Skip to content

WIP : Self-Organizing Map #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions examples/cluster/plot_som_colormap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
===========================================================
A demo of SelfOrganisingMap with colored neurons
===========================================================

Example for SOM clustering using 3 dimensionals vectors (RGB)
with 8 colors (black, white, red, green, blue, yellow, cyan, magenta)

"""
print __doc__
import pylab as pl
from matplotlib.colors import ListedColormap, NoNorm, rgb2hex
import numpy as np
from scikits.learn.cluster import SelfOrganizingMap


def plot(neurons):
assert neurons.shape[-1] == 3
h, w, d = neurons.shape
hexmap = np.apply_along_axis(rgb2hex, 1, neurons.reshape(-1, 3) / 256)
index = np.arange(h * w).reshape(h, w)
pl.pcolor(index, cmap=ListedColormap(hexmap), norm=NoNorm())

train = np.array([[0, 0, 0], # black
[255, 255, 255], # white
[255, 0, 0], # red
[0, 255, 0], # green
[0, 0, 255], # blue
[255, 255, 0], # yellow
[0, 255, 255], # cyan
[255, 0, 255]]) # magenta

init = np.random.rand(16, 16, 3) * 255

pl.subplot(1, 2, 1, aspect='equal')
plot(init)
pl.title('Initial map')

som = SelfOrganizingMap(init, n_iterations=1024,
init='matrix', learning_rate=1)
som.fit(train)

pl.subplot(1, 2, 2, aspect='equal')
plot(som.neurons_)
pl.title('Organized Map')
F = pl.gcf()
F.set_size_inches((40, 20))
pl.show()
66 changes: 66 additions & 0 deletions examples/cluster/som_digits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
=======================================================================
A demo of Self-Organising Map and KMeans on the handwritten digits data
=======================================================================

Comparing various SOM and Kmeans clustering on the handwritten digits data
with the pseudo_F index

"""
from __future__ import division
print __doc__

from time import time
import numpy as np

from scikits.learn.cluster import KMeans
from scikits.learn.cluster import SelfOrganizingMap
from scikits.learn.cluster import pseudo_F
from scikits.learn.datasets import load_digits
from scikits.learn.preprocessing import scale
from scikits.learn.metrics import confusion_matrix

np.random.seed(42)

################################################################################
# Load dataset

digits = load_digits()
data = scale(digits.data)
n_samples, n_features = data.shape
n_digits = len(np.unique(digits.target))

print "Digits dataset"
print "n_digits : %d" % n_digits
print "n_features : %d" % n_features
print "n_samples : %d" % n_samples
print

################################################################################
# Digits dataset clustering using Self-Organizing Map

print "Self-Organizing Map "
t0 = time()
grid_width = 4
som = SelfOrganizingMap(size=grid_width, n_iterations=n_samples*5,
learning_rate=1)
som.fit(data)
print "done in %0.3fs" % (time() - t0)
print

F = pseudo_F(data, som.labels_, som.neurons_)
print 'pseudo_F %0.2f | %0.2f%%' % (F, 100 * (F / (1 + F)))
print

################################################################################
# Digits dataset clustering using Kmeans

print "KMeans "
t0 = time()
km = KMeans(init='k-means++', k=grid_width**2, n_init=10)
km.fit(data)
print "done in %0.3fs" % (time() - t0)
print

F = pseudo_F(data, km.labels_, km.cluster_centers_)
print 'pseudo_F %0.2f | %0.2f%%' % (F, 100 * (F / (1 + F)))
23 changes: 23 additions & 0 deletions scikits/learn/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,27 @@
from .mean_shift_ import mean_shift, MeanShift, estimate_bandwidth
from .affinity_propagation_ import affinity_propagation, AffinityPropagation
from .k_means_ import k_means, KMeans
from .som_ import SelfOrganizingMap

import numpy as np

def pseudo_F(X, labels, centroids):
'''
The pseudo F statistic :

pseudo F = [( [(T - PG)/(G - 1)])/( [(PG)/(n - G)])]

The pseudo F statistic was suggested by Calinski and Harabasz (1974)

Calinski, T. and J. Harabasz. 1974.
A dendrite method for cluster analysis. Commun. Stat. 3: 1-27.
http://dx.doi.org/10.1080/03610927408827101
'''
mean = np.mean(X,axis=0)
B = np.sum([ (c - mean)**2 for c in centroids])
W = np.sum([ (x-centroids[labels[i]])**2
for i, x in enumerate(X)])
c = len(centroids)
n = len(X)
return (B /(c-1))/(W/ (n-c))

138 changes: 138 additions & 0 deletions scikits/learn/cluster/som_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Self-organizing map

Reference : (to check)
Kohonen, T.; , "The self-organizing map,"
Proceedings of the IEEE , vol.78, no.9, pp.1464-1480, Sep 1990
"""
# Authors: Sebastien Campion <sebastien.campion@inria.fr>
# License: BSD
from __future__ import division
import math
import numpy as np
from ..base import BaseEstimator


class SelfOrganizingMap(BaseEstimator):
"""Self-Organizing Map

