Skip to content

[MRG] Refactor MiniBatchDictionaryLearning and add stopping criterion #18975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 106 commits into from
Mar 30, 2022

Conversation

jeremiedbb
Copy link
Member

@jeremiedbb jeremiedbb commented Dec 7, 2020

Currently MiniBatchDictionaryLearning calls dict_learning_online. This PR switches to make dict_learning_online call the class instead.

2 main reasons for this:

There are 2 options to do this: the first one (proposed here) is to make the function call the class; The other one is to make both the function and the class call a common new private function. I chose the former because it's what has already been done in a few other places like #14985 and #14994.

@ogrisel
Copy link
Member

ogrisel commented Dec 9, 2020

I am ok with the general idea of this refactoring that is likely to improve maintainability of this estimator. The fact that dict_learning_online has state related arguments is a clue that it should fundamentally be a method on a class.

@jeremiedbb
Copy link
Member Author

For dict_learning_online I decided for now to deprecate

  • inter_offset, inner_stats and return_inner_stats. Those are only useful for partial fit and only serve private purpose.
  • return_n_iter. I'm pretty sure it was introduce to not break backward compat but there's no reason to not return n_iter.

For MiniBatchDictionaryLearning I decided to deprecate

  • iter_offset_, inner_stats_ and random_state_ attributes and made them private.
  • iter_offset in partial_fit. I don't see the point of this parameter. We don't have it for other online estmators.

Are you ok with that ?

@@ -1509,7 +1587,7 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
We can check the level of sparsity of `X_transformed`:

>>> np.mean(X_transformed == 0)
0.87...
0.85...
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a small difference due to what I think should be considered a bug with the previous behavior. The update of the inner stats depends on the batch size see https://github.com/scikit-learn/scikit-learn/pull/18975/files#diff-20a73e7d385ab5d19a05026b635c8b256ff568e1a0e9e2fed606fec82d3b956fR1732

The thing is that due to how batches are generated, the batches may not all have the same size:

batches = gen_batches(n_samples, self.batch_size)
batches = itertools.cycle(batches)

if n_samples is not a multiple of the batches, the last batch will be smaller than batch_size. In this PR I currently just use the correct size of the batch for the update of the stats, hence the small difference.

Actually I wonder if we should do something about the generation of the batches, to make them all have the same size. wdyt ?

@jeremiedbb jeremiedbb changed the title [WIP] Make dict_learning_online call MiniBatchDictionaryLearning instead of the opposite [MRG] Make dict_learning_online call MiniBatchDictionaryLearning instead of the opposite Dec 11, 2020
@jeremiedbb jeremiedbb changed the title [MRG] Make dict_learning_online call MiniBatchDictionaryLearning instead of the opposite [WIP] Make dict_learning_online call MiniBatchDictionaryLearning instead of the opposite Dec 16, 2020
@jeremiedbb jeremiedbb added this to the 1.1 milestone Feb 10, 2022
@ogrisel
Copy link
Member

ogrisel commented Feb 15, 2022

I pushed a commit to tweak the parameters of the denoising example to get a better dictionary that leads to cleaner denoising results while still being fast enough.

@ogrisel
Copy link
Member

ogrisel commented Feb 15, 2022

I am not sure why black has started to complain. According to our doc we should use the pinned black==21.6b0 version but this is not the case on this build.

@jeremiedbb
Copy link
Member Author

jeremiedbb commented Feb 15, 2022

We upgraded to a stable version #22474

@ogrisel
Copy link
Member

ogrisel commented Feb 15, 2022

Just saw that :)

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting those reviews but this is a partial review. I will finish it now.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jjerphan jjerphan self-requested a review March 11, 2022 10:53
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you very much for the clean-up.

@ogrisel ogrisel merged commit a23c2ed into scikit-learn:main Mar 30, 2022
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Apr 6, 2022
…scikit-learn#18975)


Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants