Inleiding tot Tensorflow voor Java

1. Overzicht

TensorFlow is een open source-bibliotheek voor dataflow-programmering. Dit is oorspronkelijk ontwikkeld door Google en is beschikbaar voor een breed scala aan platforms. Hoewel TensorFlow op een enkele kern kan werken, kan het zo profiteer eenvoudig van meerdere beschikbare CPU, GPU of TPU.

In deze zelfstudie bespreken we de basisprincipes van TensorFlow en hoe u deze in Java kunt gebruiken. Houd er rekening mee dat de TensorFlow Java API een experimentele API is en daarom niet wordt gedekt door enige stabiliteitsgarantie. We zullen later in de tutorial mogelijke use-cases behandelen voor het gebruik van de TensorFlow Java API.

2. Basisprincipes

TensorFlow-berekening draait in feite om twee fundamentele concepten: grafiek en sessie. Laten we ze snel doornemen om de achtergrond te krijgen die nodig is om de rest van de tutorial te doorlopen.

2.1. TensorFlow-grafiek

Laten we om te beginnen de fundamentele bouwstenen van TensorFlow-programma's begrijpen. In TensorFlow worden berekeningen weergegeven als grafieken. Een grafiek is typisch een gerichte acyclische grafiek van bewerkingen en gegevens, bijvoorbeeld:

De bovenstaande afbeelding vertegenwoordigt de computergrafiek voor de volgende vergelijking:

f (X, Y) = z = een * X + b * Y

Een computationele grafiek van TensorFlow bestaat uit twee elementen:

  1. Tensor: dit zijn de belangrijkste gegevenseenheid in TensorFlow. Ze worden weergegeven als de randen in een computergrafiek en geven de gegevensstroom door de grafiek weer. Een tensor kan een vorm hebben met een willekeurig aantal afmetingen. Het aantal dimensies in een tensor wordt meestal de rangorde genoemd. Dus een scalair is een tensor van rang 0, een vector is een tensor van rang 1, een matrix is ​​een tensor van rang 2, enzovoort, enzovoort.
  2. Werking: dit zijn de knooppunten in een computergrafiek. Ze verwijzen naar een breed scala aan berekeningen die kunnen plaatsvinden op de tensoren die in de operatie worden ingevoerd. Ze resulteren vaak ook in tensoren die voortkomen uit de bewerking in een computergrafiek.

2.2. TensorFlow-sessie

Nu is een TensorFlow-grafiek slechts een schema van de berekening die eigenlijk geen waarden bevat. Zo'n een grafiek moet worden uitgevoerd binnen een zogenaamde TensorFlow-sessie om de tensoren in de grafiek te evalueren. De sessie kan een aantal tensoren kosten om vanuit een grafiek als invoerparameters te evalueren. Vervolgens loopt het achteruit in de grafiek en voert alle knooppunten uit die nodig zijn om die tensoren te evalueren.

Met deze kennis zijn we nu klaar om dit op te pakken en toe te passen op de Java API!

3. Maven-instellingen

We zullen een snel Maven-project opzetten om een ​​TensorFlow-grafiek in Java te maken en uit te voeren. We hebben alleen de tensorflow afhankelijkheid:

 org.tensorflow tensorflow 1.12.0 

4. De grafiek maken

Laten we nu proberen de grafiek te bouwen die we in de vorige sectie hebben besproken met behulp van de TensorFlow Java API. Meer precies, voor deze tutorial gebruiken we TensorFlow Java API om de functie op te lossen die wordt weergegeven door de volgende vergelijking:

z = 3 * x + 2 * y

De eerste stap is het declareren en initialiseren van een grafiek:

Graph graph = nieuwe Graph ()

Nu moeten we alle vereiste bewerkingen definiëren. Onthoud dat bewerkingen in TensorFlow verbruiken en produceren nul of meer tensoren. Bovendien is elk knooppunt in de grafiek een bewerking inclusief constanten en tijdelijke aanduidingen. Dit lijkt misschien contra-intuïtief, maar wacht even!

De klas Grafiek heeft een generieke functie genaamd opBuilder () om elke bewerking op TensorFlow te bouwen.

4.1. Constanten definiëren

Laten we om te beginnen constante bewerkingen definiëren in onze bovenstaande grafiek. Merk op dat a constante werking heeft een tensor nodig voor zijn waarde:

Bewerking a = graph.opBuilder ("Const", "a") .setAttr ("dtype", DataType.fromClass (Double.class)) .setAttr ("value", Tensor.create (3.0, Double.class)). bouwen(); Bewerking b = graph.opBuilder ("Const", "b") .setAttr ("dtype", DataType.fromClass (Double.class)) .setAttr ("value", Tensor.create (2.0, Double.class)). bouwen();

Hier hebben we een Operatie van constant type, voedend in de Tensor met Dubbele waarden 2.0 en 3.0. Het lijkt misschien wat overweldigend om mee te beginnen, maar zo zit het voorlopig in de Java API. Deze constructies zijn veel beknopter in talen zoals Python.

4.2. Tijdelijke aanduidingen definiëren

Hoewel we waarden moeten geven aan onze constanten, tijdelijke aanduidingen hebben geen waarde nodig tijdens de definitie. De waarden voor tijdelijke aanduidingen moeten worden opgegeven wanneer de grafiek binnen een sessie wordt uitgevoerd. We zullen dat gedeelte later in de tutorial bespreken.

Laten we voorlopig eens kijken hoe we onze tijdelijke aanduidingen kunnen definiëren:

Bewerking x = graph.opBuilder ("Placeholder", "x") .setAttr ("dtype", DataType.fromClass (Double.class)) .build (); Bewerking y = graph.opBuilder ("Placeholder", "y") .setAttr ("dtype", DataType.fromClass (Double.class)) .build ();

Merk op dat we geen waarde hoefden op te geven voor onze tijdelijke aanduidingen. Deze waarden worden ingevoerd als Tensoren wanneer rennen.

4.3. Functies definiëren

Ten slotte moeten we de wiskundige bewerkingen van onze vergelijking definiëren, namelijk vermenigvuldigen en optellen om het resultaat te krijgen.

Dit zijn weer niets anders dan Operaties in TensorFlow en Graph.opBuilder () is weer handig:

Bewerking ax = graph.opBuilder ("Mul", "ax") .addInput (a.output (0)) .addInput (x.output (0)) .build (); Bewerking door = graph.opBuilder ("Mul", "door") .addInput (b.output (0)) .addInput (y.output (0)) .build (); Bewerking z = graph.opBuilder ("Add", "z") .addInput (ax.output (0)) .addInput (by.output (0)) .build ();

Hier hebben we daar gedefinieerd Operatie, twee voor het vermenigvuldigen van onze input en de laatste voor het optellen van de tussenresultaten. Merk op dat operaties hier tensoren ontvangen die niets anders zijn dan de output van onze eerdere operaties.

Houd er rekening mee dat we de uitvoer krijgen Tensor van de Operatie met index ‘0 '. Zoals we eerder hebben besproken, een Operatie kan resulteren in een of meer Tensor en daarom moeten we bij het ophalen van een handvat de index vermelden. Omdat we weten dat onze operaties er maar één teruggeven Tensor, ‘0 'werkt prima!

5. Visualiseren van de grafiek

Het is moeilijk om de grafiek in de gaten te houden naarmate deze groter wordt. Dit maakt het belangrijk om het op de een of andere manier te visualiseren. We kunnen altijd een handtekening maken zoals de kleine grafiek die we eerder hebben gemaakt, maar het is niet praktisch voor grotere grafieken. TensorFlow biedt een hulpprogramma genaamd TensorBoard om dit te vergemakkelijken.

Helaas heeft de Java API niet de mogelijkheid om een ​​gebeurtenisbestand te genereren dat wordt gebruikt door TensorBoard. Maar met behulp van API's in Python kunnen we een gebeurtenisbestand genereren zoals:

writer = tf.summary.FileWriter ('.') ...... writer.add_graph (tf.get_default_graph ()) writer.flush ()

Maakt u zich alstublieft geen zorgen als dit in de context van Java geen zin heeft, dit is hier alleen voor de volledigheid toegevoegd en niet nodig om de rest van de tutorial voort te zetten.

We kunnen nu het gebeurtenisbestand in TensorBoard laden en visualiseren, zoals:

tensorboard --logdir.

TensorBoard wordt geleverd als onderdeel van de TensorFlow-installatie.

Let op de gelijkenis tussen deze en de eerder handmatig getekende grafiek!

6. Werken met sessie

We hebben nu een computergrafiek gemaakt voor onze eenvoudige vergelijking in TensorFlow Java API. Maar hoe voeren we het uit? Laten we, voordat we daarop ingaan, eens kijken wat de toestand is Grafiek we hebben zojuist op dit punt gemaakt. Als we proberen de uitvoer van onze finale af te drukken Operatie "Z":

System.out.println (z.output (0));

Dit resulteert in iets als:

Dit hadden we niet verwacht! Maar als we ons herinneren wat we eerder hebben besproken, is dit eigenlijk logisch. De Grafiek die we zojuist hebben gedefinieerd, is nog niet uitgevoerd, dus de tensoren daarin hebben eigenlijk geen werkelijke waarde. De bovenstaande uitvoer zegt alleen dat dit een Tensor van het type Dubbele.

Laten we nu een Sessie om onze Grafiek:

Session sess = nieuwe sessie (grafiek)

Eindelijk zijn we nu klaar om onze grafiek uit te voeren en de output te krijgen die we verwachtten:

Tensor tensor = sess.runner (). Fetch ("z") .feed ("x", Tensor.create (3.0, Double.class)) .feed ("y", Tensor.create (6.0, Double.class) ) .run (). get (0) .expect (Double.class); System.out.println (tensor.doubleValue ());

