Skip to content

[MRG] change bounds parameter of ITML_Supervised from init to fit #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
"""

def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
num_labeled='deprecated', num_constraints=None, bounds=None,
A0=None, verbose=False, preprocessor=None):
num_labeled='deprecated', num_constraints=None,
bounds='deprecated', A0=None, verbose=False, preprocessor=None):
"""Initialize the supervised version of `ITML`.

`ITML_Supervised` creates pairs of similar sample by taking same class
Expand All @@ -222,14 +222,11 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
be removed in 0.6.0.
num_constraints: int, optional
number of constraints to generate
bounds : `list` of two numbers
Bounds on similarity, aside slack variables, s.t.
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
If not provided at initialization, bounds_[0] and bounds_[1] will be
set to the 5th and 95th percentile of the pairwise distances among all
points in the training data `X`.
bounds : Not used
.. deprecated:: 0.5.0
`bounds` was deprecated in version 0.5.0 and will
be removed in 0.6.0. Set `bounds` at fit time instead :
`itml_supervised.fit(X, y, bounds=...)`
A0 : (d x d) matrix, optional
initial regularization matrix, defaults to identity
verbose : bool, optional
Expand All @@ -245,7 +242,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
self.num_constraints = num_constraints
self.bounds = bounds

def fit(self, X, y, random_state=np.random):
def fit(self, X, y, random_state=np.random, bounds=None):
"""Create constraints from labels and learn the ITML model.


Expand All @@ -259,11 +256,26 @@ def fit(self, X, y, random_state=np.random):

random_state : numpy.random.RandomState, optional
If provided, controls random number generation.

bounds : `list` of two numbers
Bounds on similarity, aside slack variables, s.t.
``d(a, b) < bounds_[0]`` for all given pairs of similar points ``a``
and ``b``, and ``d(c, d) > bounds_[1]`` for all given pairs of
dissimilar points ``c`` and ``d``, with ``d`` the learned distance.
If not provided at initialization, bounds_[0] and bounds_[1] will be
set to the 5th and 95th percentile of the pairwise distances among all
points in the training data `X`.
"""
# TODO: remove these in v0.6.0
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0', DeprecationWarning)
if self.bounds != 'deprecated':
warnings.warn('"bounds" parameter from initialization is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0. Use the "bounds" parameter of this '
'fit method instead.', DeprecationWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
Expand All @@ -274,4 +286,4 @@ def fit(self, X, y, random_state=np.random):
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
pairs, y = wrap_pairs(X, pos_neg)
return _BaseITML._fit(self, pairs, y, bounds=self.bounds)
return _BaseITML._fit(self, pairs, y, bounds=bounds)
41 changes: 29 additions & 12 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def test_iris(self):
csep = class_separation(lsml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.8) # it's pretty terrible

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
lsml_supervised = LSML_Supervised(num_labeled=np.inf)
Expand All @@ -77,9 +78,10 @@ def test_iris(self):
csep = class_separation(itml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.2)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
itml_supervised = ITML_Supervised(num_labeled=np.inf)
Expand All @@ -88,6 +90,19 @@ def test_deprecation(self):
'removed in 0.6.0')
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)

def test_deprecation_bounds(self):
# test that a deprecation message is thrown if bounds is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
itml_supervised = ITML_Supervised(bounds=None)
msg = ('"bounds" parameter from initialization is not used.'
' It has been deprecated in version 0.5.0 and will be'
'removed in 0.6.0. Use the "bounds" parameter of this '
'fit method instead.')
assert_warns_message(DeprecationWarning, msg, itml_supervised.fit, X, y)


class TestLMNN(MetricTestCase):
def test_iris(self):
Expand Down Expand Up @@ -143,9 +158,10 @@ def test_iris(self):
csep = class_separation(sdml.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.25)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
sdml_supervised = SDML_Supervised(num_labeled=np.inf)
Expand Down Expand Up @@ -370,9 +386,10 @@ def test_iris(self):
csep = class_separation(mmc.transform(self.iris_points), self.iris_labels)
self.assertLess(csep, 0.2)

def test_deprecation(self):
# test that the right deprecation message is thrown.
# TODO: remove in v.0.5
def test_deprecation_num_labeled(self):
# test that a deprecation message is thrown if num_labeled is set at
# initialization
# TODO: remove in v.0.6
X = np.array([[0, 0], [0, 1], [2, 0], [2, 1]])
y = np.array([1, 0, 1, 0])
mmc_supervised = MMC_Supervised(num_labeled=np.inf)
Expand Down
6 changes: 3 additions & 3 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def test_itml(self):
preprocessor=None, verbose=False)
""".strip('\n'))
self.assertEqual(str(metric_learn.ITML_Supervised()), """
ITML_Supervised(A0=None, bounds=None, convergence_threshold=0.001, gamma=1.0,
max_iter=1000, num_constraints=None, num_labeled='deprecated',
preprocessor=None, verbose=False)
ITML_Supervised(A0=None, bounds='deprecated', convergence_threshold=0.001,
gamma=1.0, max_iter=1000, num_constraints=None,
num_labeled='deprecated', preprocessor=None, verbose=False)
""".strip('\n'))

def test_lsml(self):
Expand Down