Skip to content

New API to be more compatible with scikit-learn #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
6 of 7 tasks
wdevazelhes opened this issue May 17, 2018 · 2 comments
Closed
6 of 7 tasks

New API to be more compatible with scikit-learn #91

wdevazelhes opened this issue May 17, 2018 · 2 comments
Milestone

Comments

@wdevazelhes
Copy link
Member

wdevazelhes commented May 17, 2018

Hi, this is the new proposal for metric-learn's API.

The previously proposed API (#85) , used a custom data object (ConstrainedDataset), but this is to avoid. This new API proposal, on the contrary, finds a way to use only usual data format as inputs of algorithms (arrays etc).

The main goal of this new API, like the previous attempt, is to change the input data of metric learners that learn on tuple of points (pairs + labels of pairs, triplets, quadruplets), to be compatible with scikit-learn's utilities like cross-validation.

UML Diagram

Here is the current UML diagram of classes

uml diagram

Quick overview

Let's consider the case of a metric learning algorithm fitting on tuples of points. The data points X (or a lazy equivalent like the path to a folder on disk) would be given first as an argument to a Transformer object that would memorize it.

Then, when creating a TupleLearner (let's call this way a Metric Learner that learns and predicts on tuple of points), we would give as an argument the previous Transformer.

Then the TupleLearner would be able to take as input an array of indices (of shape (n_constraints, 2) for PairsLearner for instance), and under the hood form tuples from these input indices plus the Transformer.

Therefore the input to TuplesLearner would be a unique object (the array of tuples), that can be splitted (slicing in two the dataset of tuples along the first dimension slices it in two other dataset of tuples ), which could therefore stand for the usual X in scikit learn, that allows cross validation etc.

A short snippet to sum up this main use case:

class ArrayIndexer(TransformerMixin):
  def __init__(self, X):
    self.X = X
  def transform(self, list_of_indices):
    return self.X[list_of_indices]
# ArrayIndexer would be an object coded in metric-learn

pairs_extractor = ArrayIndexer(X) # X: data, ArrayIndexer: the aforementioned Transformer
mmc = MMC(preprocessor=pairs_extractor)
mmc.fit(pairs, y) # pairs: indices of pairs (2D array of shape (n_constraints, 2) of integers), y: label of constraints (1 if similar, 0 if not) (array of shape (n_constraints)) 

More details :

Here are more details on the API, including more cases that we would want to deal with:

There will be mainly two uses of metric learning algorithms. The first one is to classify pairs/triplets/quadruplets (fitting, predicting etc on labeled tuples of points), and the second one is to transform/classify usual points X.

The following classes will be implemented:

A BaseMetricLearner (invisible to the user):

Will only support a score_pairs function: this is the basis of all metric learning algorithms: we should be able to measure a similarity score between two points.

Also, every metric learner can have as an attribute a Transformer that is used as a preprocessor. This allows, for instance if we train on minibatches, to transform one minibatch at a time.

An BaseExplicitMetricLearner(BaseMetricLearner) (invisible to the user):

This gives the function embed to all metric learners that can be express as an embedding into another space. It is a bit like transform, except that it does not need to respect the scikit-learn API of transform: for instance an argument "kind" can be specified for multimodality metric learning (to specify if we want to embed some text or some image for instance). (See the second part of the Classifier of tuples section). Also, it is not really a transform because for pairs classifiers for instance, it does not transform pairs, which is the basis sample structure for this classifier, but instead embed always applies on points (on the contrary a real "transform" function coherent with the API would need to take the same input as fit, predict, etc, so we cannot define one for pairs classifiers. We would use embed instead.)

Core algorithms

Core algorithms can be either just functions called by the classes below, or private classes if it helps understanding the method. But these classes should not be used directly by the user (Except for the algorithms like NCA or LMNN (taking points and labels), that are the core algorithms themselves)

MetricLearningClassifier/Transformer (visible to the user):

Algorithm that trains on labeled points X and y. The classifier version would fit the metric learner, then at predict time transform points and use a KNN or another specified algorithm, but may not implement the transform function, while the Transformer version may implement it but may not be able to directly predict a class. They can also use a preprocessor.

Classifier of tuples (visible to the user):

Algorithm that implements scikit-learn like functions (fit, predict, score...) but on tuples of points. This is the main change with the current API in master.

This class would do the same as a scikit learn Classifier but the input X of fit, instead of being a 2D of shape (n_samples, n_features), would be either a 2D or 3D array representing tuple of points (See below). A preprocessor would be available too. (Note that here the use of the preprocessor is a bit more tricky to support the use of multimodal inputs (ex: pairs of text and images, or triplet "text", "image1", "image2" (ex: which image is closer to a given word).)

There would be 3 2 possible input forms:

  • either the input is a 3D (or more) array. In this case, we will consider that the user gives already formed tuples. The shape of the array should therefore be for instance for pairs: (n_constraints, 2, n_features, or (n_channels, n_pixels_x, n_pixels_y) for instance) (the dimensions after the second one are the dimensions intrinsic to each data point). In this case, if a preprocessor is specified as an argument when instanciating the metric learner, it should work on points, not pairs, and a tuple of preprocessors could be specified if the points are of different modalities (ex: image and one-hot vector). There the preprocessor will work more like a usual Transformer: transforming data to a better form.
  • either the input is a 2D array (matrix) of this form for instance [["cat", "img1.jpg"], ... ["dog", "img3.jpg"]] [["img3.jpg", "img1.jpg"], ... ["img4", "img3.jpg"]]. There the preprocessor would act more like a data fetcher (and a transformer)
    - or we could also support the case where the input is more complex (where we have "columns" of arbitrary objects like n-darrays and/or objects, strings etc): a dictionary might be a way to represent it like {'images': array_of_shape(100, 8, 8), 'words': ["cat", ..."dog"]}. Note that after discussion with @GaelVaroquaux and @ogrisel, there is an increasing will to support this (dict inputs) in scikit-learn itself. The preprocessor could then be specified as a dictionary like: {'images': Transformer_img(), {'words': Transformer_words()}.
    Note also that in this case, if we have quadruplets of this form (imgs1, txts1, imgs2, txts2), then we should fit preprocessor_imgs on set(union(imgs_1, imgs_2)). We could detect this case like this: we would have in argument of a metric learner: preprocessor={'imgs1': transfo_imgs, 'txts1': transfo_txt, 'imgs2': transfo_imgs, 'txts2': transfo_txt} (where transfo_txt and transfo_imgs are defined in advance). Then if preprocessor[somekey] is preprocessor[someotherkey] this means we should use both keys as keys of columns to fit the transformer on.

Note: We could also allow to pass to a PairsClassifier directly preprocessor=X where X is the data, instead of a transformer object. In this case it would be implicit that the pairs given to the algorithm are integer indices of pairs from the data X. Ex:

mmc = MMC(preprocessor=X)
mmc.fit(pairs, y_pairs)

Note on predict: In the case of pairs classification, a threshold on the similarity score must be defined to predict a class of pairs. This threshold could have a value based on some simple method, or could also be found in a more sophisticated way by specifying a desired precision, and finding on some validation set the appropriate threshold. There is also a will to put into scikit-learn some meta estimators that do this (finding the right threshold of some underlying estimator to match a given precision). So we could benefit from this API if it exist.


This issue is also the place for general discussions for this API.

The branch where the developments are made is https://github.com/metric-learn/metric-learn/tree/new_api_design:

TODO list:

EDITED: These developments would be in a further future and do not need to be implemented in this PR:

  • Allow the preprocessor to work with multiple modality samples (ex: mixed tuples of images and text) but the same type in the array (ex: an integer, or a string (so it is an identifier))
  • Allow the preprocessor to work with multiple types (ex: with dictionaries (cf. detailed section above))

 Enhancement PRs

Along the new API's developments, there will be some small enhancements that could be done in separate PRs, here is the list of those (often updated) :

ping @perimosocordiae @bellet @nvauquie @GaelVaroquaux @ogrisel @agramfort Feel free to edit this thread and add any suggestion.

@bellet
Copy link
Member

bellet commented May 24, 2018

@perimosocordiae fee free to share any feedback or thoughts about this new API idea

@bellet
Copy link
Member

bellet commented Jan 2, 2019

Addressed by #139

@bellet bellet closed this as completed Jan 2, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants