-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Improve regularization messages for QuadraticDiscriminantAnalysis #19731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @azihna !
sklearn/discriminant_analysis.py
Outdated
S2 = (S ** 2) / (len(Xg) - 1) | ||
S2 = ((1 - self.reg_param) * S2) + self.reg_param | ||
cov_reg = np.dot(S2 * Vt.T, Vt) | ||
det = linalg.det(cov_reg) | ||
if det < self.tol: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Traditionally, tol
is compared to the singular values for determining the rank. Can we use np.linalg.matrix_rank
here and pass in tol
? (Internally matrix_rank
computes the svd)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matrix rank is going to be expansive also. I think we should just use the results of the SVD as explained in the above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In your comment here: #14997 (comment)
I think that when the covariance matrix is not full rank we should probably raise a LinAlgError at fit time explicitly recommending the user to increase regularization.
From looking at the code, S
does not depend on reg_param
, which means increasing the regularization would still result in the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using S2
in the same way as S
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I experimented with the proposed change when using S2
instead of S
the solution works as intended but there are two tests that use random data that the changes couldn't pass:
- test_common.py/test_estimators[Quadratic_DiscriminantAnalysis-check_estimators_dtype]
- test_common.py/test_estimators[Quadratic_DiscriminantAnalysis-check_dtype_object]
both of them fit the model with random data then check the existence of methods or attributes. This random data raises the errors. As the quick fixes, I could:
- Change the LinAlgError to LinAlgWarning which would still warn the user of the problem
- Mark the tests as fail
Do you have any further recommendations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for a work around, we can update:
def _set_checking_parameters(estimator): |
To set reg_param
.
sklearn/utils/estimator_checks.py
Outdated
X, y = make_classification(random_state=seed, n_samples=40, | ||
n_informative=8, n_features=10, | ||
n_classes=4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We try not to special case estimators here. Is there no reg_param
we can use to make this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot for the review @thomasjpfan. Sadly no. With random data, the check is always failing no matter how large the reg_param, I tried quite a few different ones. That's why I also mentioned in the earlier param that changing the error to a warning might be a better solution. I think users might be annoyed that the model will throw an error when they try to fit on toy datasets like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the error to a warning with the latest commit and reverted back to the original state with estimator checks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like LinAlgWarning
was introduced at SciPy==1.1.0 and the CI pipelines for Python 3.6 are using SciPy==0.19. SciPy actually support Python 3.6 until version 1.6.0. Is it possible to upgrade the CI requirements?
@ogrisel |
…nto add_linalg_error_qda
…nto add_linalg_error_qda
…nto add_linalg_error_qda
…nto add_linalg_error_qda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is an issue with test_common.py
, you can add some regularization by updating _set_checking_parameters
:
scikit-learn/sklearn/utils/estimator_checks.py
Lines 566 to 568 in 36915ae
def _set_checking_parameters(estimator): | |
# set parameters to speed up some estimators and | |
# avoid deprecated behaviour |
doc/whats_new/v1.0.rst
Outdated
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis` | ||
will now cause 'RuntimeWarning' in case of collinear variables. These errors | ||
can be silenced by the 'reg_param' attribute.:pr:`19731` by | ||
:user:`Alihan Zihna <azihna>` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May we move the discriminant_analysis
section higher in this list? (The modules are in alphabetical order)
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis` | |
will now cause 'RuntimeWarning' in case of collinear variables. These errors | |
can be silenced by the 'reg_param' attribute.:pr:`19731` by | |
:user:`Alihan Zihna <azihna>` | |
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis` | |
will now cause 'LinAlgWarning' in case of collinear variables. These errors | |
can be silenced by the 'reg_param' attribute. :pr:`19731` by | |
:user:`Alihan Zihna <azihna>` |
with pytest.warns(RuntimeWarning, match="divide by zero"): | ||
y_pred = clf.predict(X2) | ||
assert np.any(y_pred != y6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the LinAlgWarning
, is this RuntimeWarning
still being raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it means that we need to record the warning and check that both warnings are raised or shall we catch the division by zero warning that is not informative then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've put back the test, catching all the raised warnings.
with pytest.warns(RuntimeWarning, match="divide by zero"): | ||
y_pred = clf.predict(X2) | ||
assert np.any(y_pred != y6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it means that we need to record the warning and check that both warnings are raised or shall we catch the division by zero warning that is not informative then?
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
@thomasjpfan @glemaitre @ogrisel I gave this a refresh and applied suggestions, could you please have another go? |
so somehow the warnings are NOT raised in some of our CI, so I removed them in 8854f58 (they were raised locally though). |
doc/whats_new/v1.5.rst
Outdated
@@ -217,6 +217,14 @@ Changelog | |||
have the `n_features_in_` and `feature_names_in_` attributes after `fit`. | |||
:pr:`27937` by :user:`Marco vd Boom <tvdboom>`. | |||
|
|||
:mod:`sklearn.discriminant_analysis` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to move in 1.6 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reference Issues/PRs
References #14997
What does this implement/fix? Explain your changes.
Remove warning:
the variables collinear
.Check the covariance matrix after regularization and raise a LinAlgError prompting to increase regularization.
I left the
tol
argument as a way to control the error. The user can turn it off if need be.Updated the documentation for
tol
argument.Change the tests to check for the correct errors.
Any other comments?
The last test was to check for cases where n_samples_class < n_features, I tried many different configurations but I was unable to produce a set of variables the produced an output that didn't throw an error. The test is only there with negative examples, I am open to suggestions for any positive examples in this case.