Skip to content

[MRG] New API should allow prediction functions and scoring #95

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

wdevazelhes
Copy link
Member

@wdevazelhes wdevazelhes commented May 22, 2018

The new API should allow metric learning algorithms that fit on tuples of points to also predict, score, etc on tuples, Therefore being usable in scikit-learn's cross-validation routines. This is part of the PRs need for issue #91.

  • Take previous tests from [WIP] New API proposal #85, and refactor them using pytest, and to allow already formed tuples (3D arrays) instead of ConstrainedDatasets.
  • Take previous code from [WIP] New API proposal #85 and adapt it to formed tuples
  • Make tests work
  • Make some modifications if needed
    • Move docstrings from _fit to fit
    • Remove unused imports

Basically these are the tests from PR scikit-learn-contrib#85, but reformatted to use pytest, and formed tuples instead of ConstrainedDatasets.
William de Vazelhes added 4 commits May 24, 2018 11:50
- Make PairsClassifierMixin and QuadrupletsClassifierMixin classes, to implement scoring functions
- Implement a new API for supervised wrappers of weakly supervised learning estimators (through the use of base classes, (ex: BaseMMC), from which inherit child classes (ex: MMC and MMC_Supervised) (which is the same idea as in PR scikit-learn-contrib#85
- Delete tests that use tuples learners as transformers (as we do not want to support this behaviour anymore: it is too complicated to allow such different input types (tuples or points) for the same estimator
# Conflicts:
#	metric_learn/sdml.py
@wdevazelhes
Copy link
Member Author

wdevazelhes commented May 25, 2018

I just merged with the recently merged PR #92, so changes introduced by this PR are now clearer.

@wdevazelhes wdevazelhes changed the title [WIP] New API should allow prediction functions and scoring [MRG] New API should allow prediction functions and scoring May 25, 2018
Copy link
Member

@bellet bellet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • maybe good to test warning/errors in the check functions? such as wrong labels (not -1/1), etc
  • predict function for pairs: later we should think about how to implement a threshold-based predict, without fixing the threshold in advance but tuning it automatically on the train set to achieve desired precision
  • It looks like we are loosing the ability of weakly supervised algorithm to be used to transform the data, but I guess this will be fixed in the next PR introducing a Mahalanobis Mixin with an embed method

class _PairsClassifierMixin:

def predict(self, pairs):
"""Predicts the learned similarity between input pairs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be metric instead of similarity here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, thanks

class _QuadrupletsClassifierMixin:

def predict(self, quadruplets):
"""Predicts differences between sample similarities in input quadruplets.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distances?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks



def build_pairs():
# test that you can do cross validation on a ConstrainedDataset with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no ConstrainedDataset anymore. also X_constrained should be renamed (this is a set of pairs)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as some tests are parameterized to work for pairs and quadruplets, I will rename them tuples in the tests, but pairs and quadruplets in build_pairs and build_quadruplets functions that initialize data



def build_quadruplets():
# test that you can do cross validation on a ConstrainedDataset with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

return (np.sqrt(np.sum(similar_diffs.dot(self.metric()) *
similar_diffs, axis=1)) -
np.sqrt(np.sum(dissimilar_diffs.dot(self.metric()) *
dissimilar_diffs, axis=1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern, distance under some metric, seems like it should be factored out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, the function will call function score_pairs (that returns the new metric between points) that will be inherited from the BaseMetricLearner, and implemented through ExplicitMixin (a Mixin for all learners that can embed data) (so score_pairs will be implemented as the euclidean distance between embeddings)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this should ultimately be in the Mahalanobis Mixin)

The quadruplets score.
"""
predicted_sign = self.decision_function(quadruplets) < 0
return np.sum(predicted_sign) / predicted_sign.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not np.mean(np.sign(...)) here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much cleaner indeed, thanks !

@wdevazelhes
Copy link
Member Author

maybe good to test warning/errors in the check functions? such as wrong labels (not -1/1), etc

Yes indeed, I will add it to the TODO in the issue #91

predict function for pairs: later we should think about how to implement a threshold-based predict, without fixing the threshold in advance but tuning it automatically on the train set to achieve desired precision

Yes, it is in the TODO

It looks like we are loosing the ability of weakly supervised algorithm to be used to transform the data, but I guess this will be fixed in the next PR introducing a Mahalanobis Mixin with an embed method

Yes, indeed, the abstract method will be created in ExplicitMixin, and then implemented in MahalanobisMixin. I wonder however if we could not postpone ExplicitMixin to when there are metric learners which are not Explicit, and for now implement embed and score_pairs directly in MahalanobisMixin.

…and scikit-learn-contrib#95 (review)

- replace similarity by metric
- replace constrained dataset by pairs/quadruplets
- simplify score on quadruplets expression
- replace ``X_constrained`` in tests by pairs/quadruplets/tuples
@bellet
Copy link
Member

bellet commented Jun 5, 2018

Yes, one possibility is to implement only a Mahalanobis Mixin for now (since all current algorithms fall in this category)

@wdevazelhes wdevazelhes mentioned this pull request Jun 6, 2018
7 tasks
@wdevazelhes wdevazelhes merged commit 24b0def into scikit-learn-contrib:new_api_design Jun 8, 2018
@wdevazelhes wdevazelhes deleted the feat/api_prediction branch August 22, 2018 06:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants