diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 856591cb..14acbf7c 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -618,6 +618,9 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ + quadruplets = check_input(quadruplets, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=self._tuple_size) return (self.score_pairs(quadruplets[:, 2:]) - self.score_pairs(quadruplets[:, :2])) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 0c0f098d..4c511263 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -105,6 +105,70 @@ def stable_init(self, n_components=None, pca_comps=None, # ---------------------- Test scikit-learn compatibility ---------------------- +def generate_array_like(input_data, labels=None): + """Helper function to generate array-like variants of numpy datasets, + for testing purposes.""" + list_data = input_data.tolist() + input_data_changed = [input_data, list_data, tuple(list_data)] + if input_data.ndim >= 2: + input_data_changed.append(tuple(tuple(x) for x in list_data)) + if input_data.ndim >= 3: + input_data_changed.append(tuple(tuple(tuple(x) for x in y) for y in + list_data)) + if input_data.ndim == 2: + pd = pytest.importorskip('pandas') + input_data_changed.append(pd.DataFrame(input_data)) + if labels is not None: + labels_changed = [labels, list(labels), tuple(labels)] + else: + labels_changed = [labels] + return input_data_changed, labels_changed + + +@pytest.mark.integration +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_array_like_inputs(estimator, build_dataset, with_preprocessor): + """Test that metric-learners can have as input (of all functions that are + applied on data) any array-like object.""" + input_data, labels, preprocessor, X = build_dataset(with_preprocessor) + + # we subsample the data for the test to be more efficient + input_data, _, labels, _ = train_test_split(input_data, labels, + train_size=20) + X = X[:10] + + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + input_variants, label_variants = generate_array_like(input_data, labels) + for input_variant in input_variants: + for label_variant in label_variants: + estimator.fit(*remove_y_quadruplets(estimator, input_variant, + label_variant)) + if hasattr(estimator, "predict"): + estimator.predict(input_variant) + if hasattr(estimator, "predict_proba"): + estimator.predict_proba(input_variant) # anticipation in case some + # time we have that, or if ppl want to contribute with new algorithms + # it will be checked automatically + if hasattr(estimator, "decision_function"): + estimator.decision_function(input_variant) + if hasattr(estimator, "score"): + for label_variant in label_variants: + estimator.score(*remove_y_quadruplets(estimator, input_variant, + label_variant)) + + X_variants, _ = generate_array_like(X) + for X_variant in X_variants: + estimator.transform(X_variant) + + pairs = np.array([[X[0], X[1]], [X[0], X[2]]]) + pairs_variants, _ = generate_array_like(pairs) + for pairs_variant in pairs_variants: + estimator.score_pairs(pairs_variant) + @pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', pairs_learners, diff --git a/test/test_utils.py b/test/test_utils.py index 2e57f489..970b40a1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -118,7 +118,7 @@ def build_quadruplets(with_preprocessor=False): (ITML_Supervised(max_iter=5), build_classification), (LSML_Supervised(), build_classification), (MMC_Supervised(max_iter=5), build_classification), - (RCA_Supervised(num_chunks=10), build_classification), + (RCA_Supervised(num_chunks=5), build_classification), (SDML_Supervised(prior='identity', balance_param=1e-5), build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__,