Parameters
----------
X : ndarray
A M by N array of M observations in N dimensions or a length
M array of N one-dimensional observations.

size : int
Width and height of the square map as well as the number of
centroids to generate. If init initialization string is
'matrix', or if a ndarray is given instead, it is
interpreted as initial cluster to use instead.

n_iterations : int
Number of iterations of the som algrithm to run

learning_rate : float
Learning rate

init : {'random', 'matrix'}
Method for initialization, defaults to 'random':

'random': randomly points choosed

'matrix': interpret the size parameter as a size by M array
of initial neurons.

Methods
-------
fit(X):
Compute SOM

Attributes
----------
neurons_: array, [(x,y), n_features]
Coordinates of neurons and value

labels_:
Labels of each point

Notes
------

"""

def __init__(self, size=16, init='random', n_iterations=64,
learning_rate=1, callback=None):
self.size = size
self.init = init
self.n_iterations = n_iterations
self.learning_rate = learning_rate
self.callback = callback

def fit(self, X, **params):
"""Given an sample of X, we randomly choose one of them for each
iteration.
A good ratio, nb X = 2 or 3 x nbiter"""
X = np.asanyarray(X)
self._set_params(**params)
self.neurons_ = None
dim = X.shape[-1]

# init neurons_
if self.init == 'random':
self.neurons_ = np.random.rand(self.size, self.size, dim)
elif self.init == 'matrix':
assert len(self.size.shape) == 3
self.neurons_ = self.size
self.size = self.neurons_.shape[0]

# iteration loop
iteration = 0
indices = np.random.random_integers(0, len(X)-1, self.n_iterations)
for i in indices:
l = self.n_iterations / self.size
lr = self.learning_rate * math.exp(-iteration / l)
self._learn_vector(X[i], lr, iteration)
iteration += 1
if self.callback != None:
self.callback(self, iteration)

# assign labels
self.labels_ = [self.bmu(x) for x in X]
return self

def _learn_vector(self, vector, lr, iteration):
winner = self.bmu(vector)
radius = self.radius_of_the_neighbordhood(iteration)
for n in self.neurons_in_radius(winner, radius):
nx, ny = n
wt = self.neurons_[nx][ny]
dr = self.dist(winner, n, radius)
self.neurons_[nx][ny] = wt + dr * lr * (vector - wt)

def bmu(self, vector):
"""Best matching unit
"""
assert vector.shape[0] == self.neurons_.shape[-1]
vector = np.resize(vector, self.neurons_.shape)
dists = np.sum((vector - self.neurons_)**2, axis=-1)
min = dists.argmin()
#w = np.unravel_index(min,dists.shape)
return divmod(min, self.size)

def dist(self, w, n, radius):
wx, wy = w
nx, ny = n
d = (wx - nx)**2 + (wy - ny)**2
# offcial paper implementation : return math.exp(-d/2*radius**2)
return math.exp(-d / radius)

def neurons_in_radius(self, winner, radius):
wi, wj = winner
x = y = np.arange(self.size)
xx, yy = np.meshgrid(x, y)
v = np.sqrt((xx - wi)**2 + (yy - wj)**2) < radius
return np.c_[np.nonzero(v)]

def radius_of_the_neighbordhood(self, iteration):
l = self.n_iterations / self.size
return self.size * math.exp(-iteration / l)
42 changes: 42 additions & 0 deletions scikits/learn/cluster/tests/test_som.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Testing for SOM

"""
import numpy as np
from numpy.testing import assert_equal

from ..som_ import SelfOrganizingMap
from .common import generate_clustered_data

n_clusters = 4
n_features = 2
X = generate_clustered_data(n_clusters=n_clusters, n_features=2, std=.1)


def test_som():
np.random.seed(1)
som = SelfOrganizingMap(size=2, n_iterations=10, learning_rate=1)
som.fit(X)
labels = som.labels_

assert_equal(np.unique(labels).shape[0], 4)
assert_equal(np.unique(labels[:20]).shape[0], 1)
assert_equal(np.unique(labels[20:40]).shape[0], 1)
assert_equal(np.unique(labels[40:60]).shape[0], 1)
assert_equal(np.unique(labels[60:]).shape[0], 1)

def test_som_init_matrix():
np.random.seed(1)
random_ind = np.random.randint(0, X.shape[0], size=n_clusters)
init_map = X[random_ind].reshape(2,2,n_features)

som = SelfOrganizingMap(size=init_map, init='matrix',
n_iterations=2000, learning_rate=0.1)

som.fit(X)
labels = som.labels_
assert_equal(np.unique(labels).shape[0], 4)
assert_equal(np.unique(labels[:20]).shape[0], 1)
assert_equal(np.unique(labels[20:40]).shape[0], 1)
assert_equal(np.unique(labels[40:60]).shape[0], 1)
assert_equal(np.unique(labels[60:]).shape[0], 1)