-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Refactor MiniBatchDictionaryLearning and add stopping criterion #18975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Refactor MiniBatchDictionaryLearning and add stopping criterion #18975
Conversation
I am ok with the general idea of this refactoring that is likely to improve maintainability of this estimator. The fact that |
For
For
Are you ok with that ? |
@@ -1509,7 +1587,7 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator): | |||
We can check the level of sparsity of `X_transformed`: | |||
|
|||
>>> np.mean(X_transformed == 0) | |||
0.87... | |||
0.85... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a small difference due to what I think should be considered a bug with the previous behavior. The update of the inner stats depends on the batch size see https://github.com/scikit-learn/scikit-learn/pull/18975/files#diff-20a73e7d385ab5d19a05026b635c8b256ff568e1a0e9e2fed606fec82d3b956fR1732
The thing is that due to how batches are generated, the batches may not all have the same size:
batches = gen_batches(n_samples, self.batch_size)
batches = itertools.cycle(batches)
if n_samples is not a multiple of the batches, the last batch will be smaller than batch_size. In this PR I currently just use the correct size of the batch for the update of the stats, hence the small difference.
Actually I wonder if we should do something about the generation of the batches, to make them all have the same size. wdyt ?
I pushed a commit to tweak the parameters of the denoising example to get a better dictionary that leads to cleaner denoising results while still being fast enough. |
I am not sure why black has started to complain. According to our doc we should use the pinned |
We upgraded to a stable version #22474 |
Just saw that :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Posting those reviews but this is a partial review. I will finish it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you very much for the clean-up.
…scikit-learn#18975) Co-authored-by: Olivier Grisel <olivier.grisel@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Currently
MiniBatchDictionaryLearning
callsdict_learning_online
. This PR switches to makedict_learning_online
call the class instead.2 main reasons for this:
currently
dict_learning_online
serves too many purposes and exposes private stuff useful for partial_fit likeiter_offset
andinner_stats
. It makes the code very hard to follow. For instance the function has 6 possible return statements, seescikit-learn/sklearn/decomposition/_dict_learning.py
Line 866 in 4773f3e
It will greatly ease the implementation of [WIP] online matrix factorization with missing values #18492.
There are 2 options to do this: the first one (proposed here) is to make the function call the class; The other one is to make both the function and the class call a common new private function. I chose the former because it's what has already been done in a few other places like #14985 and #14994.