Logistieke regressie in Java

1. Inleiding

Logistische regressie is een belangrijk instrument in de toolbox van machine learning (ML).

In deze tutorial we zullen het belangrijkste idee achter logistieke regressie onderzoeken.

Laten we eerst beginnen met een kort overzicht van ML-paradigma's en algoritmen.

2. Overzicht

ML stelt ons in staat problemen op te lossen die we mensvriendelijk kunnen formuleren. Dit feit kan echter een uitdaging vormen voor ons, softwareontwikkelaars. We zijn eraan gewend om de problemen aan te pakken die we computervriendelijk kunnen formuleren. Als mens kunnen we bijvoorbeeld gemakkelijk de objecten op een foto detecteren of de sfeer van een zin bepalen. Hoe zouden we zo'n probleem voor een computer kunnen formuleren?

Om tot een oplossing te komen, in ML is er een speciale fase genaamd opleiding. Tijdens deze fase voeren we de invoergegevens naar ons algoritme zodat het probeert een optimale set parameters te bedenken (de zogenaamde gewichten). Hoe meer invoergegevens we naar het algoritme kunnen sturen, hoe nauwkeuriger we er voorspellingen van mogen verwachten.

Training maakt deel uit van een iteratieve ML-workflow:

We beginnen met het verzamelen van gegevens. Vaak zijn de gegevens afkomstig uit verschillende bronnen. Daarom moeten we ervoor zorgen dat het hetzelfde formaat heeft. We moeten ook controleren of de dataset het domein van de studie eerlijk vertegenwoordigt. Als het model nog nooit op rode appels is getraind, kan het het nauwelijks voorspellen.

Vervolgens moeten we een model bouwen dat de gegevens verbruikt en voorspellingen kan doen. In ML zijn er geen voorgedefinieerde modellen die in alle situaties goed werken.

Bij het zoeken naar het juiste model kan het gemakkelijk gebeuren dat we een model bouwen, trainen, de voorspellingen bekijken en het model weggooien omdat we niet tevreden zijn met de voorspellingen die het doet. In dit geval moeten we een stap terug doen en een ander model bouwen en het proces opnieuw herhalen.

3. ML-paradigma's

In ML kunnen we, op basis van het soort invoergegevens dat we tot onze beschikking hebben, drie hoofdparadigma's onderscheiden:

  • begeleid leren (beeldclassificatie, objectherkenning, sentimentanalyse)
  • onbewaakt leren (detectie van afwijkingen)
  • versterking van leren (spelstrategieën)

Het geval dat we gaan beschrijven behoort in deze tutorial tot begeleid leren.

4. ML Toolbox

In ML is er een set tools die we kunnen toepassen bij het bouwen van een model. Laten we er een paar noemen:

  • Lineaire regressie
  • Logistieke regressie
  • Neurale netwerken
  • Ondersteuning van Vector Machine
  • k-Naaste buren

We kunnen verschillende tools combineren bij het bouwen van een model met een hoge voorspelbaarheid. In feite zal ons model voor deze tutorial logistische regressie en neurale netwerken gebruiken.

5. ML-bibliotheken

Ook al is Java niet de meest populaire taal voor het maken van prototypes van ML-modellen,het heeft de reputatie een betrouwbare tool te zijn voor het maken van robuuste software op veel gebieden, waaronder ML. Daarom kunnen we ML-bibliotheken vinden die in Java zijn geschreven.

In deze context kunnen we de de facto standaardbibliotheek Tensorflow noemen die ook een Java-versie heeft. Een ander vermeldenswaardig is een deep learning-bibliotheek genaamd Deeplearning4j. Dit is een zeer krachtig hulpmiddel en we gaan het ook in deze tutorial gebruiken.

6. Logistieke regressie op cijferherkenning

Het belangrijkste idee van logistieke regressie is om een ​​model te bouwen dat de labels van de invoergegevens zo nauwkeurig mogelijk voorspelt.

We trainen het model totdat de zogenaamde verliesfunctie of objectieve functie een minimale waarde bereikt. De verliesfunctie is afhankelijk van de feitelijke modelvoorspellingen en verwachte (de labels van de invoergegevens). Ons doel is om de divergentie van de werkelijke modelvoorspellingen en de verwachte modelvoorspellingen te minimaliseren.

Als we niet tevreden zijn met die minimumwaarde, moeten we een ander model bouwen en de training opnieuw uitvoeren.

Om logistieke regressie in actie te zien, illustreren we het aan de hand van de herkenning van handgeschreven cijfers. Dit probleem is al een klassiek probleem geworden. De Deeplearning4j-bibliotheek heeft een reeks realistische voorbeelden die laten zien hoe de API moet worden gebruikt. Het code-gerelateerde deel van deze tutorial is sterk gebaseerd op MNIST-classificatie.

6.1. Invoergegevens

Als invoergegevens gebruiken we de bekende MNIST-database met handgeschreven cijfers. Als invoergegevens hebben we grijsschaalafbeeldingen van 28 × 28 pixels. Elke afbeelding heeft een natuurlijk label, het cijfer dat de afbeelding vertegenwoordigt:

Om de efficiëntie van het model dat we gaan bouwen in te schatten, splitsen we de invoergegevens op in trainings- en testsets:

DataSetIterator train = nieuwe RecordReaderDataSetIterator (...); DataSetIterator test = nieuwe RecordReaderDataSetIterator (...);

