Het K-Means Clustering-algoritme in Java

1. Overzicht

Clustering is een overkoepelende term voor een klasse van niet-gecontroleerde algoritmen ontdek groepen dingen, mensen of ideeën die nauw met elkaar verband houden.

In deze ogenschijnlijk eenvoudige one-liner-definitie zagen we een paar modewoorden. Wat is clustering precies? Wat is een algoritme zonder toezicht?

In deze tutorial gaan we eerst wat licht werpen op deze concepten. Daarna zullen we zien hoe ze zich in Java kunnen manifesteren.

2. Ongecontroleerde algoritmen

Voordat we de meeste leeralgoritmen gebruiken, moeten we er op de een of andere manier wat voorbeeldgegevens aan toevoegen en het algoritme van die gegevens laten leren. In Machine Learning-terminologie, we noemen dat voorbeeldgegevensset trainingsgegevens. Ook, het hele proces staat bekend als het trainingsproces.

In ieder geval, we kunnen leeralgoritmen classificeren op basis van de hoeveelheid supervisie die ze nodig hebben tijdens het trainingsproces. De twee belangrijkste soorten leeralgoritmen in deze categorie zijn:

  • Leren onder toezicht: In algoritmen onder supervisie moeten de trainingsgegevens de daadwerkelijke oplossing voor elk punt bevatten. Als we bijvoorbeeld op het punt staan ​​ons spamfilteralgoritme te trainen, sturen we zowel de voorbeeld-e-mails als hun label, d.w.z. spam of niet-spam, naar het algoritme. Wiskundig gesproken gaan we de f (x) uit een trainingsset inclusief beide xs en ys.
  • Ongecontroleerd leren: Als de trainingsgegevens geen labels bevatten, is het algoritme niet onder toezicht. We hebben bijvoorbeeld veel gegevens over muzikanten en we gaan groepen van vergelijkbare muzikanten ontdekken in de gegevens.

3. Clustering

Clustering is een algoritme zonder toezicht om groepen vergelijkbare dingen, ideeën of mensen te ontdekken. In tegenstelling tot algoritmen onder supervisie, trainen we clusteringalgoritmen niet met voorbeelden van bekende labels. In plaats daarvan probeert clustering structuren te vinden binnen een trainingsset waar geen enkel punt van de gegevens het label is.

3.1. K-Means Clustering

K-Means is een clusteralgoritme met één fundamentele eigenschap: het aantal clusters wordt vooraf bepaald. Naast K-Means zijn er andere soorten clusteralgoritmen, zoals hiërarchische clustering, affiniteitsverspreiding of spectrale clustering.

3.2. Hoe K-Means werkt

Stel dat het ons doel is om een ​​paar vergelijkbare groepen in een dataset te vinden, zoals:

K-Means begint met k willekeurig geplaatste centroïden. Centroids, zoals hun naam suggereert, zijn de middelpunten van de clusters. Hier voegen we bijvoorbeeld vier willekeurige centroïden toe:

Vervolgens wijzen we elk bestaand gegevenspunt toe aan het dichtstbijzijnde zwaartepunt:

Na de toewijzing verplaatsen we de centroïden naar de gemiddelde locatie van de punten die eraan zijn toegewezen. Onthoud dat centroïden de middelpunten van clusters zouden moeten zijn:

De huidige iteratie eindigt elke keer dat we klaar zijn met het verplaatsen van de centroïden. We herhalen deze iteraties totdat de toewijzing tussen meerdere opeenvolgende iteraties niet meer verandert:

Wanneer het algoritme eindigt, worden die vier clusters gevonden zoals verwacht. Nu we weten hoe K-Means werkt, gaan we het in Java implementeren.

3.3. Feature Vertegenwoordiging

Bij het modelleren van verschillende trainingsdatasets hebben we een datastructuur nodig om modelattributen en hun corresponderende waarden weer te geven. Een muzikant kan bijvoorbeeld een genreattribuut hebben met een waarde als Rock. We gebruiken de term feature meestal om te verwijzen naar de combinatie van een attribuut en zijn waarde.