Dus wat doen we hier? Het moet redelijk intuïtief zijn:

  • Krijg een Loper van de Sessie
  • Definieer het Operatie ophalen met de naam 'z'
  • Voer tensoren in voor onze tijdelijke aanduidingen 'x' en 'y'
  • Voer de ... uit Grafiek in de Sessie

En nu zien we de scalaire output:

21.0

Dit is wat we hadden verwacht, is het niet!

7. De use case voor Java API

Op dit punt klinkt TensorFlow misschien als overdreven voor het uitvoeren van basisbewerkingen. Maar natuurlijk, TensorFlow is bedoeld om grafieken veel groter uit te voeren dan dit.

Bovendien, de tensoren waarmee het in real-world modellen te maken heeft, zijn veel groter in omvang en rangorde. Dit zijn de daadwerkelijke machine learning-modellen waar TensorFlow zijn echte gebruik vindt.

Het is niet moeilijk om te zien dat het werken met de kern-API in TensorFlow erg omslachtig kan worden naarmate de grafiek groter wordt. Daartoe TensorFlow biedt API's van hoog niveau, zoals Keras, om met complexe modellen te werken. Helaas is er nog weinig tot geen officiële ondersteuning voor Keras op Java.

Maar we kunnen gebruik Python om complexe modellen te definiëren en te trainen ofwel rechtstreeks in TensorFlow of met behulp van hoogwaardige API's zoals Keras. Vervolgens kunnen we exporteer een getraind model en gebruik dat in Java met behulp van de TensorFlow Java API.

Nu, waarom zouden we zoiets willen doen? Dit is met name handig voor situaties waarin we machine learning-functies willen gebruiken in bestaande clients die op Java draaien. Bijvoorbeeld bijschrift aanbevelen voor gebruikersafbeeldingen op een Android-apparaat. Desalniettemin zijn er verschillende gevallen waarin we geïnteresseerd zijn in de output van een machine learning-model, maar niet per se dat model in Java willen maken en trainen.

Dit is waar TensorFlow Java API het grootste deel van het gebruik ervan vindt. In de volgende sectie zullen we bespreken hoe dit kan worden bereikt.

8. Opgeslagen modellen gebruiken

We zullen nu begrijpen hoe we een model in TensorFlow kunnen opslaan in het bestandssysteem en dat mogelijk terug kunnen laden in een compleet andere taal en platform. TensorFlow biedt API's om modelbestanden te genereren in een taal- en platformneutrale structuur die protocolbuffer wordt genoemd.

8.1. Modellen opslaan in het bestandssysteem

We beginnen met het definiëren van dezelfde grafiek die we eerder in Python hebben gemaakt en die opslaan in het bestandssysteem.

Laten we eens kijken dat we dit in Python kunnen doen:

importeer tensorflow als tf graph = tf.Graph () builder = tf.saved_model.builder.SavedModelBuilder ('./ model') met graph.as_default (): a = tf.constant (2, name = "a") b = tf.constant (3, name = "b") x = tf.placeholder (tf.int32, name = "x") y = tf.placeholder (tf.int32, name = "y") z = tf.math. add (a * x, b * y, name = "z") sess = tf.Session () sess.run (z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables (sess, [tf. opgeslagen_model.tag_constants.SERVING]) builder.save ()

Als focus van deze tutorial in Java, laten we niet veel aandacht besteden aan de details van deze code in Python, behalve het feit dat het een bestand genereert met de naam "saved_model.pb". Let op bij het doorgeven van de beknoptheid bij het definiëren van een vergelijkbare grafiek in vergelijking met Java!

8.2. Modellen laden vanuit het bestandssysteem

We zullen nu "saved_model.pb" in Java laden. Java TensorFlow API heeft OpgeslagenModelBundle om met opgeslagen modellen te werken:

SavedModelBundle model = SavedModelBundle.load ("./ model", "serve"); Tensor tensor = model.session (). Runner (). Fetch ("z") .feed ("x", Tensor.create (3, Integer.class)) .feed ("y", Tensor.create (3, Integer.class)) .run (). Get (0) .expect (Integer.class); System.out.println (tensor.intValue ());

Het zou nu redelijk intuïtief moeten zijn om te begrijpen wat de bovenstaande code doet. Het laadt eenvoudig de modelgrafiek uit de protocolbuffer en stelt de sessie daarin beschikbaar. Vanaf dat moment kunnen we vrijwel alles met deze grafiek doen, zoals we zouden hebben gedaan met een lokaal gedefinieerde grafiek.

9. Conclusie

Samenvattend hebben we in deze tutorial de basisconcepten met betrekking tot de TensorFlow-computergrafiek doorgenomen. We hebben gezien hoe we de TensorFlow Java API kunnen gebruiken om zo'n grafiek te maken en uit te voeren. Vervolgens hebben we het gehad over de use-cases voor de Java API met betrekking tot TensorFlow.

Tijdens het proces hebben we ook begrepen hoe we de grafiek konden visualiseren met TensorBoard en een model konden opslaan en opnieuw laden met behulp van Protocol Buffer.

Zoals altijd is de code voor de voorbeelden beschikbaar op GitHub.