K-Means++

From Wikibooks, open books for an open world
Jump to: navigation, search

Contents

C#[edit]

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Drawing;
using System.Security.Cryptography;

class KMeansPP
{
    //Output object
    public class PointClusters
    {
        private Dictionary<Point, List<Point>> _pc = new Dictionary<Point, List<Point>>();

        public Dictionary<Point, List<Point>> PC
        {
            get { return _pc; } set { _pc = value; }
        }
    }

    //Intermediate calculation object
    public struct PointDetails
    {
        private Point _seedpoint;
        private double[] _Weights;
        private double _Sum;
        private double _minD;

        public Point SeedPoint
        {
            get { return _seedpoint; } set { _seedpoint = value; }
        }

        public double[] Weights
        {
            get { return _Weights; } set { _Weights = value; }
        }

        public double Sum
        {
            get { return _Sum; } set { _Sum = value; }
        }

        public double MinD
        {
            get { return _minD; } set { _minD = value; }
        }
    }


    /// <summary>
    /// Basic (non kd-tree) implementation of kmeans++ algorithm. 
    /// cf. http://en.wikipedia.org/wiki/K-means%2B%2B
    /// Excellent for financial diversification cf. 
    /// Clustering Techniques for Financial Diversification, March 2009
    /// cf http://www.cse.ohio-state.edu/~johansek/clustering.pdf 
    /// Zach Howard & Keith Johansen
    /// Note1: If unsure what value of k to use, try: k ~ (n/2)^0.5
    /// cf. http://en.wikipedia.org/wiki/Determining_the_number_of_clusters_in_a_data_set
    /// </summary>
    /// <param name="allPoints">All points in ensemble</param>
    /// <param name="k">Number of clusters</param>
    /// <returns></returns>
    public PointClusters GetKMeansPP(List<Point> allPoints, int k)
    {
        //1. Preprocess KMeans (obtain optimized seed points)
        List<Point> seedPoints = GetSeedPoints(allPoints, k);

        //2. Regular KMeans algorithm
        PointClusters resultado = GetKMeans(allPoints, seedPoints, k);

        return resultado;
    }

    //Bog standard k-means.
    private PointClusters GetKMeans(List<Point> allPoints, List<Point> seedPoints, int k)
    {
        PointClusters cluster = new PointClusters();
        double[] Distances = new double[k];
        double minD = double.MaxValue;
        List<Point> sameDPoint = new List<Point>();
        bool exit = true;

        //Cycle thru all points in ensemble and assign to nearest centre 
        foreach (Point p in allPoints)
        {
            foreach (Point sPoint in seedPoints)
            {
                double dist = GetEuclideanD(p, sPoint);
                if (dist < minD)
                {
                    sameDPoint.Clear();
                    minD = dist;
                    sameDPoint.Add(sPoint);
                }
                if (dist == minD)
                {
                    if (!sameDPoint.Contains(sPoint))
                        sameDPoint.Add(sPoint);
                }
            }

            //Extract nearest central point. 
            Point keyPoint;
            if (sameDPoint.Count > 1)
            {
                int index = GetRandNumCrypto(0, sameDPoint.Count);
                keyPoint = sameDPoint[index];
            }
            else
                keyPoint = sameDPoint[0];

            //Assign ensemble point to correct central point cluster
            if (!cluster.PC.ContainsKey(keyPoint))  //New
            {
                List<Point> newCluster = new List<Point>();
                newCluster.Add(p);
                cluster.PC.Add(keyPoint, newCluster);
            }
            else
            {   //Existing cluster centre   
                cluster.PC[keyPoint].Add(p);
            }

            //Reset
            sameDPoint.Clear();
            minD = double.MaxValue;
        }

        //Bulletproof check - it it come out of the wash incorrect then re-seed.
        if (cluster.PC.Count != k)
        {
            cluster.PC.Clear();
            seedPoints = GetSeedPoints(allPoints, k);
        }

        List<Point> newSeeds = GetCentroid(cluster);

        //Determine exit
        foreach (Point newSeed in newSeeds)
        {
            if (!cluster.PC.ContainsKey(newSeed))
                exit = false;
        }

        if (exit)
            return cluster;
        else
            return GetKMeans(allPoints, newSeeds, k);
    }

    /// <summary>
    /// Get the centroid of a set of points
    /// cf. http://en.wikipedia.org/wiki/Centroid
    /// Consider also: Metoid cf. http://en.wikipedia.org/wiki/Medoids
    /// </summary>
    /// <param name="pcs"></param>
    /// <returns></returns>
    private List<Point> GetCentroid(PointClusters pcs)
    {
        List<Point> newSeeds = new List<Point>(pcs.PC.Count);
        Point newSeed;
        int sumX = 0; int sumY = 0;

        foreach (List<Point> cluster in pcs.PC.Values)
        {
            foreach (Point p in cluster)
            {
                sumX += p.X;
                sumY += p.Y;
            }

            newSeed = new Point(sumX / cluster.Count, sumY / cluster.Count);
            newSeeds.Add(newSeed);
            sumX = sumY = 0;
        }

        return newSeeds;
    }


