Skip to content

[MRG] API: remove num_labeled parameter #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,6 @@ def chunks(self, num_chunks=100, chunk_size=2, random_state=np.random):
(num_chunks, chunk_size))
return chunks

@staticmethod
def random_subset(all_labels, num_preserved=np.inf, random_state=np.random):
"""
the random state object to be passed must be a numpy random seed
"""
n = len(all_labels)
num_ignored = max(0, n - num_preserved)
idx = random_state.randint(n, size=num_ignored)
partial_labels = np.array(all_labels, copy=True)
partial_labels[idx] = -1
return Constraints(partial_labels)

def wrap_pairs(X, constraints):
a = np.array(constraints[0])
Expand All @@ -109,4 +98,4 @@ def wrap_pairs(X, constraints):
constraints = np.vstack((np.column_stack((a, b)), np.column_stack((c, d))))
y = np.vstack([np.ones((len(a), 1)), - np.ones((len(c), 1))])
pairs = X[constraints]
return pairs, y
return pairs, y
20 changes: 12 additions & 8 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

from __future__ import print_function, absolute_import
import warnings
import numpy as np
from six.moves import xrange
from sklearn.metrics import pairwise_distances
Expand Down Expand Up @@ -172,8 +173,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
"""

def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
num_labeled=np.inf, num_constraints=None, bounds=None, A0=None,
verbose=False, preprocessor=None):
num_labeled='deprecated', num_constraints=None, bounds=None,
A0=None, verbose=False, preprocessor=None):
"""Initialize the supervised version of `ITML`.

`ITML_Supervised` creates pairs of similar sample by taking same class
Expand All @@ -186,10 +187,10 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
value for slack variables
max_iter : int, optional
convergence_threshold : float, optional
num_labeled : int, optional (default=np.inf)
number of labeled points to keep for building pairs. Extra
labeled points will be considered unlabeled, and ignored as such.
Use np.inf (default) to use all labeled points.
num_labeled : Not used
.. deprecated:: 0.5.0
`num_labeled` was deprecated in version 0.5.0 and will
be removed in 0.6.0.
num_constraints: int, optional
number of constraints to generate
bounds : list (pos,neg) pairs, optional
Expand Down Expand Up @@ -224,14 +225,17 @@ def fit(self, X, y, random_state=np.random):
random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0', DeprecationWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
num_classes = len(np.unique(y))
num_constraints = 20 * num_classes**2

c = Constraints.random_subset(y, self.num_labeled,
random_state=random_state)
c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
pairs, y = wrap_pairs(X, pos_neg)
Expand Down
21 changes: 13 additions & 8 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from __future__ import print_function, absolute_import, division
import warnings
import numpy as np
import scipy.linalg
from six.moves import xrange
Expand Down Expand Up @@ -172,8 +173,9 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
metric (See :meth:`transformer_from_metric`.)
"""

def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
num_constraints=None, weights=None, verbose=False,
def __init__(self, tol=1e-3, max_iter=1000, prior=None,
num_labeled='deprecated', num_constraints=None, weights=None,
verbose=False,
preprocessor=None):
"""Initialize the supervised version of `LSML`.

Expand All @@ -188,10 +190,10 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, num_labeled=np.inf,
max_iter : int, optional
prior : (d x d) matrix, optional
guess at a metric [default: covariance(X)]
num_labeled : int, optional (default=np.inf)
number of labeled points to keep for building quadruplets. Extra
labeled points will be considered unlabeled, and ignored as such.
Use np.inf (default) to use all labeled points.
num_labeled : Not used
.. deprecated:: 0.5.0
`num_labeled` was deprecated in version 0.5.0 and will
be removed in 0.6.0.
num_constraints: int, optional
number of constraints to generate
weights : (m,) array of floats, optional
Expand Down Expand Up @@ -222,14 +224,17 @@ def fit(self, X, y, random_state=np.random):
random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0', DeprecationWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
num_classes = len(np.unique(y))
num_constraints = 20 * num_classes**2

c = Constraints.random_subset(y, self.num_labeled,
random_state=random_state)
c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
random_state=random_state)
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],
Expand Down
20 changes: 12 additions & 8 deletions metric_learn/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

from __future__ import print_function, absolute_import, division
import warnings
import numpy as np
from six.moves import xrange
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -389,8 +390,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
"""

def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
num_labeled=np.inf, num_constraints=None,
A0=None, diagonal=False, diagonal_c=1.0, verbose=False,
num_labeled='deprecated', num_constraints=None, A0=None,
diagonal=False, diagonal_c=1.0, verbose=False,
preprocessor=None):
"""Initialize the supervised version of `MMC`.

Expand All @@ -403,10 +404,10 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
max_iter : int, optional
max_proj : int, optional
convergence_threshold : float, optional
num_labeled : int, optional (default=np.inf)
number of labeled points to keep for building pairs. Extra
labeled points will be considered unlabeled, and ignored as such.
Use np.inf (default) to use all labeled points.
num_labeled : Not used
.. deprecated:: 0.5.0
`num_labeled` was deprecated in version 0.5.0 and will
be removed in 0.6.0.
num_constraints: int, optional
number of constraints to generate
A0 : (d x d) matrix, optional
Expand Down Expand Up @@ -443,14 +444,17 @@ def fit(self, X, y, random_state=np.random):
random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0', DeprecationWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
num_classes = len(np.unique(y))
num_constraints = 20 * num_classes**2

