Gradient Descent in Java

1. Inleiding

In deze tutorial leren we over het Gradient Descent-algoritme. We implementeren het algoritme in Java en illustreren het stap voor stap.

2. Wat is Gradient Descent?

Gradient Descent is een optimalisatie-algoritme dat wordt gebruikt om een ​​lokaal minimum van een bepaalde functie te vinden. Het wordt veel gebruikt in algoritmen voor machine learning op hoog niveau om verliesfuncties te minimaliseren.

Verloop is een ander woord voor helling, en afdalen betekent dalen. Zoals de naam al doet vermoeden, gaat Gradient Descent de helling van een functie af totdat het einde is bereikt.

3. Eigenschappen van Gradient Descent

Gradient Descent vindt een lokaal minimum, dat kan verschillen van het globale minimum. Het lokale startpunt wordt als parameter aan het algoritme gegeven.

Het is een iteratief algoritme, en bij elke stap probeert het de helling af te gaan en dichter bij het lokale minimum te komen.

In de praktijk keert het algoritme terug. In deze tutorial zullen we backtracking Gradient Descent illustreren en implementeren.

4. Stapsgewijze illustratie

Gradient Descent heeft een functie en een startpunt nodig als invoer. Laten we een functie definiëren en plotten:

We kunnen op elk gewenst punt starten. Laten we beginnen bij X=1:

In de eerste stap gaat Gradient Descent de helling af met een vooraf gedefinieerde stapgrootte:

Vervolgens gaat het verder met dezelfde stapgrootte. Deze keer komt het echter uit op een groter y dan de laatste stap:

Dit geeft aan dat het algoritme het lokale minimum heeft overschreden, dus het gaat achteruit met een verlaagde stapgrootte:

Vervolgens, wanneer de huidige y is groter dan de vorige y, wordt de stapgrootte verlaagd en teniet gedaan. De iteratie gaat door totdat de gewenste precisie is bereikt.

Zoals we kunnen zien, vond Gradient Descent hier een lokaal minimum, maar dit is niet het globale minimum. Als we beginnen bij X= -1 in plaats van X= 1, wordt het globale minimum gevonden.

5. Implementatie in Java

Er zijn verschillende manieren om Gradient Descent te implementeren. Hier berekenen we niet de afgeleide van de functie om de richting van de helling te vinden, dus onze implementatie werkt ook voor niet-differentieerbare functies.

Laten we het definiëren precisie en stepCoefficient en geef ze de beginwaarden:

dubbele precisie = 0,000001; dubbele stap Coëfficiënt = 0,1;

In de eerste stap hebben we geen vorige y ter vergelijking. We kunnen de waarde van verhogen of verlagen X om te kijken of y verlaagt of verhoogt. Een positief stepCoefficient betekent dat we de waarde van verhogen X.

Laten we nu de eerste stap uitvoeren:

double previousX = initialX; double previousY = f.apply (previousX); currentX + = stepCoefficient * previousY;

In de bovenstaande code, f is een Functie, en initialX is een dubbele, beide worden geleverd als input.

Een ander belangrijk punt om te overwegen is dat Gradient Descent niet gegarandeerd convergeert. Laten we een limiet stellen aan het aantal iteraties om te voorkomen dat we vast komen te zitten in de lus:

int iter = 100;

Later zullen we verlagen iter met één bij elke iteratie. Daarom komen we uit de lus met maximaal 100 iteraties.

Nu we een vorigeX, kunnen we onze lus opzetten:

while (previousStep> precisie && iter> 0) {iter--; dubbele currentY = f.apply (currentX); if (currentY> previousY) {stepCoefficient = -stepCoefficient / 2; } previousX = currentX; currentX + = stepCoefficient * previousY; previousY = currentY; previousStep = StrictMath.abs (currentX - previousX); }

In elke iteratie berekenen we de nieuwe y en vergelijk het met de vorige y. Als huidig is groter dan vorige, veranderen we onze richting en verkleinen we de stapgrootte.

De lus gaat door totdat onze stapgrootte kleiner is dan gewenst precisie. Eindelijk kunnen we terugkeren currentX als het lokale minimum:

retourneer currentX;

6. Conclusie

In dit artikel hebben we het Gradient Descent-algoritme doorlopen met een stapsgewijze illustratie.

We hebben ook Gradient Descent geïmplementeerd in Java. De code is beschikbaar op GitHub.