From c99fef2b42294f817c7bde63f3f7849ceb7b8d46 Mon Sep 17 00:00:00 2001 From: William de Vazelhes Date: Mon, 29 Apr 2019 15:28:11 +0200 Subject: [PATCH] Add checks for labels when having pairs --- metric_learn/_util.py | 12 +++++++ test/test_utils.py | 73 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/metric_learn/_util.py b/metric_learn/_util.py index 33311620..ff9c021c 100644 --- a/metric_learn/_util.py +++ b/metric_learn/_util.py @@ -137,6 +137,11 @@ def check_input(input_data, y=None, preprocessor=None, input_data = check_input_tuples(input_data, context, preprocessor, args_for_sk_checks, tuple_size) + # if we have y and the input data are pairs, we need to ensure + # the labels are in [-1, 1]: + if y is not None and input_data.shape[1] == 2: + check_y_valid_values_for_pairs(y) + else: raise ValueError("Unknown value {} for type_of_inputs. Valid values are " "'classic' or 'tuples'.".format(type_of_inputs)) @@ -297,6 +302,13 @@ def check_tuple_size(tuples, tuple_size, context): raise ValueError(msg_t) +def check_y_valid_values_for_pairs(y): + """Checks that y values are in [-1, 1]""" + if not np.array_equal(np.abs(y), np.ones_like(y)): + raise ValueError("When training on pairs, the labels (y) should contain " + "only values in [-1, 1]. Found an incorrect value.") + + class ArrayIndexer: def __init__(self, X): diff --git a/test/test_utils.py b/test/test_utils.py index 52ebc7a6..4cec7444 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,7 +10,8 @@ from metric_learn._util import (check_input, make_context, preprocess_tuples, make_name, preprocess_points, check_collapsed_pairs, validate_vector, - _check_sdp_from_eigen) + _check_sdp_from_eigen, + check_y_valid_values_for_pairs) from metric_learn import (ITML, LSML, MMC, RCA, SDML, Covariance, LFDA, LMNN, MLKR, NCA, ITML_Supervised, LSML_Supervised, MMC_Supervised, RCA_Supervised, SDML_Supervised, @@ -1067,3 +1068,73 @@ def test_check_sdp_from_eigen_positive_err_messages(): _check_sdp_from_eigen(w, 1.) _check_sdp_from_eigen(w, 0.) _check_sdp_from_eigen(w, None) + + +@pytest.mark.unit +@pytest.mark.parametrize('wrong_labels', + [[0.5, 0.6, 0.7, 0.8, 0.9], + np.random.RandomState(42).randn(5), + np.random.RandomState(42).choice([0, 1], size=5)]) +def test_check_y_valid_values_for_pairs(wrong_labels): + expected_msg = ("When training on pairs, the labels (y) should contain " + "only values in [-1, 1]. Found an incorrect value.") + with pytest.raises(ValueError) as raised_error: + check_y_valid_values_for_pairs(wrong_labels) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.integration +@pytest.mark.parametrize('wrong_labels', + [[0.5, 0.6, 0.7, 0.8, 0.9], + np.random.RandomState(42).randn(5), + np.random.RandomState(42).choice([0, 1], size=5)]) +def test_check_input_invalid_tuples_without_preprocessor(wrong_labels): + pairs = np.random.RandomState(42).randn(5, 2, 3) + expected_msg = ("When training on pairs, the labels (y) should contain " + "only values in [-1, 1]. Found an incorrect value.") + with pytest.raises(ValueError) as raised_error: + check_input(pairs, wrong_labels, preprocessor=None, + type_of_inputs='tuples') + assert str(raised_error.value) == expected_msg + + +@pytest.mark.integration +@pytest.mark.parametrize('wrong_labels', + [[0.5, 0.6, 0.7, 0.8, 0.9], + np.random.RandomState(42).randn(5), + np.random.RandomState(42).choice([0, 1], size=5)]) +def test_check_input_invalid_tuples_with_preprocessor(wrong_labels): + n_samples, n_features, n_pairs = 10, 4, 5 + rng = np.random.RandomState(42) + pairs = rng.randint(10, size=(n_pairs, 2)) + preprocessor = rng.randn(n_samples, n_features) + expected_msg = ("When training on pairs, the labels (y) should contain " + "only values in [-1, 1]. Found an incorrect value.") + with pytest.raises(ValueError) as raised_error: + check_input(pairs, wrong_labels, preprocessor=ArrayIndexer(preprocessor), + type_of_inputs='tuples') + assert str(raised_error.value) == expected_msg + + +@pytest.mark.integration +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_check_input_pairs_learners_invalid_y(estimator, build_dataset, + with_preprocessor): + """checks that the only allowed labels for learning pairs are +1 and -1""" + input_data, labels, _, X = build_dataset() + wrong_labels_list = [labels + 0.5, + np.random.RandomState(42).randn(len(labels)), + np.random.RandomState(42).choice([0, 1], + size=len(labels))] + model = clone(estimator) + set_random_state(model) + + expected_msg = ("When training on pairs, the labels (y) should contain " + "only values in [-1, 1]. Found an incorrect value.") + + for wrong_labels in wrong_labels_list: + with pytest.raises(ValueError) as raised_error: + model.fit(input_data, wrong_labels) + assert str(raised_error.value) == expected_msg