-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MAINT Parameters validation for covariance.ledoit_wolf #24870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Since Therefore, I think we should avoid any validation here. However, we should inform our users about it. We can add a note in the "Notes" section: Notes
-----
Note that the input parameters will not be validated, to optimize execution time.
If you wish to validate the parameters, use the :class:`LedoitWolf` instead.
It will validate the input parameters when calling the method :term:`fit`. |
@glevv We iterated and had a couple of discussions on other PR (#24868 (comment)). Concretely, we would like, if possible, to make the function called the class. In this case, we delegate the parameter validation from the function to the class. The pattern would be similar to: https://github.com/scikit-learn/scikit-learn/pull/24884/files I am almost sure that this could be done for |
@glemaitre Ok, I will look into it |
There is no function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume that a follow-up PR could be the deprecation of ledoit_wolf_shrinkage
to make it purely private.
from ..utils._param_validation import Interval | ||
|
||
|
||
def _ledoit_wolf_shrinkage(X, assume_centered, block_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would let this function as is.
I assume that we wanted it private. However, we would need a deprecation it if we do this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if it stays as is, then we will be checking
scikit-learn/sklearn/covariance/_shrunk_covariance.py
Lines 236 to 239 in 64432e1
if X.shape[0] == 1: | |
warnings.warn( | |
"Only one sample available. You may want to reshape your data array" | |
) |
twice - once in the class and in the shrinkage estimation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No you just need to remove it from the shrinkage estimation because we already did it.
X = self._validate_data(X) | ||
if X.shape[0] == 1: | ||
warnings.warn( | ||
"Only one sample available. You may want to reshape your data array" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that we were supporting 1D array in the function. Therefore, we can let go the 1D vector and reshape and raise the warning afterwards.
X = self._validate_data(X) | |
if X.shape[0] == 1: | |
warnings.warn( | |
"Only one sample available. You may want to reshape your data array" | |
) | |
X = self._validate_data(X, ensure_2d=False) | |
if X.dim == 1: | |
X = X.reshape(1, -1) | |
warnings.warn( | |
"Only one sample available. You may want to reshape your data array" | |
) |
if X.ndim == 1: | ||
X = np.reshape(X, (1, -1)) | ||
warnings.warn( | ||
"Only one sample available. You may want to reshape your data array" | ||
) | ||
n_features = X.size | ||
else: | ||
_, n_features = X.shape | ||
X = as_float_array(X) | ||
|
||
# get Ledoit-Wolf shrinkage | ||
shrinkage = ledoit_wolf_shrinkage( | ||
X, assume_centered=assume_centered, block_size=block_size | ||
) | ||
emp_cov = empirical_covariance(X, assume_centered=assume_centered) | ||
mu = np.sum(np.trace(emp_cov)) / n_features | ||
shrunk_cov = (1.0 - shrinkage) * emp_cov | ||
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu | ||
estimator = LedoitWolf( | ||
assume_centered=assume_centered, | ||
block_size=block_size, | ||
).fit(X) | ||
|
||
return shrunk_cov, shrinkage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would isolate this part of the code in a _ledoit_wolf
private function for the moment.
The only thing that we can now assume is that X
has been validated and it is 2d. So we can use safely n_features = X.shape[1]
.
(1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) | ||
|
||
where mu = trace(cov) / n_features | ||
""" | ||
X = check_array(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean. We again check and raise a new warning. If we deprecate and make this class private then, we can avoid making these checks because we don't have any source code that can reach these warnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So what is the deprecation policy here? This function is technically public, but it is not mentioned in the docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do the deprecation in another PR. The idea will be
- create a private function with the same signature
- make the public function calling the private one
- if people call the private one, they should get a warning that it is deprecated in the current version (e.g. 1.2) and it will be removed in 2 minor versions (1.4).
You can find more details regarding the deprecation handling in scikit-learn here: https://scikit-learn.org/dev/developers/contributing.html#deprecation
It is true that we never documented but we tend to be extra careful to remove something that was half public.
n_features = X.size | ||
else: | ||
_, n_features = X.shape | ||
X = as_float_array(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this check. We delegate it to the class (done by the call to _validate_data
.
@@ -483,13 +481,18 @@ def fit(self, X, y=None): | |||
self._validate_params() | |||
# Not calling the parent object to fit, to avoid computing the | |||
# covariance matrix (and potentially the precision) | |||
X = self._validate_data(X) | |||
X = self._validate_data(X, ensure_2d=False) | |||
if X.shape[0] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if X.shape[0] == 1: | |
if X.ndim == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should solve the error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need to add the function to list to run the common test as stated in the original issue.
@glemaitre did you mean |
Sorry, my bad. We don't need to add it since it will be validated through the estimator. |
@glemaitre nah, it's my fault, should've asked before making a commit |
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Maybe function |
…24870) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
…24870) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
…24870) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
Linked to #24862 and #24868