Om een ​​dataset voor een bepaald leeralgoritme voor te bereiden, gebruiken we meestal een gemeenschappelijke set numerieke attributen die kunnen worden gebruikt om verschillende items te vergelijken. Als we onze gebruikers bijvoorbeeld elke artiest met een genre laten taggen, kunnen we aan het eind van de dag tellen hoe vaak elke artiest is getagd met een specifiek genre:

De feature-vector voor een artiest als Linkin Park is [rock -> 7890, nu-metal -> 700, alternatief -> 520, pop -> 3]. Dus als we een manier zouden kunnen vinden om attributen als numerieke waarden weer te geven, dan kunnen we eenvoudig twee verschillende items vergelijken, bijv. kunstenaars, door hun overeenkomstige vectoringangen te vergelijken.

Omdat numerieke vectoren zulke veelzijdige datastructuren zijn, gaan we functies weergeven die ze gebruiken. Hier is hoe we feature-vectoren in Java implementeren:

public class Record {private final String description; privé definitieve kaartfuncties; // constructor, getter, toString, is gelijk aan en hashcode}

3.4. Vergelijkbare items zoeken

In elke iteratie van K-Means hebben we een manier nodig om het dichtstbijzijnde zwaartepunt voor elk item in de dataset te vinden. Een van de eenvoudigste manieren om de afstand tussen twee kenmerkvectoren te berekenen, is door Euclidische afstand te gebruiken. De Euclidische afstand tussen twee vectoren zoals [p1, q1] en [p2, q2] is gelijk aan:

Laten we deze functie in Java implementeren. Ten eerste de abstractie:

openbare interface Afstand {dubbel berekenen (kaart f1, kaart f2); }

Naast Euclidische afstand, er zijn andere benaderingen om de afstand of gelijkenis tussen verschillende items te berekenen, zoals de Pearson-correlatiecoëfficiënt. Deze abstractie maakt het gemakkelijk om te schakelen tussen verschillende afstandsmetrieken.

Laten we eens kijken naar de implementatie voor Euclidische afstand:

openbare klasse EuclideanDistance implementeert Afstand {@Override openbare dubbele berekening (kaart f1, kaart f2) {dubbele som = 0; for (String key: f1.keySet ()) {Double v1 = f1.get (key); Dubbel v2 = f2.get (sleutel); if (v1! = null && v2! = null) {som + = Math.pow (v1 - v2, 2); }} retourneer Math.sqrt (som); }}

Eerst berekenen we de som van de gekwadrateerde verschillen tussen de corresponderende items. Vervolgens door de sqrt functie, berekenen we de werkelijke Euclidische afstand.

3.5. Centroid Vertegenwoordiging

Centroids staan ​​in dezelfde ruimte als normale objecten, dus we kunnen ze op dezelfde manier weergeven als objecten:

openbare klasse Centroid {privé definitieve kaartcoördinaten; // constructors, getter, toString, equals en hashcode}

Nu we een paar noodzakelijke abstracties hebben, is het tijd om onze K-Means-implementatie te schrijven. Hier is een korte blik op onze methodehandtekening:

openbare klasse KMeans {privé statisch definitief Willekeurig willekeurig = nieuw Willekeurig (); openbare statische kaart fit (List records, int k, Distance distance, int maxIterations) {// weggelaten}}

Laten we de handtekening van deze methode opsplitsen:

  • De dataset is een set kenmerkvectoren. Omdat elke kenmerkvector een Vermelding, dan is het dataset-type Lijst
  • De k parameter bepaalt het aantal clusters, dat we van tevoren moeten opgeven
  • afstand vat de manier samen waarop we het verschil tussen twee kenmerken gaan berekenen
  • K-Means wordt beëindigd wanneer de toewijzing gedurende enkele opeenvolgende iteraties niet meer verandert. Naast deze beëindigingsvoorwaarde kunnen we ook een bovengrens stellen voor het aantal iteraties. De maxIteraties argument bepaalt die bovengrens
  • Wanneer K-Means wordt beëindigd, zou elk zwaartepunt een paar toegewezen functies moeten hebben, daarom gebruiken we een Kaartals het retourtype. In principe komt elk kaartitem overeen met een cluster

3.6. Centroid generatie

De eerste stap is het genereren van k willekeurig geplaatste centroïden.

Hoewel elk zwaartepunt volledig willekeurige coördinaten kan bevatten, is het een goede gewoonte om dat te doen genereer willekeurige coördinaten tussen de minimum en maximum mogelijke waarden voor elk attribuut. Door willekeurige centroïden te genereren zonder rekening te houden met het bereik van mogelijke waarden, zou het algoritme langzamer convergeren.

Eerst moeten we de minimum- en maximumwaarde voor elk attribuut berekenen en vervolgens de willekeurige waarden tussen elk paar ervan genereren:

private static List randomCentroids (List records, int k) {List centroids = new ArrayList (); Map maxs = nieuwe HashMap (); Kaartminuten = nieuwe HashMap (); voor (Record record: records) {record.getFeatures (). forEach ((sleutel, waarde) ->); } Stel attributen in = records.stream () .flatMap (e -> e.getFeatures (). KeySet (). Stream ()) .collect (toSet ()); for (int i = 0; i <k; i ++) {Map coordinates = new HashMap (); for (String attribuut: attributen) {double max = maxs.get (attribuut); double min = mins.get (attribuut); coordinates.put (attribuut, random.nextDouble () * (max - min) + min); } centroids.add (nieuwe Centroid (coördinaten)); } terugkeer centroïden; }

Nu kunnen we elk record aan een van deze willekeurige centroïden toewijzen.

3.7. Toewijzing

Ten eerste, gegeven een Vermelding, moeten we het zwaartepunt vinden dat er het dichtst bij is:

privé statisch zwaartepunt dichtstbijzijndeCentroid (record opnemen, zwaartepunten weergeven, afstandsafstand) {double minimumDistance = Double.MAX_VALUE; Centroid dichtstbijzijnde = null; voor (Centroid centroid: centroids) {double currentDistance = distance.calculate (record.getFeatures (), centroid.getCoordinates ()); if (currentDistance <minimumDistance) {minimumDistance = currentDistance; dichtstbijzijnde = zwaartepunt; }} terugkeer dichtstbijzijnde; }

Elk record behoort tot het dichtstbijzijnde centroïde cluster:

private static void assignToCluster (Map clusters, Recordrecord, Centroid centroid) {clusters.compute (centroid, (key, list) -> {if (list == null) {list = new ArrayList ();} list.add (record); return list;} ); }

3.8. Zwaartepunt verplaatsing

Als een zwaartepunt na één iteratie geen opdrachten meer bevat, zullen we het niet verplaatsen. Anders moeten we de centroïde coördinaat voor elk attribuut verplaatsen naar de gemiddelde locatie van alle toegewezen records:

privé statisch zwaartepunt gemiddelde (zwaartepunt zwaartepunt, lijst records) {if (records == null || records.isEmpty ()) {terugkeer zwaartepunt; } Kaart gemiddelde = centroid.getCoordinates (); records.stream (). flatMap (e -> e.getFeatures (). keySet (). stream ()) .forEach (k -> gemiddelde.put (k, 0.0)); voor (Recordrecord: records) {record.getFeatures (). forEach ((k, v) -> gemiddelde.pute (k, (k1, currentValue) -> v + currentValue)); } average.forEach ((k, v) -> gemiddelde.put (k, v / records.size ())); retourneer nieuwe Centroid (gemiddeld); }

Omdat we een enkel zwaartepunt kunnen verplaatsen, is het nu mogelijk om het relocateCentroids methode:

privé statische lijst relocateCentroids (Map clusters) {retourneer clusters.entrySet (). stream (). map (e -> gemiddelde (e.getKey (), e.getValue ())). collect (toList ()); }

Deze eenvoudige oneliner itereert door alle centroïden, verplaatst ze en retourneert de nieuwe centroïden.

3.9. Alles samenvoegen

In elke iteratie, nadat alle records aan hun dichtstbijzijnde zwaartepunt zijn toegewezen, moeten we eerst de huidige toewijzingen vergelijken met de laatste iteratie.

Als de toewijzingen identiek waren, wordt het algoritme beëindigd. Anders moeten we, voordat we naar de volgende iteratie springen, de centroïden verplaatsen:

openbare statische kaart fit (List records, int k, Distance distance, int maxIterations) {List centroids = randomCentroids (records, k); Kaart clusters = nieuwe HashMap (); Kaart lastState = nieuwe HashMap (); // itereer voor een vooraf gedefinieerd aantal keren voor (int i = 0; i <maxIterations; i ++) {boolean isLastIteration = i == maxIterations - 1; // in elke iteratie moeten we voor elk record het dichtstbijzijnde zwaartepunt vinden voor (Record record: records) {Centroid centroid = dichtstbijzijndeCentroid (record, centroids, distance); assignToCluster (clusters, record, centroid); } // als de toewijzingen niet veranderen, beëindigt het algoritme boolean shouldTerminate = isLastIteration || clusters.equals (lastState); lastState = clusters; if (shouldTerminate) {break; } // aan het einde van elke iteratie moeten we de centroids centroids = relocateCentroids (clusters) verplaatsen; clusters = nieuwe HashMap (); } return lastState; }

4. Voorbeeld: soortgelijke artiesten ontdekken op Last.fm

Last.fm bouwt een gedetailleerd profiel op van de muzikale smaak van elke gebruiker door details op te nemen van waar de gebruiker naar luistert. In deze sectie gaan we clusters van vergelijkbare artiesten zoeken. Om een ​​dataset te bouwen die geschikt is voor deze taak, gebruiken we drie API's van Last.fm:

  1. API om een ​​verzameling topartiesten op Last.fm te krijgen.
  2. Nog een API om populaire tags te vinden. Elke gebruiker kan een artiest ergens mee taggen, bijv. rots. Last.fm houdt dus een database bij van die tags en hun frequenties.
  3. En een API om de toptags voor een artiest te krijgen, gerangschikt op populariteit. Aangezien er veel van dergelijke tags zijn, behouden we alleen de tags die tot de top algemene tags behoren.

4.1. Last.fm's API

Om deze API's te gebruiken, moeten we een API-sleutel krijgen van Last.fm en deze in elk HTTP-verzoek verzenden. We gaan de volgende Retrofit-service gebruiken om die API's aan te roepen:

openbare interface LastFmService {@GET ("/ 2.0 /? method = chart.gettopartists & format = json & limit = 50") Bel topArtists (@Query ("page") int-pagina); @GET ("/ 2.0 /? Method = artist.gettoptags & format = json & limit = 20 & autocorrect = 1") Roep topTagsFor (@Query ("artist") String artist); @GET ("/ 2.0 /? Method = chart.gettoptags & format = json & limit = 100") Roep topTags (); // Een paar DTO's en één interceptor}

Laten we dus de meest populaire artiesten op Last.fm zoeken:

// het opzetten van de Retrofit-service privé statische lijst getTop100Artists () gooit IOException {List artists = new ArrayList (); // Ophalen van de eerste twee pagina's, elk met 50 records. voor (int i = 1; i <= 2; i ++) {artiesten.addAll (lastFm.topArtists (i) .execute (). body (). all ()); } terugkeer artiesten; }

Op dezelfde manier kunnen we de toptags ophalen:

private static Set getTop100Tags () gooit IOException {return lastFm.topTags (). execute (). body (). all (); }

Ten slotte kunnen we een dataset van artiesten samen met hun tagfrequenties bouwen:

privé statische lijst datasetWithTaggedArtists (lijst artiesten, set topTags) gooit IOException {List records = new ArrayList (); voor (String artiest: artiesten) {Map tags = lastFm.topTagsFor (artist) .execute (). body (). all (); // Bewaar alleen populaire tags. tags.entrySet (). removeIf (e ->! topTags.contains (e.getKey ())); records.add (nieuw record (artiest, tags)); } records retourneren; }

4.2. Vormen van artiestenclusters

Nu kunnen we de voorbereide dataset naar onze K-Means-implementatie sturen:

Lijst met artiesten = getTop100Artists (); Set topTags = getTop100Tags (); Lijstrecords = datasetWithTaggedArtists (artiesten, topTags); Kaart clusters = KMeans.fit (records, 7, nieuwe EuclideanDistance (), 1000); // Afdrukken van de clusterconfiguratie clusters.forEach ((sleutel, waarde) -> {System.out.println ("------------------------- - CLUSTER ---------------------------- "); // De coördinaten sorteren om eerst de belangrijkste tags te zien. System.out. println (sortCentroid (key)); String-leden = String.join (",", value.stream (). map (Record :: getDescription) .collect (toSet ())); System.out.print (leden); System.out.println (); System.out.println ();});

Als we deze code uitvoeren, zouden de clusters worden gevisualiseerd als tekstuitvoer:

------------------------------ TROS ------------------- ---------------- Centroid {classic rock = 65.58333333333333, rock = 64.41666666666667, british = 20.33333333333333332, ...} David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen , blink-182, The Rolling Stones, Metallica, Fleetwood Mac, The Beatles, Elton John, The Clash ---------------------------- - CLUSTER ----------------------------------- Centroid {Hip-Hop = 97.21428571428571, rap = 64.85714285714286, hiphop = 29.285714285714285, ...} Kanye West, Post Malone, Childish Gambino, Lil Nas X, A $ AP Rocky, Lizzo, xxxtentacion, Travi $ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj , Drake ------------------------------ CLUSTER ----------------- ------------------ Centroid {indierock = 54.0, rock = 52.0, Psychedelic Rock = 51.0, psychedelic = 47.0, ...} Tame Impala, The Black Keys - ---------------------------- TROS --------------------- -------------- Centroid {pop = 81.96428571428571, zangeressen = 41.2857142 85714285, indie = 22.785714285714285, ...} Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars, Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P! Nk, Sam Smith, Shawn Mendes, Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna, Adele, Lady Gaga, Jonas Brothers ------------ ------------------ TROS ------------------------------- ---- Centroid {indie = 95.23076923076923, alternative = 70.61538461538461, indierock = 64.46153846153847, ...} Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons, The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire, Arctic Monkeys ------------------------------ CLUSTER - ---------------------------------- Centroid {elektronisch = 91.6923076923077, House = 39.46153846153846, dans = 38.0, .. .} Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers, Avicii, Kygo, Marshmello, David Guetta, Major Lazer ------------------------------ CLUSTER ----- ------------------------------ Centroid {rock = 87.38888888888889, alternative = 72.11111111111111, alternatieve rock = 49.16666666, ...} Weezer , The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz, Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park, Red Hot Chili Peppers, Muse

Omdat centroïde coördinaties worden gesorteerd op de gemiddelde tagfrequentie, kunnen we gemakkelijk het dominante genre in elk cluster herkennen. De laatste cluster is bijvoorbeeld een cluster van goede oude rockbands, of de tweede is gevuld met rapsterren.

Hoewel deze clustering logisch is, is het voor het grootste deel niet perfect, omdat de gegevens alleen worden verzameld op basis van gebruikersgedrag.

5. Visualisatie

Een paar ogenblikken geleden visualiseerde ons algoritme het cluster van artiesten op een terminalvriendelijke manier. Als we onze clusterconfiguratie naar JSON converteren en deze naar D3.js sturen, hebben we met een paar regels JavaScript een mooie mensvriendelijke Radial Tidy-Tree:

We moeten onze Kaart naar een JSON met een vergelijkbaar schema zoals dit d3.js-voorbeeld.

6. Aantal clusters

Een van de fundamentele eigenschappen van K-Means is dat we van tevoren het aantal clusters moeten definiëren. Tot nu toe hebben we een statische waarde gebruikt voor k, maar het bepalen van deze waarde kan een uitdagend probleem zijn. Er zijn twee veelgebruikte manieren om het aantal clusters te berekenen:

  1. Domein kennis
  2. Wiskundige heuristieken

Als we het geluk hebben dat we zoveel weten over het domein, kunnen we misschien gewoon het juiste nummer raden.Anders kunnen we een paar heuristieken toepassen, zoals Elbow Method of Silhouette Method, om een ​​idee te krijgen van het aantal clusters.

Voordat we verder gaan, moeten we weten dat deze heuristieken, hoewel nuttig, slechts heuristieken zijn en mogelijk geen eenduidige antwoorden bieden.

6.1. Elleboog-methode

Om de elleboogmethode te gebruiken, moeten we eerst het verschil berekenen tussen elk clusterzwaartepunt en al zijn leden. Naarmate we meer niet-verwante leden in een cluster groeperen, neemt de afstand tussen het zwaartepunt en zijn leden toe, waardoor de kwaliteit van het cluster afneemt.

Een manier om deze afstandsberekening uit te voeren, is door de som van gekwadrateerde fouten te gebruiken. Som van gekwadrateerde fouten of SSE is gelijk aan de som van gekwadrateerde verschillen tussen een centroïde en al zijn leden:

openbare statische dubbele sse (Map geclusterd, Afstandsafstand) {dubbele som = 0; voor (Map.Entry entry: clustered.entrySet ()) {Centroid centroid = entry.getKey (); voor (Record record: entry.getValue ()) {dubbele d = afstand.calculate (centroid.getCoordinates (), record.getFeatures ()); som + = Math.pow (d, 2); }} retourbedrag; }

Dan, we kunnen het K-Means-algoritme uitvoeren voor verschillende waarden van ken bereken de SSE voor elk van hen:

Lijst met records = // de dataset; Afstand afstand = nieuwe EuclideanDistance (); Lijst sumOfSquaredErrors = nieuwe ArrayList (); voor (int k = 2; k <= 16; k ++) {Map clusters = KMeans.fit (records, k, afstand, 1000); double sse = Fouten.sse (clusters, afstand); sumOfSquaredErrors.add (sse); }

Aan het eind van de dag is het mogelijk om een ​​geschikte te vinden k door het aantal clusters uit te zetten tegen de SSE:

Gewoonlijk neemt de afstand tussen clusterleden af ​​naarmate het aantal clusters toeneemt. We kunnen echter geen willekeurige grote waarden kiezen voor k, omdat het hebben van meerdere clusters met slechts één lid het hele doel van clustering verslaat.

Het idee achter de elleboogmethode is om een ​​geschikte waarde voor te vinden k op een manier dat de SSE rond die waarde dramatisch afneemt. Bijvoorbeeld, k = 9 kan hier een goede kandidaat zijn.

7. Conclusie

In deze zelfstudie hebben we eerst enkele belangrijke concepten in Machine Learning behandeld. Toen kregen we kennis met de mechanica van het K-Means-clusteralgoritme. Ten slotte hebben we een eenvoudige implementatie voor K-Means geschreven, ons algoritme getest met een real-world dataset van Last.fm en het clusteringresultaat op een mooie grafische manier gevisualiseerd.

Zoals gewoonlijk is de voorbeeldcode beschikbaar op ons GitHub-project, dus zorg ervoor dat je het bekijkt!