Skip to content

More systematic checks that an estimator was fit before using its parameters #267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Dec 4, 2019
12 changes: 10 additions & 2 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,14 @@ def transform(self, X):
X_embedded : `numpy.ndarray`, shape=(n_samples, n_components)
The embedded data points.
"""
check_is_fitted(self, ['preprocessor_', 'components_'])
X_checked = check_input(X, type_of_inputs='classic', estimator=self,
preprocessor=self.preprocessor_,
accept_sparse=True)
return X_checked.dot(self.components_.T)

def get_metric(self):
check_is_fitted(self, 'components_')
components_T = self.components_.T.copy()

def metric_fun(u, v, squared=False):
Expand Down Expand Up @@ -298,6 +300,7 @@ def get_mahalanobis_matrix(self):
M : `numpy.ndarray`, shape=(n_features, n_features)
The copy of the learned Mahalanobis matrix.
"""
check_is_fitted(self, 'components_')
return self.components_.T.dot(self.components_)


Expand Down Expand Up @@ -333,7 +336,10 @@ def predict(self, pairs):
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
The predicted learned metric value between samples in every pair.
"""
check_is_fitted(self, ['threshold_', 'components_'])
if "threshold_" not in vars(self):
msg = ("A threshold for this estimator has not been set,"
"call its set_threshold or calibrate_threshold method.")
raise AttributeError(msg)
return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1

def decision_function(self, pairs):
Expand All @@ -357,6 +363,7 @@ def decision_function(self, pairs):
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
The predicted decision function value for each pair.
"""
check_is_fitted(self, 'preprocessor_')
pairs = check_input(pairs, type_of_inputs='tuples',
preprocessor=self.preprocessor_,
estimator=self, tuple_size=self._tuple_size)
Expand Down Expand Up @@ -599,7 +606,7 @@ def predict(self, quadruplets):
prediction : `numpy.ndarray` of floats, shape=(n_constraints,)
Predictions of the ordering of pairs, for each quadruplet.
"""
check_is_fitted(self, 'components_')
check_is_fitted(self, 'preprocessor_')
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
preprocessor=self.preprocessor_,
estimator=self, tuple_size=self._tuple_size)
Expand Down Expand Up @@ -628,6 +635,7 @@ def decision_function(self, quadruplets):
decision_function : `numpy.ndarray` of floats, shape=(n_constraints,)
Metric differences.
"""
check_is_fitted(self, 'preprocessor_')
quadruplets = check_input(quadruplets, type_of_inputs='tuples',
preprocessor=self.preprocessor_,
estimator=self, tuple_size=self._tuple_size)
Expand Down
19 changes: 17 additions & 2 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
with pytest.raises(NotFittedError):
estimator.predict(input_data)
estimator.decision_function(input_data)


@pytest.mark.parametrize('calibration_params',
Expand Down Expand Up @@ -133,10 +133,25 @@ def fit(self, pairs, y):
pairs, y = self._prepare_inputs(pairs, y,
type_of_inputs='tuples')
self.components_ = np.atleast_2d(np.identity(pairs.shape[2]))
self.threshold_ = 'I am not set.'
# self.threshold_ is not set.
return self


def test_unset_threshold():
# test that set_threshold indeed sets the threshold
identity_pairs_classifier = IdentityPairsClassifier()
pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]])
y = np.array([1, 1, -1, -1])
identity_pairs_classifier.fit(pairs, y)
with pytest.raises(AttributeError) as e:
identity_pairs_classifier.predict(pairs)

expected_msg = ("A threshold for this estimator has not been set,"
"call its set_threshold or calibrate_threshold method.")

assert str(e.value) == expected_msg


def test_set_threshold():
# test that set_threshold indeed sets the threshold
identity_pairs_classifier = IdentityPairsClassifier()
Expand Down