Skip to content

[MRG] Fix Random initialisation of GMM should consider data magnitude #10850 #11101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

g-walsh
Copy link
Contributor

@g-walsh g-walsh commented May 16, 2018

Reference Issues/PRs

Fixes #10850
See also #10741

What does this implement/fix? Explain your changes.

Added a third option for initialisation 'rand_data' which samples points from the data set for initialisation. It does this by assigning zero to all responsibilities except for the sampled points which are assigned responsibility of 1 to a given component.

When the init_params are calculated when calling the gmm, this resp will produce inital means at the sampled points.

Any other comments?

I think this has added the desired function from the original PR #10741 but I'm not sure if it is a useful addition so I did a bit of extra looking. I've added gmm_test2.py (which I would not include in an eventual merge) to show how I produced the following.

Here (seed 1234) the sampling works fine for all three methods. The original data is 4 sets of Gaussian data. Orange crosses are the initial_mean values and the colouring is the labelling of the data by the gmm.

figure_1234

But here (seed 10)

figure_10

I think the fact that two of the sampled points are close together gives a poor fit for rand_data. I've labelled this as WIP to see if you think this is worth investigating further as a feature? I know that things like documentation (and proper testing) would need to be updated in addition to this.

n.b Also worth pointing out is that the current 'random' does take data magnitude into account in some respect as it will produce inital centres very close to the mean of the data set (along all dimensions) as seen here.

@jnothman
Copy link
Member

jnothman commented May 16, 2018 via email

@g-walsh
Copy link
Contributor Author

g-walsh commented May 17, 2018

@jnothman Sure, will take a look at that next.

@g-walsh
Copy link
Contributor Author

g-walsh commented May 18, 2018

@jnothman I've taken a look at the kmeans++ and it seem to be just a better version of random data selection. It randomly selects an initial center and then selects subsequent ones with a bias against being close to the existing points. It seems to solve the problems I had with random before. See below

figure_1

I've added a 'k-means++' option to PR on base.py. It feels like a bit inelegant though as the existing k-means++ initialisation function _k_init (in the kmeans module) is an internal function and outputs the points themselves rather than the responsibilities. This means that I have to lookup the centres in the original data, create a resp to create the means again. Let me know what you think.

@jnothman
Copy link
Member

jnothman commented May 21, 2018 via email

@g-walsh
Copy link
Contributor Author

g-walsh commented May 22, 2018

@jnothman Thanks, I've now added tracking of indices to the implementation of km++ and now this is used in the initialisation of gmm.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good! Do you think rand_data is still helpful?

You need to update:

  • The tests: add a test of _k_init to ensure the two outputs are consistent. Or change the implementation to only output indices and test that.
  • The tests: Otherwise at least check that each approach is different in gmm, and ideally some properties thereof
  • The parameter docstring in the gmm classes
  • Any examples that would be a better illustration if k-means++ initialisation were used

@g-walsh
Copy link
Contributor Author

g-walsh commented May 23, 2018

Thanks! No, I don't think rand_data is useful at all. I only left it in as it was the suggestion of the original Issue and PR. I'll remove it.

I'm not too experienced with making tests but I'll take a look at how you guys implement (and docs) them and then have a go.

@amueller
Copy link
Member

I think rand_data is useful, the complexity of kmeans++ is terrible for large n_clusters! It often takes longer than running k-means.

@g-walsh
Copy link
Contributor Author

g-walsh commented May 30, 2018

@jnothman I've just taken a look and I have a couple of questions.

  • I think that with this patch only two functions use _k_init, the first is k_means which already has tests for kmeans++ initialisation (as well as random initialisation) so would I need to add any extra tests for that? If the existing tests pass then can I assume that my addition of indices hasn't affected the existing implementation and interaction of kmeans/kmeans++ ?
  • The new addition of the kmeans++ to mixture models obviously needs a test. I currently can't find any tests for consistency for the existing initialisation methods (random and kmeans), do they exist? I can try and write tests for all four potential initialisations (2 existing: random and kmeans and 2 new: kmeans++ and rand_data) if that works.

I'll take a look at the documentation and examples after the testing looks good.

@amueller I'm happy to leave rand_data in but as I showed in my example above it could lead to poor fitting. Should we add a warning to it, perhaps with an example similar to above?

@jnothman
Copy link
Member

jnothman commented Jun 5, 2018

Thanks for just giving it a go. You have some pep8 issues (at a glance I saw you needed spaces after commas).

@amueller knows what he's talking about when it comes to KMeans initialisation ;) Let's leave rand_data in as an option and just document their caveats.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be good to illustrate the quality and variation due to initialisation in an example, but perhaps it's not essential. The user guide (i.e. mixture.rst) might be a good place to note the pros and cons of each method.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main thing this is missing is documentation or example of when to use which setting

@g-walsh
Copy link
Contributor Author

g-walsh commented Jun 8, 2018

I've had a first pass at adding to the documentation and have added an example that I feel shows the different initializations. This is my first attempt at this kind of documentation so feedback would be very helpful for me!

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I think it'd be helpful to indicate on the example the number of iterations to convergence. It would also be interesting to record the amount of time spent in initialisation, but I assume it's too tiny to be meaningful on this example

@g-walsh
Copy link
Contributor Author

g-walsh commented Jun 9, 2018

Thanks, I've dealt with the camel case and I've added a bit of evaluation of the methods.

It's a bit simple but I've just used the built in timing functions to time how long the initialization takes and then add it to the plot as a relative timing. i.e. the time taken to do the initialization is given as a multiple of the time it took to do the random initialization.

I also used the verbose output on the gmm to come up with some idea of how many iterations of the Gaussian Mixture would take on my desktop machine. It might be better to try and read out how many iterations it takes into a variable and then display it on the plot.

@jnothman
Copy link
Member

jnothman commented Jun 9, 2018 via email

@jnothman
Copy link
Member

jnothman commented Jun 9, 2018

You can change this from WIP to MRG when ready

@g-walsh
Copy link
Contributor Author

g-walsh commented Jun 10, 2018

Thanks! I hadn't seen the n_iter_ attribute before. I've added that to the plot now.

@g-walsh g-walsh changed the title [WIP] Fix Random initialisation of GMM should consider data magnitude #10850 [MRG] Fix Random initialisation of GMM should consider data magnitude #10850 Jun 10, 2018
Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, this looks great!

from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from sklearn.utils.extmath import row_norms
from sklearn.cluster.k_means_ import _k_init
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really shouldn't be using private API here...

It's a pity that this is the only substantial problem I have with the example. @amueller, what do you think of making kmeans++ public?

Alternatively, we could consider:

  • storing the initial means on the GMM estimator
  • allowing GMM to work with max_iter=0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a blocker, unfortunately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you intend to try fixing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was waiting to see if you guys thought making kmeans++ public was feasible or not. If not then I can certainly look into an alternative method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@amueller what do you think making kmeans++ public?
But I don't mind allowing max_iter=0 either as a diagnostic tool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added allowing max_iter=0 and changed the plot to not include the private kmeans. Do you think I should add a warning about running it with max_iter=0 or do you think the convergence warning is enough?

@g-walsh
Copy link
Contributor Author

g-walsh commented Jun 11, 2018

I've fixed all of these except for the private _k_init awaiting a decision.

Copy link
Member

@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

Please add an entry to the change log at doc/whats_new/v0.21.rst. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:

@cmarmo
Copy link
Contributor

cmarmo commented Jul 9, 2020

@g-walsh thanks for your patience! Do you think you will find some time to fix conflicts with upstream? Then maybe @jeremiedbb could check? Thanks!

@g-walsh
Copy link
Contributor Author

g-walsh commented Jul 9, 2020

@g-walsh thanks for your patience! Do you think you will find some time to fix conflicts with upstream? Then maybe @jeremiedbb could check? Thanks!

@cmarmo yep, I should have some time soon to go over and update this. Thanks for following it up!

@g-walsh
Copy link
Contributor Author

g-walsh commented Jul 10, 2020

@g-walsh thanks for your patience! Do you think you will find some time to fix conflicts with upstream? Then maybe @jeremiedbb could check? Thanks!

@cmarmo yep, I should have some time soon to go over and update this. Thanks for following it up!

@cmarmo Any chance you could help with the errors for these failing checks? It seems to be during install/build and I'm not sure where to start debugging. My local builds seem to work just fine.

The install error that seems most common across the builds is:

distutils.errors.DistutilsModuleError: invalid command 'develop'

Seen this before?

@jeremiedbb
Copy link
Member

Merging master will fix the ci

@jeremiedbb
Copy link
Member

Continued and finished in #20408. Thanks @g-walsh !

@jeremiedbb jeremiedbb closed this Apr 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Random initialisation of GMM should consider data magnitude
5 participants