Skip to content

sklearn.mixture.GMM: 'random_state' makes number of initialisation ('n_init') meaningless #7377

Closed
@lijing-lin

Description

@lijing-lin

I would expect that, even if random_state is given (for reproducibility reason) , each of the n_init initialisations will give a different result (therefore the one with largest log likelihood can be chosen to carry out EM).

However, this is not the case. No matter what number of initialisation n_init is specified, if random_state is not None, GMM always gives the same clustering label, as in the following code.

import numpy as np
import sklearn 

X = np.random.randn(100,5)
X = X-X.mean(0)
X = X/np.std(X,0)

gmm = sklearn.mixture.GMM(n_components=3, 
                            covariance_type='diag',
                            random_state = 1,  # this actually make n_init meaningless
                            n_init = 1000)

gmm.fit(X)
label = gmm.predict(X)

The reason is, by default GMM uses Kmean for initialisation. When random_state is specified, for each round of initialisation, it is passed to cluster.KMeans

       for init in range(self.n_init):
            if self.verbose > 0:
                print('Initialization ' + str(init + 1))
                start_init_time = time()
            if 'm' in self.init_params or not hasattr(self, 'means_'):
                self.means_ = cluster.KMeans(
                    n_clusters=self.n_components,
                    random_state=self.random_state).fit(X).cluster_centers_

So they always give the same initialisation.

One way around this is, to set the random state before running GMM

np.random.seed(random_state)
gmm = sklearn.mixture.GMM(n_components=3, 
                            covariance_type='diag',
                            random_state = None,
                            n_init = 1000)

gmm.fit(X)
label = gmm.predict(X)

Versions

Darwin-15.3.0-x86_64-i386-64bit
('Python', '2.7.11 |Anaconda 2.5.0 (x86_64)| (default, Dec  6 2015, 18:57:58) \n[GCC 4.2.1 (Apple Inc. build 5577)]')
('NumPy', '1.11.1')
('SciPy', '0.17.1')
('Scikit-Learn', '0.17')

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions