diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index bba1a6216bb56..c195076fb8be7 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -43,6 +43,31 @@ def _ledoit_wolf(X, *, assume_centered, block_size): return shrunk_cov, shrinkage +def _oas(X, *, assume_centered=False): + """Estimate covariance with the Oracle Approximating Shrinkage algorithm.""" + # for only one feature, the result is the same whatever the shrinkage + if len(X.shape) == 2 and X.shape[1] == 1: + if not assume_centered: + X = X - X.mean() + return np.atleast_2d((X**2).mean()), 0.0 + + n_samples, n_features = X.shape + + emp_cov = empirical_covariance(X, assume_centered=assume_centered) + mu = np.trace(emp_cov) / n_features + + # formula from Chen et al.'s **implementation** + alpha = np.mean(emp_cov**2) + num = alpha + mu**2 + den = (n_samples + 1.0) * (alpha - (mu**2) / n_features) + + shrinkage = 1.0 if den == 0 else min(num / den, 1.0) + shrunk_cov = (1.0 - shrinkage) * emp_cov + shrunk_cov.flat[:: n_features + 1] += shrinkage * mu + + return shrunk_cov, shrinkage + + ############################################################################### # Public API # ShrunkCovariance estimator @@ -503,6 +528,7 @@ def fit(self, X, y=None): # OAS estimator +@validate_params({"X": ["array-like"]}) def oas(X, *, assume_centered=False): """Estimate covariance with the Oracle Approximating Shrinkage algorithm. @@ -537,35 +563,10 @@ def oas(X, *, assume_centered=False): The formula we used to implement the OAS is slightly modified compared to the one given in the article. See :class:`OAS` for more details. """ - X = np.asarray(X) - # for only one feature, the result is the same whatever the shrinkage - if len(X.shape) == 2 and X.shape[1] == 1: - if not assume_centered: - X = X - X.mean() - return np.atleast_2d((X**2).mean()), 0.0 - if X.ndim == 1: - X = np.reshape(X, (1, -1)) - warnings.warn( - "Only one sample available. You may want to reshape your data array" - ) - n_samples = 1 - n_features = X.size - else: - n_samples, n_features = X.shape - - emp_cov = empirical_covariance(X, assume_centered=assume_centered) - mu = np.trace(emp_cov) / n_features - - # formula from Chen et al.'s **implementation** - alpha = np.mean(emp_cov**2) - num = alpha + mu**2 - den = (n_samples + 1.0) * (alpha - (mu**2) / n_features) - - shrinkage = 1.0 if den == 0 else min(num / den, 1.0) - shrunk_cov = (1.0 - shrinkage) * emp_cov - shrunk_cov.flat[:: n_features + 1] += shrinkage * mu - - return shrunk_cov, shrinkage + estimator = OAS( + assume_centered=assume_centered, + ).fit(X) + return estimator.covariance_, estimator.shrinkage_ class OAS(EmpiricalCovariance): @@ -697,7 +698,7 @@ def fit(self, X, y=None): else: self.location_ = X.mean(0) - covariance, shrinkage = oas(X - self.location_, assume_centered=True) + covariance, shrinkage = _oas(X - self.location_, assume_centered=True) self.shrinkage_ = shrinkage self._set_covariance(covariance) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index fb4eeb26138df..bbd3a4757a835 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -26,6 +26,8 @@ ) from sklearn.covariance._shrunk_covariance import _ledoit_wolf +from .._shrunk_covariance import _oas + X, _ = datasets.load_diabetes(return_X_y=True) X_1d = X[:, 0] n_samples, n_features = X.shape @@ -336,6 +338,16 @@ def test_oas(): assert_almost_equal(oa.score(X), score_, 4) assert oa.precision_ is None + # test function _oas without assuming centered data + X_1f = X[:, 0:1] + oa = OAS() + oa.fit(X_1f) + # compare shrunk covariance obtained from data and from MLE estimate + _oa_cov_from_mle, _oa_shrinkage_from_mle = _oas(X_1f) + assert_array_almost_equal(_oa_cov_from_mle, oa.covariance_, 4) + assert_almost_equal(_oa_shrinkage_from_mle, oa.shrinkage_) + assert_array_almost_equal((X_1f**2).sum() / n_samples, oa.covariance_, 4) + def test_EmpiricalCovariance_validates_mahalanobis(): """Checks that EmpiricalCovariance validates data with mahalanobis.""" diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 0d01428b8db4b..cef75b9be9d4b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -132,6 +132,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), + ("sklearn.covariance.oas", "sklearn.covariance.OAS"), ]