Zodra we de invoerafbeeldingen hebben gelabeld en opgesplitst in twee sets, is de fase van "gegevensuitwerking" voorbij en kunnen we naar het "modelbouw" gaan.

6.2. Model gebouw

Zoals we al zeiden, zijn er geen modellen die in elke situatie goed werken. Desalniettemin hebben wetenschappers na vele jaren van onderzoek naar ML modellen gevonden die zeer goed presteren bij het herkennen van handgeschreven cijfers. Hier gebruiken we het zogenaamde LeNet-5-model.

LeNet-5 is een neuraal netwerk dat bestaat uit een reeks lagen die het 28 × 28 pixelbeeld transformeren in een tiendimensionale vector:

De tiendimensionale uitvoervector bevat waarschijnlijkheden dat het label van het invoerbeeld 0, of 1, of 2 is, enzovoort.

Als de uitvoervector bijvoorbeeld de volgende vorm heeft:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

het betekent dat de kans dat het invoerbeeld nul is 0,1 is, één is 0, dat twee is 0,3, enz. We zien dat de maximale waarschijnlijkheid (0,3) overeenkomt met label 3.

Laten we dieper ingaan op de details van modelbouw. We laten Java-specifieke details weg en concentreren ons op ML-concepten.

We hebben het model opgezet door een MultiLayerNetwork voorwerp:

MultiLayerNetwork-model = nieuw MultiLayerNetwork (config);

In de constructor moeten we een MultiLayerConfiguration voorwerp. Dit is precies het object dat de geometrie van het neurale netwerk beschrijft. Om de netwerkgeometrie te definiëren, moeten we elke laag definiëren.

Laten we laten zien hoe we dit doen met de eerste en de tweede:

ConvolutionLayer layer1 = nieuwe ConvolutionLayer .Builder (5, 5) .nIn (kanalen) .stride (1, 1) .nOut (20) .activation (Activation.IDENTITY) .build (); SubsamplingLayer layer2 = nieuwe SubsamplingLayer .Builder (SubsamplingLayer.PoolingType.MAX) .kernelSize (2, 2) .stride (2, 2) .build ();

We zien dat de definities van lagen een aanzienlijk aantal ad-hocparameters bevatten die een aanzienlijke invloed hebben op de hele netwerkprestaties. Dit is precies waar ons vermogen om een ​​goed model te vinden in het landschap van iedereen cruciaal wordt.

Nu zijn we klaar om het MultiLayerConfiguration voorwerp:

MultiLayerConfiguration config = nieuwe NeuralNetConfiguration.Builder () // voorbereidingsstappen .list () .layer (layer1) .layer (layer2) // andere lagen en laatste stappen .build ();

dat we doorgeven aan de MultiLayerNetwork constructeur.

6.3. Opleiding

Het model dat we hebben geconstrueerd, bevat 431080 parameters of gewichten. We gaan hier niet de exacte berekening van dit aantal geven, maar we moeten ons ervan bewust zijn dat alleen tDe eerste laag heeft meer dan 24x24x20 = 11520 gewichten.

De trainingsfase is zo simpel als:

model.fit (trein); 

Aanvankelijk hebben de 431080-parameters enkele willekeurige waarden, maar na de training krijgen ze enkele waarden die de modelprestaties bepalen. We kunnen de voorspellende waarde van het model evalueren:

Evaluatie eval = model.evaluate (test); logger.info (eval.stats ());

Het LeNet-5-model bereikt een vrij hoge nauwkeurigheid van bijna 99%, zelfs in slechts een enkele trainingsherhaling (epoch). Als we een hogere nauwkeurigheid willen bereiken, moeten we meer iteraties maken met een vlakte for loop:

for (int i = 0; i <epochs; i ++) {model.fit (trein); train.reset (); test.reset (); } 

6.4. Voorspelling

Nu we het model hebben getraind en we zijn blij met de voorspellingen van de testgegevens, kunnen we het model uitproberen op een geheel nieuwe input. Laten we hiervoor een nieuwe klas maken MnistPrediction waarin we een afbeelding laden uit een bestand dat we selecteren uit het bestandssysteem:

INDArray afbeelding = nieuwe NativeImageLoader (hoogte, breedte, kanalen) .asMatrix (bestand); nieuwe ImagePreProcessingScaler (0, 1) .transform (afbeelding);

De variabele beeld bevat onze afbeelding die wordt verkleind tot 28 × 28 grijstinten. We kunnen het naar ons model sturen:

INDArray output = model.output (afbeelding);

De variabele output bevat de waarschijnlijkheid dat de afbeelding nul, een, twee, etc. is.

Laten we nu een beetje spelen en een cijfer 2 schrijven, deze afbeelding digitaliseren en het model voeden. We kunnen zoiets als dit krijgen:

Zoals we zien, heeft de component met maximale waarde 0,99 index twee. Het betekent dat het model ons handgeschreven cijfer correct heeft herkend.

7. Conclusie

In deze tutorial hebben we de algemene concepten van machine learning beschreven. We hebben deze concepten geïllustreerd aan de hand van een voorbeeld van een logistische regressie die we hebben toegepast op een handgeschreven cijferherkenning.

Zoals altijd kunnen we de bijbehorende codefragmenten vinden in onze GitHub-repository.