c = Constraints.random_subset(y, self.num_labeled,
random_state=random_state)
c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
pairs, y = wrap_pairs(X, pos_neg)
Expand Down
18 changes: 11 additions & 7 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""

from __future__ import absolute_import
import warnings
import numpy as np
from sklearn.base import TransformerMixin
from sklearn.covariance import graph_lasso
Expand Down Expand Up @@ -113,7 +114,7 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
"""

def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
num_labeled=np.inf, num_constraints=None, verbose=False,
num_labeled='deprecated', num_constraints=None, verbose=False,
preprocessor=None):
"""Initialize the supervised version of `SDML`.

Expand All @@ -128,10 +129,10 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, use_cov=True,
trade off between optimizer and sparseness (see graph_lasso)
use_cov : bool, optional
controls prior matrix, will use the identity if use_cov=False
num_labeled : int, optional (default=np.inf)
number of labeled points to keep for building pairs. Extra
labeled points will be considered unlabeled, and ignored as such.
Use np.inf (default) to use all labeled points.
num_labeled : Not used
.. deprecated:: 0.5.0
`num_labeled` was deprecated in version 0.5.0 and will
be removed in 0.6.0.
num_constraints : int, optional
number of constraints to generate
verbose : bool, optional
Expand Down Expand Up @@ -164,14 +165,17 @@ def fit(self, X, y, random_state=np.random):
self : object
Returns the instance.
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0', DeprecationWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
num_classes = len(np.unique(y))
num_constraints = 20 * num_classes**2

c = Constraints.random_subset(y, self.num_labeled,
random_state=random_state)
c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
pairs, y = wrap_pairs(X, pos_neg)
Expand Down
44 changes: 44 additions & 0 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_iris(self):
csep = class_separation(lsml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.8) # it's pretty terrible

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
lsml_supervised = LSML_Supervised(num_labeled=np.inf)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, lsml_supervised.fit, X, y)


class TestITML(MetricTestCase):
def test_iris(self):
Expand All @@ -66,6 +77,17 @@ def test_iris(self):
csep = class_separation(itml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.2)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
itml_supervised = ITML_Supervised(num_labeled=np.inf)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)


class TestLMNN(MetricTestCase):
def test_iris(self):
Expand Down Expand Up @@ -121,6 +143,17 @@ def test_iris(self):
csep = class_separation(sdml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.25)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, sdml_supervised.fit, X, y)


class TestNCA(MetricTestCase):
def test_iris(self):
Expand Down Expand Up @@ -242,6 +275,17 @@ def test_iris(self):
csep = class_separation(mmc.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.2)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
mmc_supervised = MMC_Supervised(num_labeled=np.inf)
msg = ('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, mmc_supervised.fit, X, y)


@pytest.mark.parametrize(('algo_class', 'dataset'),
[(NCA, make_classification()),
Expand Down
12 changes: 6 additions & 6 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_itml(self):
""".strip('\n'))
self.assertEqual(str(metric_learn.ITML_Supervised()), """
ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0,
max_iter=1000, num_constraints=None, num_labeled=inf,
max_iter=1000, num_constraints=None, num_labeled='deprecated',
preprocessor=None, verbose=False)
""".strip('\n'))

Expand All @@ -42,7 +42,7 @@ def test_lsml(self):
"LSML(max_iter=1000, preprocessor=None, prior=None, tol=0.001, "
"verbose=False)")
self.assertEqual(str(metric_learn.LSML_Supervised()), """
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled=inf,
LSML_Supervised(max_iter=1000, num_constraints=None, num_labeled='deprecated',
preprocessor=None, prior=None, tol=0.001, verbose=False,
weights=None)
""".strip('\n'))
Expand All @@ -52,9 +52,9 @@ def test_sdml(self):
"SDML(balance_param=0.5, preprocessor=None, "
"sparsity_param=0.01, use_cov=True,\n verbose=False)")
self.assertEqual(str(metric_learn.SDML_Supervised()), """
SDML_Supervised(balance_param=0.5, num_constraints=None, num_labeled=inf,
preprocessor=None, sparsity_param=0.01, use_cov=True,
verbose=False)
SDML_Supervised(balance_param=0.5, num_constraints=None,
num_labeled='deprecated', preprocessor=None, sparsity_param=0.01,
use_cov=True, verbose=False)
""".strip('\n'))

def test_rca(self):
Expand All @@ -78,7 +78,7 @@ def test_mmc(self):
self.assertEqual(str(metric_learn.MMC_Supervised()), """
MMC_Supervised(A0=None, convergence_threshold=1e-06, diagonal=False,
diagonal_c=1.0, max_iter=100, max_proj=10000, num_constraints=None,
num_labeled=inf, preprocessor=None, verbose=False)
num_labeled='deprecated', preprocessor=None, verbose=False)
""".strip('\n'))

if __name__ == '__main__':
Expand Down
3 changes: 1 addition & 2 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def build_data():
input_data, labels = load_iris(return_X_y=True)
X, y = shuffle(input_data, labels, random_state=SEED)
num_constraints = 50
constraints = (
Constraints.random_subset(y, random_state=check_random_state(SEED)))
constraints = Constraints(y)
pairs = (
constraints
.positive_negative_pairs(num_constraints, same_length=True,
Expand Down