Skip to content

MAINT Parameters validation for covariance.ledoit_wolf #24870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Dec 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
1459031
validate param for ledoit_wolf
glevv Nov 9, 2022
5a1e1d0
added to the test
glevv Nov 9, 2022
f22c551
revert commit and add note
glevv Nov 10, 2022
451ff42
revert commit
glevv Nov 10, 2022
4dec4e7
refactor ledoit_wolf logic
glevv Nov 10, 2022
fbf5277
Update _shrunk_covariance.py
glevv Nov 10, 2022
082b8c7
Update _shrunk_covariance.py
glevv Nov 11, 2022
8a8a17b
lint
glevv Nov 11, 2022
8503a88
Update _shrunk_covariance.py
glevv Nov 11, 2022
dbfdf44
fix docs
glevv Nov 11, 2022
ff60a98
Merge branch 'scikit-learn:main' into ledoit-val
glevv Nov 13, 2022
c89c581
added to public tests
glevv Nov 13, 2022
8d322e1
revert
glevv Nov 14, 2022
2c6fa59
Merge branch 'main' into ledoit-val
glevv Nov 15, 2022
1a34099
Update test_covariance.py
glevv Nov 16, 2022
37c7acc
Update _shrunk_covariance.py
glevv Nov 16, 2022
a144812
Update test_covariance.py
glevv Nov 16, 2022
020ba76
add validate params
glevv Nov 17, 2022
82e59b9
added ledoit_wolf to public tests
glevv Nov 17, 2022
4a895f4
added tests
glevv Nov 17, 2022
8bd6de4
added default values
glevv Nov 17, 2022
7f8bad6
Update test_covariance.py
glevv Nov 17, 2022
1a787f6
linting
glevv Nov 17, 2022
2a63033
removed context manager
glevv Nov 18, 2022
d06bdc5
Merge branch 'scikit-learn:main' into ledoit-val
glevv Nov 18, 2022
b01c33f
Update _shrunk_covariance.py
glevv Nov 18, 2022
63191a6
added test
glevv Nov 18, 2022
88e3e6b
Merge branch 'main' into ledoit-val
glevv Nov 18, 2022
76f71fb
lint
glevv Nov 18, 2022
8baa450
Update test_covariance.py
glevv Nov 18, 2022
9d5dcb0
Update sklearn/covariance/tests/test_covariance.py
glevv Nov 20, 2022
da2476a
Update sklearn/covariance/tests/test_covariance.py
glevv Nov 20, 2022
27039d3
Update test_covariance.py
glevv Nov 20, 2022
16cf982
Merge branch 'main' into ledoit-val
glevv Nov 22, 2022
4e061b4
Merge branch 'main' into ledoit-val
glevv Nov 30, 2022
4a755ae
Merge branch 'main' into ledoit-val
glevv Nov 30, 2022
801877e
Merge branch 'main' into ledoit-val
glevv Dec 15, 2022
2664bb4
Merge branch 'main' into ledoit-val
glevv Dec 16, 2022
ba7841b
Merge branch 'main' into ledoit-val
glevv Dec 22, 2022
769e28a
partial validation for class wrappers
jeremiedbb Dec 27, 2022
35c8f1c
Merge remote-tracking branch 'upstream/main' into pr/glevv/24870
jeremiedbb Dec 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions sklearn/covariance/_shrunk_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,33 @@
import numpy as np

from . import empirical_covariance, EmpiricalCovariance
from .._config import config_context
from ..utils import check_array
from ..utils._param_validation import Interval
from ..utils._param_validation import Interval, validate_params


def _ledoit_wolf(X, *, assume_centered, block_size):
"""Estimate the shrunk Ledoit-Wolf covariance matrix."""
# for only one feature, the result is the same whatever the shrinkage
if len(X.shape) == 2 and X.shape[1] == 1:
if not assume_centered:
X = X - X.mean()
return np.atleast_2d((X**2).mean()), 0.0
n_features = X.shape[1]

# get Ledoit-Wolf shrinkage
shrinkage = ledoit_wolf_shrinkage(
X, assume_centered=assume_centered, block_size=block_size
)
emp_cov = empirical_covariance(X, assume_centered=assume_centered)
mu = np.sum(np.trace(emp_cov)) / n_features
shrunk_cov = (1.0 - shrinkage) * emp_cov
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu

return shrunk_cov, shrinkage


###############################################################################
# Public API
# ShrunkCovariance estimator


Expand Down Expand Up @@ -288,6 +310,7 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000):
return shrinkage


@validate_params({"X": ["array-like"]})
def ledoit_wolf(X, *, assume_centered=False, block_size=1000):
"""Estimate the shrunk Ledoit-Wolf covariance matrix.

Expand Down Expand Up @@ -325,31 +348,13 @@ def ledoit_wolf(X, *, assume_centered=False, block_size=1000):

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed this with @jeremiedbb in another PR. He suggested actually doing the double validation and adding the function to the list to run the common test.

Sorry for the back and forth. We did not anticipate all the issues. We start to settle on the way we want the validation to be done :)

Copy link
Contributor Author

@glevv glevv Nov 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I added function to public tests it raised this error

parameter_constraints = getattr(func, "_skl_parameter_constraints")
E       AttributeError: 'function' object has no attribute '_skl_parameter_constraints'

As I understand decorator validate_params adds attribute to the function, but in this case we do not use it, so it can't be accessed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we need to have:

@validate_params({
    "X": ["array-like"],
    "assume_centered": ["boolean"],
    "block_size": [Integral],
})
def ledoit_wolf(X, *, assume_centered=False, block_size=1000):

and

PARAM_VALIDATION_FUNCTION_LIST = [
    "sklearn.cluster.kmeans_plusplus",
    "sklearn.covariance.ledoit_wolf",
]

Doing this on your PR, works for me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why not implement the first proposition: add param validation to the functions without rewriting them to use class inside?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to make the pattern "function calling class" consistent.

The other pattern is to make the class and function calling a private function. However, it makes the code difficult to read after trying a couple of examples.

Double validation becomes an issue when this is called within loops.

where mu = trace(cov) / n_features
"""
X = check_array(X)
# for only one feature, the result is the same whatever the shrinkage
if len(X.shape) == 2 and X.shape[1] == 1:
if not assume_centered:
X = X - X.mean()
return np.atleast_2d((X**2).mean()), 0.0
if X.ndim == 1:
X = np.reshape(X, (1, -1))
warnings.warn(
"Only one sample available. You may want to reshape your data array"
)
n_features = X.size
else:
_, n_features = X.shape
estimator = LedoitWolf(
assume_centered=assume_centered,
block_size=block_size,
store_precision=False,
).fit(X)

# get Ledoit-Wolf shrinkage
shrinkage = ledoit_wolf_shrinkage(
X, assume_centered=assume_centered, block_size=block_size
)
emp_cov = empirical_covariance(X, assume_centered=assume_centered)
mu = np.sum(np.trace(emp_cov)) / n_features
shrunk_cov = (1.0 - shrinkage) * emp_cov
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu

return shrunk_cov, shrinkage
return estimator.covariance_, estimator.shrinkage_


class LedoitWolf(EmpiricalCovariance):
Expand Down Expand Up @@ -488,10 +493,9 @@ def fit(self, X, y=None):
self.location_ = np.zeros(X.shape[1])
else:
self.location_ = X.mean(0)
with config_context(assume_finite=True):
covariance, shrinkage = ledoit_wolf(
X - self.location_, assume_centered=True, block_size=self.block_size
)
covariance, shrinkage = _ledoit_wolf(
X - self.location_, assume_centered=True, block_size=self.block_size
)
self.shrinkage_ = shrinkage
self._set_covariance(covariance)

Expand Down
9 changes: 9 additions & 0 deletions sklearn/covariance/tests/test_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pytest

from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_almost_equal
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_array_equal
Expand All @@ -23,6 +24,7 @@
OAS,
oas,
)
from sklearn.covariance._shrunk_covariance import _ledoit_wolf

X, _ = datasets.load_diabetes(return_X_y=True)
X_1d = X[:, 0]
Expand Down Expand Up @@ -158,6 +160,9 @@ def test_ledoit_wolf():
assert_almost_equal(lw.shrinkage_, shrinkage_, 4)
assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X))
assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1])
assert_almost_equal(
lw.shrinkage_, _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1]
)
assert_almost_equal(lw.score(X), score_, 4)
# compare shrunk covariance obtained from data and from MLE estimate
lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X)
Expand All @@ -172,6 +177,10 @@ def test_ledoit_wolf():
X_1d = X[:, 0].reshape((-1, 1))
lw = LedoitWolf()
lw.fit(X_1d)
assert_allclose(
X_1d.var(ddof=0),
_ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[0],
)
lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d)
assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4)
assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_)
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_function_param_validation(func_module):

PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
]


Expand Down