-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] Resamplers #13269
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[WIP] Resamplers #13269
Changes from all commits
f90766d
26c4153
60708f4
630ba2b
96df7ab
b9623e4
638d147
ad54774
4254a35
561c47a
b0907a1
33f1fe6
2e24420
4c0aea8
8339128
d9ba54f
75f74fb
4d4b946
f649300
f12664a
21769e5
fbb2b36
d04d7c7
ec2f48a
61fce47
de252ef
704c764
24de6e2
f999919
19023ea
805ad2d
082560a
4fc5830
6a70151
d450047
ada6bad
e5bcb34
d42d318
fbdb966
ffbf12b
dc00dff
c9ffe0c
9cf2a9c
38c5d4d
ed0a431
59bb6c4
8e54d2d
3ad9ff8
f4c8b7e
8a3e1f8
a38607c
8a28047
d0b2789
26e53b9
87181b9
6220843
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,8 @@ | |
import numpy as np | ||
|
||
from . import __version__ | ||
from .utils import _IS_32BIT | ||
from sklearn.utils import _IS_32BIT | ||
from sklearn.utils import safe_indexing, check_X_y_kwargs | ||
|
||
_DEFAULT_TAGS = { | ||
'non_deterministic': False, | ||
|
@@ -603,6 +604,45 @@ def fit_predict(self, X, y=None): | |
return self.fit(X).predict(X) | ||
|
||
|
||
class OutlierRejectionMixin: | ||
"""Mixin class for all outlier detection resamplers in scikit-learn. Child | ||
orausch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
classes remove outliers from the dataset. | ||
""" | ||
_estimator_type = "outlier_rejector" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's problematic to change the type of existing outlier detectors. |
||
|
||
def fit_resample(self, X, y, **kws): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should just be adding this to OutlierMixin rather than adding a new mixin and type. |
||
"""Performs fit on X and returns a new X and y consisting of only the | ||
inliers. | ||
Parameters | ||
---------- | ||
X : ndarray, shape (n_samples, n_features) | ||
Input data X. | ||
y : ndarray, shape (n_samples,) | ||
Input data y. | ||
Returns | ||
------- | ||
X : ndarray, shape (n_samples, n_features) | ||
The original X with outlier samples removed. | ||
y : ndarray, shape (n_samples,) | ||
The original y with outlier samples removed. | ||
kws : dict of ndarray | ||
dict of keyword arguments, with all outlier samples removed. | ||
""" | ||
|
||
check_X_y_kwargs(X, y, kws) | ||
inliers = self.fit_predict(X) == 1 | ||
kwsr = { | ||
kw: safe_indexing(kws[kw], inliers) | ||
for kw in kws | ||
} | ||
return safe_indexing(X, inliers), safe_indexing(y, inliers), kwsr | ||
|
||
|
||
class MetaEstimatorMixin: | ||
_required_parameters = ["estimator"] | ||
"""Mixin class for all meta estimators in scikit-learn.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Author: Joel Nothman | ||
|
||
from ..base import BaseEstimator, MetaEstimatorMixin, clone | ||
from ..utils.metaestimators import if_delegate_has_method | ||
from ..utils.validation import check_is_fitted, check_X_y_kwargs | ||
|
||
|
||
class ResampledTrainer(MetaEstimatorMixin, BaseEstimator): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should get consensus on this name. What's a better one? |
||
"""Composition of a resampler and a estimator | ||
|
||
Read more in the :ref:`User Guide <pipeline_resamplers>`. | ||
|
||
Parameters | ||
---------- | ||
resampler : Estimator supporting fit_resample | ||
|
||
estimator : Estimator | ||
|
||
Attributes | ||
---------- | ||
resampler_ : Estimator | ||
Fitted clone of `resampler`. | ||
|
||
estimator_ : Estimator | ||
Fitted clone of `estimator`. | ||
|
||
Examples | ||
-------- | ||
>>> from sklearn.base import BaseEstimator | ||
>>> from sklearn.compose import ResampledTrainer | ||
>>> from sklearn.datasets import load_iris | ||
>>> from sklearn.linear_model import LogisticRegression | ||
>>> | ||
>>> class HalfSampler(BaseEstimator): | ||
... "Train with every second sample" | ||
... def fit_resample(self, X, y, **kw): | ||
... return X[::2], y[::2] | ||
>>> | ||
>>> est = ResampledTrainer(HalfSampler(), LogisticRegression()) | ||
>>> X, y = load_iris(return_X_y=True) | ||
>>> est.fit(X, y) | ||
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS | ||
ResampledTrainer(...) | ||
>>> est.predict(X[:2]) | ||
array([0, 0]) | ||
""" | ||
|
||
def __init__(self, resampler, estimator): | ||
self.resampler = resampler | ||
self.estimator = estimator | ||
|
||
_required_parameters = ["resampler", "estimator"] | ||
|
||
# TODO: tags? | ||
|
||
def fit(self, X, y=None, **kw): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TODO: docstrings |
||
X, y, kw = check_X_y_kwargs(X, y, kw) | ||
self.resampler_ = clone(self.resampler) | ||
X, y, kw = self.resampler_.fit_resample(X, y, **kw) | ||
|
||
self.estimator_ = clone(self.estimator).fit(X, y, **kw) | ||
return self | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def predict(self, X, **predict_params): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.predict(X, **predict_params) | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def transform(self, X): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.transform(X) | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def predict_proba(self, X): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.predict_proba(X) | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def predict_log_proba(self, X): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.predict_log_proba(X) | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def decision_function(self, X): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.decision_function(X) | ||
|
||
@if_delegate_has_method(delegate="estimator") | ||
def score(self, X, y, **kw): | ||
check_is_fitted(self, "estimator_") | ||
return self.estimator_.score(X, y, **kw) | ||
|
||
@property | ||
def fit_transform(self): | ||
# check if the estimator has a transform function | ||
self.estimator.transform | ||
|
||
def fit_transform(X, y, **kwargs): | ||
self.fit(X, y, **kwargs) | ||
# since estimator_ exists now, we can return transform | ||
return self.estimator_.transform(X) | ||
|
||
return fit_transform | ||
|
||
@property | ||
def _estimator_type(self): | ||
return self.estimator._estimator_type | ||
|
||
@property | ||
def classes_(self): | ||
return self.estimator_.classes_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should remove this and just add a note under
outlier detector
that iffit_resample
is provided, it should act as an outlier rejector, returning the training data with outliers removed.