From 1459031802e3db662b2bd08037f3fb9f6fe4f4bb Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:46:32 +0300 Subject: [PATCH 01/30] validate param for ledoit_wolf --- sklearn/covariance/_shrunk_covariance.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 72f5101f4d753..7361e75ba1ba8 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -20,7 +20,7 @@ from . import empirical_covariance, EmpiricalCovariance from .._config import config_context from ..utils import check_array -from ..utils._param_validation import Interval +from ..utils._param_validation import Interval, validate_params # ShrunkCovariance estimator @@ -288,6 +288,13 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): return shrinkage +@validate_params( + { + "X": ["array-like"], + "assume_centered": ["boolean"], + "block_size": [Interval(Integral, 1, None, closed="left")], + } +) def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. From 5a1e1d01c6065efc424f08cb41af9d03742de510 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:49:03 +0300 Subject: [PATCH 02/30] added to the test --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..61b59d5f88020 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.covariance.ledoit_wolf", ] From f22c55149f313d6ad532603e49401a09100fd4b6 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 10 Nov 2022 16:23:51 +0000 Subject: [PATCH 03/30] revert commit and add note --- sklearn/covariance/_shrunk_covariance.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 7361e75ba1ba8..71bd6a3482f01 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -20,7 +20,7 @@ from . import empirical_covariance, EmpiricalCovariance from .._config import config_context from ..utils import check_array -from ..utils._param_validation import Interval, validate_params +from ..utils._param_validation import Interval # ShrunkCovariance estimator @@ -288,13 +288,6 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): return shrinkage -@validate_params( - { - "X": ["array-like"], - "assume_centered": ["boolean"], - "block_size": [Interval(Integral, 1, None, closed="left")], - } -) def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. @@ -331,6 +324,10 @@ def ledoit_wolf(X, *, assume_centered=False, block_size=1000): (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) where mu = trace(cov) / n_features + + Note that the input parameters will not be validated, to optimize execution time. + If you wish to validate the parameters, use the :class:`LedoitWolf` instead. + It will validate the input parameters when calling the method :term:`fit`. """ X = check_array(X) # for only one feature, the result is the same whatever the shrinkage From 451ff42b3aa653604b0969ab9df39865dbcbdf4b Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 10 Nov 2022 16:24:24 +0000 Subject: [PATCH 04/30] revert commit --- sklearn/tests/test_public_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 61b59d5f88020..d4e645c052dab 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,7 +10,6 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", - "sklearn.covariance.ledoit_wolf", ] From 4dec4e75c36cda42dc3b5a294f269e67bf421cbc Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 10 Nov 2022 18:23:17 +0000 Subject: [PATCH 05/30] refactor ledoit_wolf logic --- sklearn/covariance/_shrunk_covariance.py | 215 ++++++++++------------- 1 file changed, 91 insertions(+), 124 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 71bd6a3482f01..bbd8d4a101d11 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -19,10 +19,89 @@ from . import empirical_covariance, EmpiricalCovariance from .._config import config_context -from ..utils import check_array +from ..utils import check_array, as_float_array from ..utils._param_validation import Interval +def _ledoit_wolf_shrinkage(X, assume_centered, block_size): + """Estimate Ledoit-Wolf shrinkage.""" + # for only one feature, the result is the same whatever the shrinkage + if len(X.shape) == 2 and X.shape[1] == 1: + return 0.0 + n_samples, n_features = X.shape + + # optionally center data + if not assume_centered: + X = X - X.mean(0) + + # A non-blocked version of the computation is present in the tests + # in tests/test_covariance.py + + # number of blocks to split the covariance matrix into + n_splits = int(n_features / block_size) + X2 = X**2 + emp_cov_trace = np.sum(X2, axis=0) / n_samples + mu = np.sum(emp_cov_trace) / n_features + beta_ = 0.0 # sum of the coefficients of + delta_ = 0.0 # sum of the *squared* coefficients of + # starting block computation + for i in range(n_splits): + for j in range(n_splits): + rows = slice(block_size * i, block_size * (i + 1)) + cols = slice(block_size * j, block_size * (j + 1)) + beta_ += np.sum(np.dot(X2.T[rows], X2[:, cols])) + delta_ += np.sum(np.dot(X.T[rows], X[:, cols]) ** 2) + rows = slice(block_size * i, block_size * (i + 1)) + beta_ += np.sum(np.dot(X2.T[rows], X2[:, block_size * n_splits :])) + delta_ += np.sum(np.dot(X.T[rows], X[:, block_size * n_splits :]) ** 2) + for j in range(n_splits): + cols = slice(block_size * j, block_size * (j + 1)) + beta_ += np.sum(np.dot(X2.T[block_size * n_splits :], X2[:, cols])) + delta_ += np.sum(np.dot(X.T[block_size * n_splits :], X[:, cols]) ** 2) + delta_ += np.sum( + np.dot(X.T[block_size * n_splits :], X[:, block_size * n_splits :]) ** 2 + ) + delta_ /= n_samples**2 + beta_ += np.sum( + np.dot(X2.T[block_size * n_splits :], X2[:, block_size * n_splits :]) + ) + # use delta_ to compute beta + beta = 1.0 / (n_features * n_samples) * (beta_ / n_samples - delta_) + # delta is the sum of the squared coefficients of ( - mu*Id) / p + delta = delta_ - 2.0 * mu * emp_cov_trace.sum() + n_features * mu**2 + delta /= n_features + # get final beta as the min between beta and delta + # We do this to prevent shrinking more than "1", which would invert + # the value of covariances + beta = min(beta, delta) + # finally get shrinkage + shrinkage = 0 if beta == 0 else beta / delta + return shrinkage + + +def _ledoit_wolf(X, *, assume_centered, block_size): + """Estimate the shrunk Ledoit-Wolf covariance matrix.""" + # for only one feature, the result is the same whatever the shrinkage + if len(X.shape) == 2 and X.shape[1] == 1: + if not assume_centered: + X = X - X.mean() + return np.atleast_2d((X**2).mean()), 0.0 + n_features = X.shape[1] + + # get Ledoit-Wolf shrinkage + shrinkage = _ledoit_wolf_shrinkage( + X, assume_centered=assume_centered, block_size=block_size + ) + emp_cov = empirical_covariance(X, assume_centered=assume_centered) + mu = np.sum(np.trace(emp_cov)) / n_features + shrunk_cov = (1.0 - shrinkage) * emp_cov + shrunk_cov.flat[:: n_features + 1] += shrinkage * mu + + return shrunk_cov, shrinkage + + +############################################################################### +# Public API # ShrunkCovariance estimator @@ -193,101 +272,6 @@ def fit(self, X, y=None): # Ledoit-Wolf estimator -def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): - """Estimate the shrunk Ledoit-Wolf covariance matrix. - - Read more in the :ref:`User Guide `. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Data from which to compute the Ledoit-Wolf shrunk covariance shrinkage. - - assume_centered : bool, default=False - If True, data will not be centered before computation. - Useful to work with data whose mean is significantly equal to - zero but is not exactly zero. - If False, data will be centered before computation. - - block_size : int, default=1000 - Size of blocks into which the covariance matrix will be split. - - Returns - ------- - shrinkage : float - Coefficient in the convex combination used for the computation - of the shrunk estimate. - - Notes - ----- - The regularized (shrunk) covariance is: - - (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) - - where mu = trace(cov) / n_features - """ - X = check_array(X) - # for only one feature, the result is the same whatever the shrinkage - if len(X.shape) == 2 and X.shape[1] == 1: - return 0.0 - if X.ndim == 1: - X = np.reshape(X, (1, -1)) - - if X.shape[0] == 1: - warnings.warn( - "Only one sample available. You may want to reshape your data array" - ) - n_samples, n_features = X.shape - - # optionally center data - if not assume_centered: - X = X - X.mean(0) - - # A non-blocked version of the computation is present in the tests - # in tests/test_covariance.py - - # number of blocks to split the covariance matrix into - n_splits = int(n_features / block_size) - X2 = X**2 - emp_cov_trace = np.sum(X2, axis=0) / n_samples - mu = np.sum(emp_cov_trace) / n_features - beta_ = 0.0 # sum of the coefficients of - delta_ = 0.0 # sum of the *squared* coefficients of - # starting block computation - for i in range(n_splits): - for j in range(n_splits): - rows = slice(block_size * i, block_size * (i + 1)) - cols = slice(block_size * j, block_size * (j + 1)) - beta_ += np.sum(np.dot(X2.T[rows], X2[:, cols])) - delta_ += np.sum(np.dot(X.T[rows], X[:, cols]) ** 2) - rows = slice(block_size * i, block_size * (i + 1)) - beta_ += np.sum(np.dot(X2.T[rows], X2[:, block_size * n_splits :])) - delta_ += np.sum(np.dot(X.T[rows], X[:, block_size * n_splits :]) ** 2) - for j in range(n_splits): - cols = slice(block_size * j, block_size * (j + 1)) - beta_ += np.sum(np.dot(X2.T[block_size * n_splits :], X2[:, cols])) - delta_ += np.sum(np.dot(X.T[block_size * n_splits :], X[:, cols]) ** 2) - delta_ += np.sum( - np.dot(X.T[block_size * n_splits :], X[:, block_size * n_splits :]) ** 2 - ) - delta_ /= n_samples**2 - beta_ += np.sum( - np.dot(X2.T[block_size * n_splits :], X2[:, block_size * n_splits :]) - ) - # use delta_ to compute beta - beta = 1.0 / (n_features * n_samples) * (beta_ / n_samples - delta_) - # delta is the sum of the squared coefficients of ( - mu*Id) / p - delta = delta_ - 2.0 * mu * emp_cov_trace.sum() + n_features * mu**2 - delta /= n_features - # get final beta as the min between beta and delta - # We do this to prevent shrinking more than "1", which would invert - # the value of covariances - beta = min(beta, delta) - # finally get shrinkage - shrinkage = 0 if beta == 0 else beta / delta - return shrinkage - - def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. @@ -324,36 +308,15 @@ def ledoit_wolf(X, *, assume_centered=False, block_size=1000): (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) where mu = trace(cov) / n_features - - Note that the input parameters will not be validated, to optimize execution time. - If you wish to validate the parameters, use the :class:`LedoitWolf` instead. - It will validate the input parameters when calling the method :term:`fit`. """ - X = check_array(X) - # for only one feature, the result is the same whatever the shrinkage - if len(X.shape) == 2 and X.shape[1] == 1: - if not assume_centered: - X = X - X.mean() - return np.atleast_2d((X**2).mean()), 0.0 - if X.ndim == 1: - X = np.reshape(X, (1, -1)) - warnings.warn( - "Only one sample available. You may want to reshape your data array" - ) - n_features = X.size - else: - _, n_features = X.shape + X = as_float_array(X) - # get Ledoit-Wolf shrinkage - shrinkage = ledoit_wolf_shrinkage( - X, assume_centered=assume_centered, block_size=block_size - ) - emp_cov = empirical_covariance(X, assume_centered=assume_centered) - mu = np.sum(np.trace(emp_cov)) / n_features - shrunk_cov = (1.0 - shrinkage) * emp_cov - shrunk_cov.flat[:: n_features + 1] += shrinkage * mu + estimator = LedoitWolf( + assume_centered=assume_centered, + block_size=block_size, + ).fit(X) - return shrunk_cov, shrinkage + return estimator.covariance_, estimator.shrinkage_ class LedoitWolf(EmpiricalCovariance): @@ -488,12 +451,16 @@ def fit(self, X, y=None): # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) X = self._validate_data(X) + if X.shape[0] == 1: + warnings.warn( + "Only one sample available. You may want to reshape your data array" + ) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) else: self.location_ = X.mean(0) with config_context(assume_finite=True): - covariance, shrinkage = ledoit_wolf( + covariance, shrinkage = _ledoit_wolf( X - self.location_, assume_centered=True, block_size=self.block_size ) self.shrinkage_ = shrinkage From fbf52772f30ba336bada7d31c3d13472a240ea53 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 10 Nov 2022 19:58:18 +0000 Subject: [PATCH 06/30] Update _shrunk_covariance.py --- sklearn/covariance/_shrunk_covariance.py | 148 ++++++++++++++--------- 1 file changed, 90 insertions(+), 58 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index bbd8d4a101d11..633bb9ce69ad5 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -23,62 +23,6 @@ from ..utils._param_validation import Interval -def _ledoit_wolf_shrinkage(X, assume_centered, block_size): - """Estimate Ledoit-Wolf shrinkage.""" - # for only one feature, the result is the same whatever the shrinkage - if len(X.shape) == 2 and X.shape[1] == 1: - return 0.0 - n_samples, n_features = X.shape - - # optionally center data - if not assume_centered: - X = X - X.mean(0) - - # A non-blocked version of the computation is present in the tests - # in tests/test_covariance.py - - # number of blocks to split the covariance matrix into - n_splits = int(n_features / block_size) - X2 = X**2 - emp_cov_trace = np.sum(X2, axis=0) / n_samples - mu = np.sum(emp_cov_trace) / n_features - beta_ = 0.0 # sum of the coefficients of - delta_ = 0.0 # sum of the *squared* coefficients of - # starting block computation - for i in range(n_splits): - for j in range(n_splits): - rows = slice(block_size * i, block_size * (i + 1)) - cols = slice(block_size * j, block_size * (j + 1)) - beta_ += np.sum(np.dot(X2.T[rows], X2[:, cols])) - delta_ += np.sum(np.dot(X.T[rows], X[:, cols]) ** 2) - rows = slice(block_size * i, block_size * (i + 1)) - beta_ += np.sum(np.dot(X2.T[rows], X2[:, block_size * n_splits :])) - delta_ += np.sum(np.dot(X.T[rows], X[:, block_size * n_splits :]) ** 2) - for j in range(n_splits): - cols = slice(block_size * j, block_size * (j + 1)) - beta_ += np.sum(np.dot(X2.T[block_size * n_splits :], X2[:, cols])) - delta_ += np.sum(np.dot(X.T[block_size * n_splits :], X[:, cols]) ** 2) - delta_ += np.sum( - np.dot(X.T[block_size * n_splits :], X[:, block_size * n_splits :]) ** 2 - ) - delta_ /= n_samples**2 - beta_ += np.sum( - np.dot(X2.T[block_size * n_splits :], X2[:, block_size * n_splits :]) - ) - # use delta_ to compute beta - beta = 1.0 / (n_features * n_samples) * (beta_ / n_samples - delta_) - # delta is the sum of the squared coefficients of ( - mu*Id) / p - delta = delta_ - 2.0 * mu * emp_cov_trace.sum() + n_features * mu**2 - delta /= n_features - # get final beta as the min between beta and delta - # We do this to prevent shrinking more than "1", which would invert - # the value of covariances - beta = min(beta, delta) - # finally get shrinkage - shrinkage = 0 if beta == 0 else beta / delta - return shrinkage - - def _ledoit_wolf(X, *, assume_centered, block_size): """Estimate the shrunk Ledoit-Wolf covariance matrix.""" # for only one feature, the result is the same whatever the shrinkage @@ -89,7 +33,7 @@ def _ledoit_wolf(X, *, assume_centered, block_size): n_features = X.shape[1] # get Ledoit-Wolf shrinkage - shrinkage = _ledoit_wolf_shrinkage( + shrinkage = ledoit_wolf_shrinkage( X, assume_centered=assume_centered, block_size=block_size ) emp_cov = empirical_covariance(X, assume_centered=assume_centered) @@ -272,6 +216,93 @@ def fit(self, X, y=None): # Ledoit-Wolf estimator +def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): + """Estimate the shrunk Ledoit-Wolf covariance matrix. + Read more in the :ref:`User Guide `. + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Data from which to compute the Ledoit-Wolf shrunk covariance shrinkage. + assume_centered : bool, default=False + If True, data will not be centered before computation. + Useful to work with data whose mean is significantly equal to + zero but is not exactly zero. + If False, data will be centered before computation. + block_size : int, default=1000 + Size of blocks into which the covariance matrix will be split. + Returns + ------- + shrinkage : float + Coefficient in the convex combination used for the computation + of the shrunk estimate. + Notes + ----- + The regularized (shrunk) covariance is: + (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) + where mu = trace(cov) / n_features + """ + X = check_array(X) + # for only one feature, the result is the same whatever the shrinkage + if len(X.shape) == 2 and X.shape[1] == 1: + return 0.0 + if X.ndim == 1: + X = np.reshape(X, (1, -1)) + + if X.shape[0] == 1: + warnings.warn( + "Only one sample available. You may want to reshape your data array" + ) + n_samples, n_features = X.shape + + # optionally center data + if not assume_centered: + X = X - X.mean(0) + + # A non-blocked version of the computation is present in the tests + # in tests/test_covariance.py + + # number of blocks to split the covariance matrix into + n_splits = int(n_features / block_size) + X2 = X**2 + emp_cov_trace = np.sum(X2, axis=0) / n_samples + mu = np.sum(emp_cov_trace) / n_features + beta_ = 0.0 # sum of the coefficients of + delta_ = 0.0 # sum of the *squared* coefficients of + # starting block computation + for i in range(n_splits): + for j in range(n_splits): + rows = slice(block_size * i, block_size * (i + 1)) + cols = slice(block_size * j, block_size * (j + 1)) + beta_ += np.sum(np.dot(X2.T[rows], X2[:, cols])) + delta_ += np.sum(np.dot(X.T[rows], X[:, cols]) ** 2) + rows = slice(block_size * i, block_size * (i + 1)) + beta_ += np.sum(np.dot(X2.T[rows], X2[:, block_size * n_splits :])) + delta_ += np.sum(np.dot(X.T[rows], X[:, block_size * n_splits :]) ** 2) + for j in range(n_splits): + cols = slice(block_size * j, block_size * (j + 1)) + beta_ += np.sum(np.dot(X2.T[block_size * n_splits :], X2[:, cols])) + delta_ += np.sum(np.dot(X.T[block_size * n_splits :], X[:, cols]) ** 2) + delta_ += np.sum( + np.dot(X.T[block_size * n_splits :], X[:, block_size * n_splits :]) ** 2 + ) + delta_ /= n_samples**2 + beta_ += np.sum( + np.dot(X2.T[block_size * n_splits :], X2[:, block_size * n_splits :]) + ) + # use delta_ to compute beta + beta = 1.0 / (n_features * n_samples) * (beta_ / n_samples - delta_) + # delta is the sum of the squared coefficients of ( - mu*Id) / p + delta = delta_ - 2.0 * mu * emp_cov_trace.sum() + n_features * mu**2 + delta /= n_features + # get final beta as the min between beta and delta + # We do this to prevent shrinking more than "1", which would invert + # the value of covariances + beta = min(beta, delta) + # finally get shrinkage + shrinkage = 0 if beta == 0 else beta / delta + return shrinkage + + def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. @@ -450,8 +481,9 @@ def fit(self, X, y=None): self._validate_params() # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) - X = self._validate_data(X) + X = self._validate_data(X, ensure_2d=False) if X.shape[0] == 1: + X = X.reshape(1, -1) warnings.warn( "Only one sample available. You may want to reshape your data array" ) From 082b8c7c5c6eef6a4a0ea25f4762332f48fd5e16 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 11 Nov 2022 15:32:28 +0000 Subject: [PATCH 07/30] Update _shrunk_covariance.py --- sklearn/covariance/_shrunk_covariance.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 633bb9ce69ad5..c5592e77ff435 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -340,11 +340,10 @@ def ledoit_wolf(X, *, assume_centered=False, block_size=1000): where mu = trace(cov) / n_features """ - X = as_float_array(X) - estimator = LedoitWolf( assume_centered=assume_centered, block_size=block_size, + store_precision=False, ).fit(X) return estimator.covariance_, estimator.shrinkage_ @@ -482,6 +481,11 @@ def fit(self, X, y=None): # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) X = self._validate_data(X, ensure_2d=False) + if X.ndim == 1: + X = np.reshape(X, (1, -1)) + warnings.warn( + "Only one sample available. You may want to reshape your data array" + ) if X.shape[0] == 1: X = X.reshape(1, -1) warnings.warn( From 8a8a17b3d976de65cfc9503ec0e5d19994918fd5 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 11 Nov 2022 15:36:24 +0000 Subject: [PATCH 08/30] lint --- sklearn/covariance/_shrunk_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index c5592e77ff435..5d55df2d70fdb 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -19,7 +19,7 @@ from . import empirical_covariance, EmpiricalCovariance from .._config import config_context -from ..utils import check_array, as_float_array +from ..utils import check_array from ..utils._param_validation import Interval From 8503a881845143cf46def1d8cca29fa1903c02c2 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 11 Nov 2022 16:20:37 +0000 Subject: [PATCH 09/30] Update _shrunk_covariance.py --- sklearn/covariance/_shrunk_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 5d55df2d70fdb..fef8de1c5f79c 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -480,7 +480,7 @@ def fit(self, X, y=None): self._validate_params() # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) - X = self._validate_data(X, ensure_2d=False) + X = self._validate_data(X) if X.ndim == 1: X = np.reshape(X, (1, -1)) warnings.warn( From dbfdf44b5eea2c5336dd703fe95fbf8afebc1426 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 11 Nov 2022 17:41:44 +0000 Subject: [PATCH 10/30] fix docs --- sklearn/covariance/_shrunk_covariance.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index fef8de1c5f79c..8e4435458020b 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -218,27 +218,35 @@ def fit(self, X, y=None): def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. + Read more in the :ref:`User Guide `. + Parameters ---------- X : array-like of shape (n_samples, n_features) Data from which to compute the Ledoit-Wolf shrunk covariance shrinkage. + assume_centered : bool, default=False If True, data will not be centered before computation. Useful to work with data whose mean is significantly equal to zero but is not exactly zero. If False, data will be centered before computation. + block_size : int, default=1000 Size of blocks into which the covariance matrix will be split. + Returns ------- shrinkage : float Coefficient in the convex combination used for the computation of the shrunk estimate. + Notes ----- The regularized (shrunk) covariance is: + (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features) + where mu = trace(cov) / n_features """ X = check_array(X) From c89c581cb5ca14fb4dc53a6822793022d1760d19 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Sun, 13 Nov 2022 19:35:54 +0000 Subject: [PATCH 11/30] added to public tests --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..61b59d5f88020 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.covariance.ledoit_wolf", ] From 8d322e1ecf3703457b3852547cd6817f3d41447d Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Mon, 14 Nov 2022 16:00:00 +0300 Subject: [PATCH 12/30] revert --- sklearn/tests/test_public_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 61b59d5f88020..d4e645c052dab 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,7 +10,6 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", - "sklearn.covariance.ledoit_wolf", ] From 1a3409988353b9a705bd4069af894149ece72360 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:29:13 +0000 Subject: [PATCH 13/30] Update test_covariance.py --- sklearn/covariance/tests/test_covariance.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 6a9031d0fcb36..e93fdb8abbb49 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -19,6 +19,7 @@ shrunk_covariance, LedoitWolf, ledoit_wolf, + _ledoit_wolf, ledoit_wolf_shrinkage, OAS, oas, @@ -158,6 +159,7 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) + assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X, False, 1000)[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) @@ -177,17 +179,6 @@ def test_ledoit_wolf(): assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) assert_array_almost_equal(empirical_covariance(X_1d), lw.covariance_, 4) - # test with one sample - # warning should be raised when using only 1 sample - X_1sample = np.arange(5).reshape(1, 5) - lw = LedoitWolf() - - warn_msg = "Only one sample available. You may want to reshape your data array" - with pytest.warns(UserWarning, match=warn_msg): - lw.fit(X_1sample) - - assert_array_almost_equal(lw.covariance_, np.zeros(shape=(5, 5), dtype=np.float64)) - # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False) lw.fit(X) From 37c7acc3f0dabafa875538dcafde45593a6c9b1d Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Wed, 16 Nov 2022 16:29:15 +0000 Subject: [PATCH 14/30] Update _shrunk_covariance.py --- sklearn/covariance/_shrunk_covariance.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 8e4435458020b..8a26249dd3c76 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -489,16 +489,6 @@ def fit(self, X, y=None): # Not calling the parent object to fit, to avoid computing the # covariance matrix (and potentially the precision) X = self._validate_data(X) - if X.ndim == 1: - X = np.reshape(X, (1, -1)) - warnings.warn( - "Only one sample available. You may want to reshape your data array" - ) - if X.shape[0] == 1: - X = X.reshape(1, -1) - warnings.warn( - "Only one sample available. You may want to reshape your data array" - ) if self.assume_centered: self.location_ = np.zeros(X.shape[1]) else: From a144812ea52320e96237635913d608b590cbe50e Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Wed, 16 Nov 2022 17:02:07 +0000 Subject: [PATCH 15/30] Update test_covariance.py --- sklearn/covariance/tests/test_covariance.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index e93fdb8abbb49..cce889ea27c03 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -19,7 +19,6 @@ shrunk_covariance, LedoitWolf, ledoit_wolf, - _ledoit_wolf, ledoit_wolf_shrinkage, OAS, oas, @@ -159,7 +158,6 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X, False, 1000)[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) From 020ba7653c1b8e0bd2b123b702b090fd1979f9a9 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 16:26:35 +0000 Subject: [PATCH 16/30] add validate params --- sklearn/covariance/_shrunk_covariance.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 8a26249dd3c76..79c714698761f 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -20,7 +20,7 @@ from . import empirical_covariance, EmpiricalCovariance from .._config import config_context from ..utils import check_array -from ..utils._param_validation import Interval +from ..utils._param_validation import Interval, validate_params def _ledoit_wolf(X, *, assume_centered, block_size): @@ -311,6 +311,13 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): return shrinkage +@validate_params( + { + "X": ["array-like"], + "assume_centered": ["boolean"], + "block_size": [Interval(Integral, 1, None, closed="left")], + } +) def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. From 82e59b985d77eab6b61a2b5e3d817b597090fe00 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 16:26:59 +0000 Subject: [PATCH 17/30] added ledoit_wolf to public tests --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d4e645c052dab..61b59d5f88020 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -10,6 +10,7 @@ PARAM_VALIDATION_FUNCTION_LIST = [ "sklearn.cluster.kmeans_plusplus", + "sklearn.covariance.ledoit_wolf", ] From 4a895f468800524eda992dcc7edc93b878eb9bf4 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 16:28:15 +0000 Subject: [PATCH 18/30] added tests --- sklearn/covariance/tests/test_covariance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index cce889ea27c03..251f2abce53aa 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -23,6 +23,7 @@ OAS, oas, ) +from sklearn.covariance._shrunk_covariance import _ledoit_wolf X, _ = datasets.load_diabetes(return_X_y=True) X_1d = X[:, 0] @@ -158,6 +159,7 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) + assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X)[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) From 8bd6de4aaf89c53285a16c6aaf00b73c6a2fad31 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 16:46:15 +0000 Subject: [PATCH 19/30] added default values --- sklearn/covariance/tests/test_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 251f2abce53aa..4a22fcba29d32 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -159,7 +159,7 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X)[1]) + assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X, False, 10000)[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) From 7f8bad6e2fcb32818d968829dc6208d6d9ab2dbb Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 17:02:23 +0000 Subject: [PATCH 20/30] Update test_covariance.py --- sklearn/covariance/tests/test_covariance.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 4a22fcba29d32..2b2be05aa2196 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -159,7 +159,11 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert_almost_equal(lw.shrinkage_, _ledoit_wolf(X, False, 10000)[1]) + assert_almost_equal(lw.shrinkage_, _ledoit_wolf( + X=X, + assume_centered=False, + block_size=10000 + )[1]) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) From 1a787f63e2b6768d951467d3c8d022777f3ff8ee Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Thu, 17 Nov 2022 17:38:40 +0000 Subject: [PATCH 21/30] linting --- sklearn/covariance/tests/test_covariance.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 2b2be05aa2196..a1e5063099f7f 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -159,11 +159,9 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert_almost_equal(lw.shrinkage_, _ledoit_wolf( - X=X, - assume_centered=False, - block_size=10000 - )[1]) + assert_almost_equal( + lw.shrinkage_, _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1] + ) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X) From 2a63033b225703bf71e8ca1ab2eed3fe337f0cac Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 18 Nov 2022 07:34:28 +0300 Subject: [PATCH 22/30] removed context manager --- sklearn/covariance/_shrunk_covariance.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 79c714698761f..2874e8716a0a0 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -500,10 +500,9 @@ def fit(self, X, y=None): self.location_ = np.zeros(X.shape[1]) else: self.location_ = X.mean(0) - with config_context(assume_finite=True): - covariance, shrinkage = _ledoit_wolf( - X - self.location_, assume_centered=True, block_size=self.block_size - ) + covariance, shrinkage = _ledoit_wolf( + X - self.location_, assume_centered=True, block_size=self.block_size + ) self.shrinkage_ = shrinkage self._set_covariance(covariance) From b01c33fde8ba2440ff0ba250e1efcc99419f95d9 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 18 Nov 2022 08:07:47 +0300 Subject: [PATCH 23/30] Update _shrunk_covariance.py --- sklearn/covariance/_shrunk_covariance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index 2874e8716a0a0..de85ef6b85a99 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -18,7 +18,6 @@ import numpy as np from . import empirical_covariance, EmpiricalCovariance -from .._config import config_context from ..utils import check_array from ..utils._param_validation import Interval, validate_params From 63191a6dd9b8e6bcd3786dbc6640eafa8a0fe6e3 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:17:27 +0000 Subject: [PATCH 24/30] added test --- sklearn/covariance/tests/test_covariance.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index a1e5063099f7f..4fce1f3f8e201 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -144,7 +144,7 @@ def test_ledoit_wolf(): lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d, assume_centered=True) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) - assert_array_almost_equal((X_1d**2).sum() / n_samples, lw.covariance_, 4) + assert_array_almost_equal((X_1d**2).mean(), lw.covariance_, 4) # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False, assume_centered=True) @@ -176,6 +176,10 @@ def test_ledoit_wolf(): X_1d = X[:, 0].reshape((-1, 1)) lw = LedoitWolf() lw.fit(X_1d) + assert_almost_equal( + X_1d.var(ddof=0), + _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[1] + ) lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) From 76f71fbeffecb6ee374b571cf3855f640492ca50 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 18 Nov 2022 16:59:55 +0000 Subject: [PATCH 25/30] lint --- sklearn/covariance/tests/test_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 4fce1f3f8e201..d809c365bb7c3 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -178,7 +178,7 @@ def test_ledoit_wolf(): lw.fit(X_1d) assert_almost_equal( X_1d.var(ddof=0), - _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[1] + _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[1], ) lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) From 8baa450658bebe7598134c86e44d280b859e1d75 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Fri, 18 Nov 2022 18:00:56 +0000 Subject: [PATCH 26/30] Update test_covariance.py --- sklearn/covariance/tests/test_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index d809c365bb7c3..8ddaadc9eae8c 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -178,7 +178,7 @@ def test_ledoit_wolf(): lw.fit(X_1d) assert_almost_equal( X_1d.var(ddof=0), - _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[1], + _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[0], ) lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) From 9d5dcb0b687f7641f43ac529520424efbf868bd6 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:30:11 +0000 Subject: [PATCH 27/30] Update sklearn/covariance/tests/test_covariance.py Co-authored-by: Guillaume Lemaitre --- sklearn/covariance/tests/test_covariance.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 8ddaadc9eae8c..88ef76a9bc0a8 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -159,8 +159,8 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert_almost_equal( - lw.shrinkage_, _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1] + assert lw.shrinkage_ == pytest.approx( + _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1] ) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate From da2476a59855726989c5312b3a2bb58b14a9eec3 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:30:44 +0000 Subject: [PATCH 28/30] Update sklearn/covariance/tests/test_covariance.py Co-authored-by: Guillaume Lemaitre --- sklearn/covariance/tests/test_covariance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 88ef76a9bc0a8..d121d5c43ce8b 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -176,7 +176,7 @@ def test_ledoit_wolf(): X_1d = X[:, 0].reshape((-1, 1)) lw = LedoitWolf() lw.fit(X_1d) - assert_almost_equal( + assert_allclose( X_1d.var(ddof=0), _ledoit_wolf(X=X_1d, assume_centered=False, block_size=10000)[0], ) From 27039d30de40b83e56d14adbfa1b9d1052b4dc80 Mon Sep 17 00:00:00 2001 From: Gleb Levitski <36483986+glevv@users.noreply.github.com> Date: Sun, 20 Nov 2022 13:38:10 +0000 Subject: [PATCH 29/30] Update test_covariance.py --- sklearn/covariance/tests/test_covariance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index d121d5c43ce8b..1541a18cf9bb8 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -7,6 +7,7 @@ import numpy as np import pytest +from numpy.testing import assert_allclose from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal From 769e28a3b406a6d4c9373b1aebcd9278d2ef25d1 Mon Sep 17 00:00:00 2001 From: jeremiedbb Date: Tue, 27 Dec 2022 18:52:35 +0100 Subject: [PATCH 30/30] partial validation for class wrappers --- sklearn/covariance/_shrunk_covariance.py | 8 +------- sklearn/covariance/tests/test_covariance.py | 19 +++++++++++++++---- sklearn/tests/test_public_functions.py | 2 +- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/sklearn/covariance/_shrunk_covariance.py b/sklearn/covariance/_shrunk_covariance.py index de85ef6b85a99..bba1a6216bb56 100644 --- a/sklearn/covariance/_shrunk_covariance.py +++ b/sklearn/covariance/_shrunk_covariance.py @@ -310,13 +310,7 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000): return shrinkage -@validate_params( - { - "X": ["array-like"], - "assume_centered": ["boolean"], - "block_size": [Interval(Integral, 1, None, closed="left")], - } -) +@validate_params({"X": ["array-like"]}) def ledoit_wolf(X, *, assume_centered=False, block_size=1000): """Estimate the shrunk Ledoit-Wolf covariance matrix. diff --git a/sklearn/covariance/tests/test_covariance.py b/sklearn/covariance/tests/test_covariance.py index 1541a18cf9bb8..fb4eeb26138df 100644 --- a/sklearn/covariance/tests/test_covariance.py +++ b/sklearn/covariance/tests/test_covariance.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from numpy.testing import assert_allclose +from sklearn.utils._testing import assert_allclose from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal @@ -145,7 +145,7 @@ def test_ledoit_wolf(): lw_cov_from_mle, lw_shrinkage_from_mle = ledoit_wolf(X_1d, assume_centered=True) assert_array_almost_equal(lw_cov_from_mle, lw.covariance_, 4) assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) - assert_array_almost_equal((X_1d**2).mean(), lw.covariance_, 4) + assert_array_almost_equal((X_1d**2).sum() / n_samples, lw.covariance_, 4) # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False, assume_centered=True) @@ -160,8 +160,8 @@ def test_ledoit_wolf(): assert_almost_equal(lw.shrinkage_, shrinkage_, 4) assert_almost_equal(lw.shrinkage_, ledoit_wolf_shrinkage(X)) assert_almost_equal(lw.shrinkage_, ledoit_wolf(X)[1]) - assert lw.shrinkage_ == pytest.approx( - _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1] + assert_almost_equal( + lw.shrinkage_, _ledoit_wolf(X=X, assume_centered=False, block_size=10000)[1] ) assert_almost_equal(lw.score(X), score_, 4) # compare shrunk covariance obtained from data and from MLE estimate @@ -186,6 +186,17 @@ def test_ledoit_wolf(): assert_almost_equal(lw_shrinkage_from_mle, lw.shrinkage_) assert_array_almost_equal(empirical_covariance(X_1d), lw.covariance_, 4) + # test with one sample + # warning should be raised when using only 1 sample + X_1sample = np.arange(5).reshape(1, 5) + lw = LedoitWolf() + + warn_msg = "Only one sample available. You may want to reshape your data array" + with pytest.warns(UserWarning, match=warn_msg): + lw.fit(X_1sample) + + assert_array_almost_equal(lw.covariance_, np.zeros(shape=(5, 5), dtype=np.float64)) + # test shrinkage coeff on a simple data set (without saving precision) lw = LedoitWolf(store_precision=False) lw.fit(X) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 91adf35df8819..de1bded95c17b 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -107,7 +107,6 @@ def _check_function_param_validation( "sklearn.metrics.zero_one_loss", "sklearn.model_selection.train_test_split", "sklearn.svm.l1_min_c", - "sklearn.covariance.ledoit_wolf", ] @@ -127,6 +126,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"), + ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ]