From 6b789d39b1f9859a7efac9bf7e6b539bb584acbd Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Thu, 14 Nov 2019 17:58:23 +0100 Subject: [PATCH 1/8] maj --- test_components.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 test_components.py diff --git a/test_components.py b/test_components.py new file mode 100644 index 00000000..6edf73db --- /dev/null +++ b/test_components.py @@ -0,0 +1,21 @@ +import numpy as np +import pytest +from numpy.linalg import LinAlgError +from scipy.stats import ortho_group + +rng = np.random.RandomState(42) + +# an orthonormal matrix useful for creating matrices with given +# eigenvalues: +P = ortho_group.rvs(7, random_state=rng) + +# matrix with a determinant still high but which should be considered as a +# non-definite matrix (to check we don't test the definiteness with the +# determinant which is a bad strategy) +M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) +M = P.dot(M).dot(P.T) +assert np.abs(np.linalg.det(M)) > 10 +assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed +# determinant is far from null) +with pytest.raises(LinAlgError) as err_msg: + np.linalg.cholesky(M) From 275c69a8493dfba872aaf2db6dbaad1acbd7c4e0 Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Thu, 28 Nov 2019 17:01:07 +0100 Subject: [PATCH 2/8] added fit checks --- metric_learn/base_metric.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 6feccc72..f238ccd2 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -215,6 +215,7 @@ def score_pairs(self, pairs): :ref:`mahalanobis_distances` : The section of the project documentation that describes Mahalanobis Distances. """ + check_is_fitted(self, 'components_') pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=2) @@ -240,12 +241,15 @@ def transform(self, X): X_embedded : `numpy.ndarray`, shape=(n_samples, n_components) The embedded data points. """ + check_is_fitted(self, 'components_') + X_checked = check_input(X, type_of_inputs='classic', estimator=self, preprocessor=self.preprocessor_, accept_sparse=True) return X_checked.dot(self.components_.T) def get_metric(self): + check_is_fitted(self, 'components_') components_T = self.components_.T.copy() def metric_fun(u, v, squared=False): @@ -285,6 +289,7 @@ def metric(self): """Deprecated. Will be removed in v0.6.0. Use `get_mahalanobis_matrix` instead""" # TODO: remove this method in version 0.6.0 + check_is_fitted(self, 'components_') warnings.warn(("`metric` is deprecated since version 0.5.0 and will be " "removed in 0.6.0. Use `get_mahalanobis_matrix` instead."), DeprecationWarning) @@ -298,6 +303,7 @@ def get_mahalanobis_matrix(self): M : `numpy.ndarray`, shape=(n_features, n_features) The copy of the learned Mahalanobis matrix. """ + check_is_fitted(self, 'components_') return self.components_.T.dot(self.components_) @@ -357,6 +363,7 @@ def decision_function(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted decision function value for each pair. """ + check_is_fitted(self, 'components_') pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -628,6 +635,7 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ + check_is_fitted(self, 'components_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) From 1c28b5663dc5f6e9d2f002bda93c4dc9072471cd Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Thu, 28 Nov 2019 17:02:31 +0100 Subject: [PATCH 3/8] maj --- test_components.py | 21 --------------------- 1 file changed, 21 deletions(-) delete mode 100644 test_components.py diff --git a/test_components.py b/test_components.py deleted file mode 100644 index 6edf73db..00000000 --- a/test_components.py +++ /dev/null @@ -1,21 +0,0 @@ -import numpy as np -import pytest -from numpy.linalg import LinAlgError -from scipy.stats import ortho_group - -rng = np.random.RandomState(42) - -# an orthonormal matrix useful for creating matrices with given -# eigenvalues: -P = ortho_group.rvs(7, random_state=rng) - -# matrix with a determinant still high but which should be considered as a -# non-definite matrix (to check we don't test the definiteness with the -# determinant which is a bad strategy) -M = np.diag([1e5, 1e5, 1e5, 1e5, 1e5, 1e5, 1e-20]) -M = P.dot(M).dot(P.T) -assert np.abs(np.linalg.det(M)) > 10 -assert np.linalg.slogdet(M)[1] > 1 # (just to show that the computed -# determinant is far from null) -with pytest.raises(LinAlgError) as err_msg: - np.linalg.cholesky(M) From 76ffccb9e578cf8220177de801a24451810a1b8d Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Thu, 28 Nov 2019 17:09:15 +0100 Subject: [PATCH 4/8] Added checks that the function was fitted. --- metric_learn/base_metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index f238ccd2..707b9d8b 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -242,7 +242,6 @@ def transform(self, X): The embedded data points. """ check_is_fitted(self, 'components_') - X_checked = check_input(X, type_of_inputs='classic', estimator=self, preprocessor=self.preprocessor_, accept_sparse=True) From c6d695f6837d58238a4444dd1ff6c3d08b8e5e24 Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Fri, 29 Nov 2019 10:32:54 +0100 Subject: [PATCH 5/8] check the input before if model is fitted --- metric_learn/base_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 707b9d8b..426a4c55 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -215,10 +215,10 @@ def score_pairs(self, pairs): :ref:`mahalanobis_distances` : The section of the project documentation that describes Mahalanobis Distances. """ - check_is_fitted(self, 'components_') pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=2) + check_is_fitted(self, 'components_') pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :]) # (for MahalanobisMixin, the embedding is linear so we can just embed the # difference) From 0a0a22775778b35009ebf2936115b8307c53d1ad Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Fri, 29 Nov 2019 10:46:43 +0100 Subject: [PATCH 6/8] made more sensible checks. --- metric_learn/base_metric.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 426a4c55..0508be69 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -218,7 +218,6 @@ def score_pairs(self, pairs): pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=2) - check_is_fitted(self, 'components_') pairwise_diffs = self.transform(pairs[:, 1, :] - pairs[:, 0, :]) # (for MahalanobisMixin, the embedding is linear so we can just embed the # difference) @@ -241,7 +240,7 @@ def transform(self, X): X_embedded : `numpy.ndarray`, shape=(n_samples, n_components) The embedded data points. """ - check_is_fitted(self, 'components_') + check_is_fitted(self, ['preprocessor_', 'components_']) X_checked = check_input(X, type_of_inputs='classic', estimator=self, preprocessor=self.preprocessor_, accept_sparse=True) @@ -288,7 +287,6 @@ def metric(self): """Deprecated. Will be removed in v0.6.0. Use `get_mahalanobis_matrix` instead""" # TODO: remove this method in version 0.6.0 - check_is_fitted(self, 'components_') warnings.warn(("`metric` is deprecated since version 0.5.0 and will be " "removed in 0.6.0. Use `get_mahalanobis_matrix` instead."), DeprecationWarning) @@ -338,7 +336,6 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ - check_is_fitted(self, ['threshold_', 'components_']) return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 def decision_function(self, pairs): @@ -362,7 +359,7 @@ def decision_function(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted decision function value for each pair. """ - check_is_fitted(self, 'components_') + check_is_fitted(self, 'preprocessor_') pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -605,7 +602,7 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ - check_is_fitted(self, 'components_') + check_is_fitted(self, 'preprocessor_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -634,7 +631,7 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ - check_is_fitted(self, 'components_') + check_is_fitted(self, 'preprocessor_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) From 658ef4a8c34063a7cb6302335945837cc4e85109 Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Fri, 29 Nov 2019 11:12:16 +0100 Subject: [PATCH 7/8] added a test for a threshold --- metric_learn/base_metric.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 0508be69..b8be13b5 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -336,6 +336,10 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ + if "threshold_" not in vars(self): + msg = ("A threshold for this estimator has not been set," + "call the set_threshold or calibrate_threshold method.") + raise AttributeError(msg) return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 def decision_function(self, pairs): From f6a1c96494f88c7eab7b69692534e1365629cc4a Mon Sep 17 00:00:00 2001 From: RobinVogel Date: Fri, 29 Nov 2019 11:32:24 +0100 Subject: [PATCH 8/8] added a test for the unset threshold --- metric_learn/base_metric.py | 2 +- test/test_pairs_classifiers.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index b8be13b5..427fcf86 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -338,7 +338,7 @@ def predict(self, pairs): """ if "threshold_" not in vars(self): msg = ("A threshold for this estimator has not been set," - "call the set_threshold or calibrate_threshold method.") + "call its set_threshold or calibrate_threshold method.") raise AttributeError(msg) return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index affc70f6..840cd151 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -73,7 +73,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) with pytest.raises(NotFittedError): - estimator.predict(input_data) + estimator.decision_function(input_data) @pytest.mark.parametrize('calibration_params', @@ -133,10 +133,25 @@ def fit(self, pairs, y): pairs, y = self._prepare_inputs(pairs, y, type_of_inputs='tuples') self.components_ = np.atleast_2d(np.identity(pairs.shape[2])) - self.threshold_ = 'I am not set.' + # self.threshold_ is not set. return self +def test_unset_threshold(): + # test that set_threshold indeed sets the threshold + identity_pairs_classifier = IdentityPairsClassifier() + pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]]) + y = np.array([1, 1, -1, -1]) + identity_pairs_classifier.fit(pairs, y) + with pytest.raises(AttributeError) as e: + identity_pairs_classifier.predict(pairs) + + expected_msg = ("A threshold for this estimator has not been set," + "call its set_threshold or calibrate_threshold method.") + + assert str(e.value) == expected_msg + + def test_set_threshold(): # test that set_threshold indeed sets the threshold identity_pairs_classifier = IdentityPairsClassifier()