Skip to content

ENH Improve regularization messages for QuadraticDiscriminantAnalysis #19731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jun 6, 2024

Conversation

azihna
Copy link
Contributor

@azihna azihna commented Mar 19, 2021

Reference Issues/PRs

References #14997

What does this implement/fix? Explain your changes.

Remove warning: the variables collinear.
Check the covariance matrix after regularization and raise a LinAlgError prompting to increase regularization.
I left the tol argument as a way to control the error. The user can turn it off if need be.
Updated the documentation for tol argument.
Change the tests to check for the correct errors.

Any other comments?

The last test was to check for cases where n_samples_class < n_features, I tried many different configurations but I was unable to produce a set of variables the produced an output that didn't throw an error. The test is only there with negative examples, I am open to suggestions for any positive examples in this case.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @azihna !

S2 = (S ** 2) / (len(Xg) - 1)
S2 = ((1 - self.reg_param) * S2) + self.reg_param
cov_reg = np.dot(S2 * Vt.T, Vt)
det = linalg.det(cov_reg)
if det < self.tol:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Traditionally, tol is compared to the singular values for determining the rank. Can we use np.linalg.matrix_rank here and pass in tol? (Internally matrix_rank computes the svd)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

matrix rank is going to be expansive also. I think we should just use the results of the SVD as explained in the above comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your comment here: #14997 (comment)

I think that when the covariance matrix is not full rank we should probably raise a LinAlgError at fit time explicitly recommending the user to increase regularization.

From looking at the code, S does not depend on reg_param, which means increasing the regularization would still result in the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using S2 in the same way as S?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I experimented with the proposed change when using S2 instead of S the solution works as intended but there are two tests that use random data that the changes couldn't pass:

  • test_common.py/test_estimators[Quadratic_DiscriminantAnalysis-check_estimators_dtype]
  • test_common.py/test_estimators[Quadratic_DiscriminantAnalysis-check_dtype_object]

both of them fit the model with random data then check the existence of methods or attributes. This random data raises the errors. As the quick fixes, I could:

  • Change the LinAlgError to LinAlgWarning which would still warn the user of the problem
  • Mark the tests as fail
    Do you have any further recommendations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for a work around, we can update:

def _set_checking_parameters(estimator):

To set reg_param.

@azihna azihna requested review from thomasjpfan and ogrisel March 24, 2021 18:37
Comment on lines 937 to 939
X, y = make_classification(random_state=seed, n_samples=40,
n_informative=8, n_features=10,
n_classes=4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try not to special case estimators here. Is there no reg_param we can use to make this work?

Copy link
Contributor Author

@azihna azihna Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for the review @thomasjpfan. Sadly no. With random data, the check is always failing no matter how large the reg_param, I tried quite a few different ones. That's why I also mentioned in the earlier param that changing the error to a warning might be a better solution. I think users might be annoyed that the model will throw an error when they try to fit on toy datasets like that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the error to a warning with the latest commit and reverted back to the original state with estimator checks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like LinAlgWarning was introduced at SciPy==1.1.0 and the CI pipelines for Python 3.6 are using SciPy==0.19. SciPy actually support Python 3.6 until version 1.6.0. Is it possible to upgrade the CI requirements?

@azihna
Copy link
Contributor Author

azihna commented Apr 23, 2021

@ogrisel
Using LinAlgError with the current implementation seems to fail in some use cases. E.g. for some datasets (such as the randomized datasets used in estimator_checks) the error doesn't go away no matter the high the regularization, it is fine with normal datasets. This caused problems in test_common.py because the tests would fail due to LinAlgError.
Now to LinAlgWarning and changed the message to be a bit more informative, however LinAlgWarning was introduced in scipy=1.1.0 and CI pipeline fails in py36 checks.
I'd like to change the warning to RuntimeWarning with a note to upgrade when the minimum scipy dependency is changed. Do you think that this change the point of the PR too much? Would you recommend another approach?

@azihna azihna requested a review from thomasjpfan May 5, 2021 19:35
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an issue with test_common.py, you can add some regularization by updating _set_checking_parameters:

def _set_checking_parameters(estimator):
# set parameters to speed up some estimators and
# avoid deprecated behaviour

Comment on lines 518 to 521
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`
will now cause 'RuntimeWarning' in case of collinear variables. These errors
can be silenced by the 'reg_param' attribute.:pr:`19731` by
:user:`Alihan Zihna <azihna>`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May we move the discriminant_analysis section higher in this list? (The modules are in alphabetical order)

Suggested change
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`
will now cause 'RuntimeWarning' in case of collinear variables. These errors
can be silenced by the 'reg_param' attribute.:pr:`19731` by
:user:`Alihan Zihna <azihna>`
- |Enhancement| :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`
will now cause 'LinAlgWarning' in case of collinear variables. These errors
can be silenced by the 'reg_param' attribute. :pr:`19731` by
:user:`Alihan Zihna <azihna>`

Comment on lines 554 to 619
with pytest.warns(RuntimeWarning, match="divide by zero"):
y_pred = clf.predict(X2)
assert np.any(y_pred != y6)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the LinAlgWarning, is this RuntimeWarning still being raised?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it means that we need to record the warning and check that both warnings are raised or shall we catch the division by zero warning that is not informative then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put back the test, catching all the raised warnings.

@glemaitre glemaitre self-assigned this Jul 26, 2021
Comment on lines 554 to 619
with pytest.warns(RuntimeWarning, match="divide by zero"):
y_pred = clf.predict(X2)
assert np.any(y_pred != y6)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it means that we need to record the warning and check that both warnings are raised or shall we catch the division by zero warning that is not informative then?

@glemaitre glemaitre removed their assignment Jul 27, 2021
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
adrinjalali and others added 5 commits April 16, 2024 12:49
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Co-authored-by: Guillaume Lemaitre <guillaume@probabl.ai>
Copy link

github-actions bot commented Apr 16, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5d8f3ed. Link to the linter CI: here

@adrinjalali
Copy link
Member

@thomasjpfan @glemaitre @ogrisel I gave this a refresh and applied suggestions, could you please have another go?

@adrinjalali
Copy link
Member

so somehow the warnings are NOT raised in some of our CI, so I removed them in 8854f58 (they were raised locally though).

@@ -217,6 +217,14 @@ Changelog
have the `n_features_in_` and `feature_names_in_` attributes after `fit`.
:pr:`27937` by :user:`Marco vd Boom <tvdboom>`.

:mod:`sklearn.discriminant_analysis`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to move in 1.6 :)

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adrinjalali adrinjalali enabled auto-merge (squash) June 6, 2024 13:18
@adrinjalali adrinjalali merged commit 99916c4 into scikit-learn:main Jun 6, 2024
30 checks passed
@jeremiedbb jeremiedbb mentioned this pull request Jul 2, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants