diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 5fe2ca14..0978b17b 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -45,6 +45,28 @@ def transform(self, X=None): X = self.X L = self.transformer() return X.dot(L.T) + + def fit_transform(self, *args, **kwargs): + """ + Function calls .fit() and returns the result of .transform() + Essentially, it runs the relevant Metric Learning algorithm with .fit() + and returns the metric-transformed input data. + + Paramters + --------- + + Since all the parameters passed to fit_transform are passed on to + fit(), the parameters to be passed must be noted from the corresponding + Metric Learning algorithm's fit method. + + Returns + ------- + transformed : (n x d) matrix + Input data transformed to the metric space by :math:`XL^{\\top}` + + """ + self.fit(*args, **kwargs) + return self.transform() def get_params(self, deep=False): """Get parameters for this metric learner. diff --git a/metric_learn/itml.py b/metric_learn/itml.py index 19e5bb71..6a6fcf04 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -183,5 +183,5 @@ def fit(self, X, labels, random_state=np.random): num_constraints = 20*(len(num_classes))**2 c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state) - return ITML.fit(self, X, c.positive_negative_pairs(num_constraints), + return ITML.fit(self, X, c.positive_negative_pairs(num_constraints, random_state=random_state), bounds=self.params['bounds'], A0=self.params['A0']) diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 077cdd5d..343c0b7f 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -172,6 +172,6 @@ def fit(self, X, labels, random_state=np.random): num_constraints = 20*(len(num_classes))**2 c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state) - pairs = c.positive_negative_pairs(num_constraints, same_length=True) + pairs = c.positive_negative_pairs(num_constraints, same_length=True, random_state=random_state) return LSML.fit(self, X, pairs, weights=self.params['weights'], prior=self.params['prior']) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 474f2502..852b00f3 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -106,4 +106,4 @@ def fit(self, X, labels, random_state=np.random): num_constraints = 20*(len(num_classes))**2 c = Constraints.random_subset(labels, self.params['num_labeled'], random_state=random_state) - return SDML.fit(self, X, c.adjacency_matrix(num_constraints)) + return SDML.fit(self, X, c.adjacency_matrix(num_constraints, random_state=random_state)) diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py new file mode 100644 index 00000000..a25511ce --- /dev/null +++ b/test/test_fit_transform.py @@ -0,0 +1,132 @@ +import unittest +import numpy as np +from sklearn.datasets import load_iris +from numpy.testing import assert_array_almost_equal + +from metric_learn import ( + LMNN, NCA, LFDA, Covariance, + LSML_Supervised, ITML_Supervised, SDML_Supervised, RCA_Supervised) + + + +class MetricTestCase(unittest.TestCase): + @classmethod + def setUpClass(self): + # runs once per test class + iris_data = load_iris() + self.iris_points = iris_data['data'] + self.iris_labels = iris_data['target'] + + +class TestCovariance(MetricTestCase): + def test_cov(self): + cov = Covariance() + cov.fit(self.iris_points) + res_1 = cov.transform() + + cov = Covariance() + res_2 = cov.fit_transform(self.iris_points) + # deterministic result + assert_array_almost_equal(res_1, res_2) + + +class TestLSML(MetricTestCase): + def test_lsml(self): + + seed = np.random.RandomState(1234) + lsml = LSML_Supervised(num_constraints=200) + lsml.fit(self.iris_points, self.iris_labels, random_state=seed) + res_1 = lsml.transform() + + seed = np.random.RandomState(1234) + lsml = LSML_Supervised(num_constraints=200) + res_2 = lsml.fit_transform(self.iris_points, self.iris_labels, random_state=seed) + + assert_array_almost_equal(res_1, res_2) + +class TestITML(MetricTestCase): + def test_itml(self): + + seed = np.random.RandomState(1234) + itml = ITML_Supervised(num_constraints=200) + itml.fit(self.iris_points, self.iris_labels, random_state=seed) + res_1 = itml.transform() + + seed = np.random.RandomState(1234) + itml = ITML_Supervised(num_constraints=200) + res_2 = itml.fit_transform(self.iris_points, self.iris_labels, random_state=seed) + + assert_array_almost_equal(res_1, res_2) + +class TestLMNN(MetricTestCase): + def test_lmnn(self): + + lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) + lmnn.fit(self.iris_points, self.iris_labels) + res_1 = lmnn.transform() + + lmnn = LMNN(k=5, learn_rate=1e-6, verbose=False) + res_2 = lmnn.fit_transform(self.iris_points, self.iris_labels) + + assert_array_almost_equal(res_1, res_2) + +class TestSDML(MetricTestCase): + def test_sdml(self): + + seed = np.random.RandomState(1234) + sdml = SDML_Supervised(num_constraints=1500) + sdml.fit(self.iris_points, self.iris_labels, random_state=seed) + res_1 = sdml.transform() + + seed = np.random.RandomState(1234) + sdml = SDML_Supervised(num_constraints=1500) + res_2 = sdml.fit_transform(self.iris_points, self.iris_labels, random_state=seed) + + assert_array_almost_equal(res_1, res_2) + +class TestNCA(MetricTestCase): + def test_nca(self): + + n = self.iris_points.shape[0] + nca = NCA(max_iter=(100000//n), learning_rate=0.01) + nca.fit(self.iris_points, self.iris_labels) + res_1 = nca.transform() + + nca = NCA(max_iter=(100000//n), learning_rate=0.01) + res_2 = nca.fit_transform(self.iris_points, self.iris_labels) + + assert_array_almost_equal(res_1, res_2) + +class TestLFDA(MetricTestCase): + def test_lfda(self): + + lfda = LFDA(k=2, dim=2) + lfda.fit(self.iris_points, self.iris_labels) + res_1 = lfda.transform() + + lfda = LFDA(k=2, dim=2) + res_2 = lfda.fit_transform(self.iris_points, self.iris_labels) + + res_1 = round(res_1[0][0], 3) + res_2 = round(res_2[0][0], 3) + res = (res_1 == res_2 or res_1 == -res_2) + + self.assertTrue(res) + +class TestRCA(MetricTestCase): + def test_rca(self): + + seed = np.random.RandomState(1234) + rca = RCA_Supervised(dim=2, num_chunks=30, chunk_size=2) + rca.fit(self.iris_points, self.iris_labels, random_state=seed) + res_1 = rca.transform() + + seed = np.random.RandomState(1234) + rca = RCA_Supervised(dim=2, num_chunks=30, chunk_size=2) + res_2 = rca.fit_transform(self.iris_points, self.iris_labels, random_state=seed) + + assert_array_almost_equal(res_1, res_2) + + +if __name__ == '__main__': + unittest.main()