Skip to content

TST activate common tests for TSNE #25374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sklearn/manifold/_t_sne.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from scipy.sparse import csr_matrix, issparse
from numbers import Integral, Real
from ..neighbors import NearestNeighbors
from ..base import BaseEstimator
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils import check_random_state
from ..utils._openmp_helpers import _openmp_effective_n_threads
from ..utils.validation import check_non_negative
Expand Down Expand Up @@ -537,7 +537,7 @@ def trustworthiness(X, X_embedded, *, n_neighbors=5, metric="euclidean"):
return t


class TSNE(BaseEstimator):
class TSNE(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it weird to have a sub-class of TransformerMixin without a transform method. But maybe it's better than what we do currently...

Copy link
Member Author

@glemaitre glemaitre Jan 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Mixin only defines fit_transform and we override it, I find it fine to inherit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem we are trying to solve by inheriting is having TSNE selected as a transformer by all_estimators. Is there another way of doing that, other than having to inherit from TransformerMixin?

It feels weird to inherit, which for me is saying "I am a transformer with all the things a transformer can do!" and then not have a transform method. But the mixin doesn't define it either. So is having a transform method not part of being "a real transformer"?

On a pragmatic level, using the mixin is a nice way to get into the list of all_estimators and maybe mixins aren't like real "is a" style inheritance?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I am a transformer with all the things a transformer can do!"

Looking at the documentation, it seems that the original spirit is indeed to have at least fit and transform and fit_transform is just a convenience.

On a pragmatic level, using the mixin is a nice way to get into the list of all_estimators and maybe mixins aren't like real "is a" style inheritance?

We could always replace by duck-typing: if an estimator implements fit+transform or/and fit_transform, then it should pass the test of a transformer.
Indeed, some of our checks are ducktyping while the helper to list the estimators is checking for the mixins.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the annoying part with duck-typing is that we already need to have an instance while we are playing with classes at this stage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand what you mean. To find out if a method is implemented you could do something like "fit" in dir(Estimator) no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed we could. I just had in my mind a piece of the parameter validation framework HasMethods that does this job already for instances.

Copy link
Member

@jeremiedbb jeremiedbb Mar 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the glossary, a transformer is an estimator that implements transform and/or fit_transform, so TSNE complies with that and since TransformerMixin is supposed to be the "mixing class for all transformers in scikit-learn", then I find it totally appropriate that TSNE inherits from this mixin

Copy link
Member

@ogrisel ogrisel Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Yet, I am still not convinced we should rely on isinstance(obj, TransformerMixin) to discover all transformers in scikit-learn. We should rather use an estimator tag or duck-typing.

But +1 with moving forward with this TSNE-specific PR which is already a net improvement in itself and delegate the discussion of how to properly discover transformers in common test to another issue/PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should rather use an estimator tag or duck-typing.

This is related to https://github.com/scikit-learn/scikit-learn/pull/17806/files which is about all other estimator types.

"""T-distributed Stochastic Neighbor Embedding.

t-SNE [1] is a tool to visualize high-dimensional data. It converts
Expand Down Expand Up @@ -1145,5 +1145,10 @@ def fit(self, X, y=None):
self.fit_transform(X)
return self

@property
def _n_features_out(self):
"""Number of transformed output features."""
return self.embedding_.shape[1]

def _more_tags(self):
return {"pairwise": self.metric == "precomputed"}
45 changes: 26 additions & 19 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4234,9 +4234,14 @@ def fit_then_transform(est):
def fit_transform(est):
return est.fit_transform(X, y)

transform_methods = [fit_then_transform, fit_transform]
for transform_method in transform_methods:
transform_methods = {
"transform": fit_then_transform,
"fit_transform": fit_transform,
}
for name, transform_method in transform_methods.items():
transformer = clone(transformer)
if not hasattr(transformer, name):
continue
X_trans_no_setting = transform_method(transformer)

# Auto wrapping only wraps the first array
Expand Down Expand Up @@ -4269,29 +4274,31 @@ def _output_from_fit_transform(transformer, name, X, df, y):
("fit.transform/array/df", X, df),
("fit.transform/array/array", X, X),
]
for (
case,
data_fit,
data_transform,
) in cases:
transformer.fit(data_fit, y)
if name in CROSS_DECOMPOSITION:
X_trans, _ = transformer.transform(data_transform, y)
else:
X_trans = transformer.transform(data_transform)
outputs[case] = (X_trans, transformer.get_feature_names_out())
if all(hasattr(transformer, meth) for meth in ["fit", "transform"]):
for (
case,
data_fit,
data_transform,
) in cases:
transformer.fit(data_fit, y)
if name in CROSS_DECOMPOSITION:
X_trans, _ = transformer.transform(data_transform, y)
else:
X_trans = transformer.transform(data_transform)
outputs[case] = (X_trans, transformer.get_feature_names_out())

# fit_transform case:
cases = [
("fit_transform/df", df),
("fit_transform/array", X),
]
for case, data in cases:
if name in CROSS_DECOMPOSITION:
X_trans, _ = transformer.fit_transform(data, y)
else:
X_trans = transformer.fit_transform(data, y)
outputs[case] = (X_trans, transformer.get_feature_names_out())
if hasattr(transformer, "fit_transform"):
for case, data in cases:
if name in CROSS_DECOMPOSITION:
X_trans, _ = transformer.fit_transform(data, y)
else:
X_trans = transformer.fit_transform(data, y)
outputs[case] = (X_trans, transformer.get_feature_names_out())

return outputs

Expand Down