-
Notifications
You must be signed in to change notification settings - Fork 228
[MRG] New API should allow prediction functions and scoring #95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] New API should allow prediction functions and scoring #95
Conversation
Basically these are the tests from PR scikit-learn-contrib#85, but reformatted to use pytest, and formed tuples instead of ConstrainedDatasets.
- Make PairsClassifierMixin and QuadrupletsClassifierMixin classes, to implement scoring functions - Implement a new API for supervised wrappers of weakly supervised learning estimators (through the use of base classes, (ex: BaseMMC), from which inherit child classes (ex: MMC and MMC_Supervised) (which is the same idea as in PR scikit-learn-contrib#85 - Delete tests that use tuples learners as transformers (as we do not want to support this behaviour anymore: it is too complicated to allow such different input types (tuples or points) for the same estimator
# Conflicts: # metric_learn/sdml.py
I just merged with the recently merged PR #92, so changes introduced by this PR are now clearer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- maybe good to test warning/errors in the check functions? such as wrong labels (not -1/1), etc
- predict function for pairs: later we should think about how to implement a threshold-based predict, without fixing the threshold in advance but tuning it automatically on the train set to achieve desired precision
- It looks like we are loosing the ability of weakly supervised algorithm to be used to transform the data, but I guess this will be fixed in the next PR introducing a Mahalanobis Mixin with an
embed
method
metric_learn/base_metric.py
Outdated
class _PairsClassifierMixin: | ||
|
||
def predict(self, pairs): | ||
"""Predicts the learned similarity between input pairs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be metric instead of similarity here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, thanks
metric_learn/base_metric.py
Outdated
class _QuadrupletsClassifierMixin: | ||
|
||
def predict(self, quadruplets): | ||
"""Predicts differences between sample similarities in input quadruplets. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distances?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks
test/test_weakly_supervised.py
Outdated
|
||
|
||
def build_pairs(): | ||
# test that you can do cross validation on a ConstrainedDataset with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no ConstrainedDataset anymore. also X_constrained
should be renamed (this is a set of pairs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as some tests are parameterized to work for pairs and quadruplets, I will rename them tuples in the tests, but pairs and quadruplets in build_pairs
and build_quadruplets
functions that initialize data
test/test_weakly_supervised.py
Outdated
|
||
|
||
def build_quadruplets(): | ||
# test that you can do cross validation on a ConstrainedDataset with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) * | ||
similar_diffs, axis=1)) - | ||
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) * | ||
dissimilar_diffs, axis=1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern, distance under some metric, seems like it should be factored out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed, the function will call function score_pairs
(that returns the new metric between points) that will be inherited from the BaseMetricLearner, and implemented through ExplicitMixin
(a Mixin for all learners that can embed data) (so score_pairs
will be implemented as the euclidean distance between embeddings)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(this should ultimately be in the Mahalanobis Mixin)
metric_learn/base_metric.py
Outdated
The quadruplets score. | ||
""" | ||
predicted_sign = self.decision_function(quadruplets) < 0 | ||
return np.sum(predicted_sign) / predicted_sign.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not np.mean(np.sign(...))
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much cleaner indeed, thanks !
Yes indeed, I will add it to the TODO in the issue #91
Yes, it is in the TODO
Yes, indeed, the abstract method will be created in |
…and scikit-learn-contrib#95 (review) - replace similarity by metric - replace constrained dataset by pairs/quadruplets - simplify score on quadruplets expression - replace ``X_constrained`` in tests by pairs/quadruplets/tuples
Yes, one possibility is to implement only a Mahalanobis Mixin for now (since all current algorithms fall in this category) |
The new API should allow metric learning algorithms that fit on tuples of points to also predict, score, etc on tuples, Therefore being usable in scikit-learn's cross-validation routines. This is part of the PRs need for issue #91.
ConstrainedDatasets
.Move docstrings from_fit
tofit
Remove unused imports