    private List<Point> GetSeedPoints(List<Point> allPoints, int k)
    {
        List<Point> seedPoints = new List<Point>(k);
        PointDetails pd;
        List<PointDetails> pds = new List<PointDetails>();
        int index = 0;

        //1. Choose 1 random point as first seed
        int firstIndex = GetRandNorm(0, allPoints.Count);
        Point FirstPoint = allPoints[firstIndex];
        seedPoints.Add(FirstPoint);

        for (int i = 0; i < k - 1; i++)
        {
            if (seedPoints.Count >= 2)
            {
                //Get point with min distance
                PointDetails minpd = GetMinDPD(pds);
                index = GetWeightedProbDist(minpd.Weights, minpd.Sum);
                Point SubsequentPoint = allPoints[index];
                seedPoints.Add(SubsequentPoint);

                pd = new PointDetails();
                pd = GetAllDetails(allPoints, SubsequentPoint, pd);
                pds.Add(pd);
            }
            else
            {
                pd = new PointDetails();
                pd = GetAllDetails(allPoints, FirstPoint, pd);
                pds.Add(pd);
                index = GetWeightedProbDist(pd.Weights, pd.Sum);
                Point SecondPoint = allPoints[index];
                seedPoints.Add(SecondPoint);

                pd = new PointDetails();
                pd = GetAllDetails(allPoints, SecondPoint, pd);
                pds.Add(pd);
            }
        }

        return seedPoints;
    }

    /// <summary>
    /// Very simple weighted probability distribution. NB: No ranking involved.
    /// Returns a random index proportional to to D(x)^2
    /// </summary>
    /// <param name="w">Weights</param>
    /// <param name="s">Sum total of weights</param>
    /// <returns>Index</returns>
    private int GetWeightedProbDist(double[] w, double s)
    {
        double p = GetRandNumCrypto();
        double q = 0d;
        int i = -1;

        while (q < p)
        {
            i++;
            q += (w[i] / s);
        }
        return i;
    }

    //Gets a pseudo random number (of normal quality) in range: [0, 1)
    private double GetRandNorm()
    {
        Random seed = new Random();
        return seed.NextDouble();
    }

    //Gets a pseudo random number (of normal quality) in range: [min, max)
    private int GetRandNorm(int min, int max)
    {
        Random seed = new Random();
        return seed.Next(min, max);
    }

    //Pseudorandom number (of crypto strength) in range: [min,max) 
    private int GetRandNumCrypto(int min, int max)
    {
        byte[] salt = new byte[8];
        RNGCryptoServiceProvider rng = new RNGCryptoServiceProvider();
        rng.GetBytes(salt);
        return (int)((double)BitConverter.ToUInt64(salt, 0) / UInt64.MaxValue * (max - min)) + min;
    }

    //Pseudorandom number (of crypto strength) in range: [0.0,1.0) 
    private double GetRandNumCrypto()
    {
        byte[] salt = new byte[8];
        RNGCryptoServiceProvider rng = new RNGCryptoServiceProvider();
        rng.GetBytes(salt);
        return (double)BitConverter.ToUInt64(salt, 0) / UInt64.MaxValue;
    }


    //Gets the weight, sum, & min distance. Loop consolidation essentially.
    private PointDetails GetAllDetails(List<Point> allPoints, Point seedPoint, PointDetails pd)
    {
        double[] Weights = new double[allPoints.Count];
        double minD = double.MaxValue;
        double Sum = 0d;
        int i = 0;

        foreach (Point p in allPoints)
        {
            if (p == seedPoint) //Delta is 0
                continue;

            Weights[i] = GetEuclideanD(p, seedPoint);
            Sum += Weights[i];
            if (Weights[i] < minD)
                minD = Weights[i];
            i++;
        }

        pd.SeedPoint = seedPoint;
        pd.Weights = Weights;
        pd.Sum = Sum;
        pd.MinD = minD;

        return pd;
    }

    /// <summary>
    /// Simple Euclidean distance
    /// cf. http://en.wikipedia.org/wiki/Euclidean_distance
    /// Consider also: Manhattan, Chebyshev & Minkowski distances
    /// </summary>
    /// <param name="P1"></param>
    /// <param name="P2"></param>
    /// <returns></returns>
    private double GetEuclideanD(Point P1, Point P2)
    {
        double dx = (P1.X - P2.X);
        double dy = (P1.Y - P2.Y);
        return ((dx * dx) + (dy * dy));
    }

    //Gets min distance from set of PointDistance objects. If similar then chooses random item.
    private PointDetails GetMinDPD(List<PointDetails> pds)
    {
        double minValue = double.MaxValue;
        List<PointDetails> sameDistValues = new List<PointDetails>();

        foreach (PointDetails pd in pds)
        {
            if (pd.MinD < minValue)
            {
                sameDistValues.Clear();
                minValue = pd.MinD;
                sameDistValues.Add(pd);
            }
            if (pd.MinD == minValue)
            {
                if (!sameDistValues.Contains(pd))
                    sameDistValues.Add(pd);
            }
        }

        if (sameDistValues.Count > 1)
            return sameDistValues[GetRandNumCrypto(0, sameDistValues.Count)];
        else
            return sameDistValues[0];
    }
}