diff --git a/metric_learn/rca.py b/metric_learn/rca.py index c9fedd59..cbb90430 100644 --- a/metric_learn/rca.py +++ b/metric_learn/rca.py @@ -26,7 +26,9 @@ def _chunk_mean_centering(data, chunks): num_chunks = chunks.max() + 1 chunk_mask = chunks != -1 - chunk_data = data[chunk_mask] + # We need to ensure the data is float so that we can substract the + # mean on it + chunk_data = data[chunk_mask].astype(float, copy=False) chunk_labels = chunks[chunk_mask] for c in xrange(num_chunks): mask = chunk_labels == c @@ -98,7 +100,7 @@ def fit(self, X, chunks): When ``chunks[i] == -1``, point i doesn't belong to any chunklet. When ``chunks[i] == j``, point i belongs to chunklet j. """ - X = self._prepare_inputs(X, ensure_min_samples=2) + X, chunks = self._prepare_inputs(X, chunks, ensure_min_samples=2) # PCA projection to remove noise and redundant information. if self.pca_comps is not None: @@ -109,7 +111,6 @@ def fit(self, X, chunks): X_t = X - X.mean(axis=0) M_pca = None - chunks = np.asanyarray(chunks, dtype=int) chunk_mask, chunked_data = _chunk_mean_centering(X_t, chunks) inner_cov = np.atleast_2d(np.cov(chunked_data, rowvar=0, bias=1)) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 5d6c5d77..091c56e2 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -89,9 +89,15 @@ def stable_init(self, sparsity_param=0.01, num_labeled='deprecated', dSDML.__init__ = stable_init check_estimator(dSDML) - # This fails because the default num_chunks isn't data-dependent. - # def test_rca(self): - # check_estimator(RCA_Supervised) + def test_rca(self): + def stable_init(self, num_dims=None, pca_comps=None, + chunk_size=2, preprocessor=None): + # this init makes RCA stable for scikit-learn examples. + RCA_Supervised.__init__(self, num_chunks=2, num_dims=num_dims, + pca_comps=pca_comps, chunk_size=chunk_size, + preprocessor=preprocessor) + dRCA.__init__ = stable_init + check_estimator(dRCA) RNG = check_random_state(0)