Skip to content

[WIP] Added PredictionTransformer and ThresholdClassifier #6663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions sklearn/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# Virgile Fritsch
# Alexandre Gramfort
# Lars Buitinck
# Tim Head
# Licence: BSD

from collections import defaultdict
Expand All @@ -15,14 +16,16 @@
import numpy as np
from scipy import sparse

from .base import BaseEstimator, TransformerMixin
from .base import BaseEstimator, ClassifierMixin
from .base import MetaEstimatorMixin, TransformerMixin
from .externals.joblib import Parallel, delayed
from .externals import six
from .utils import tosequence
from .utils.metaestimators import if_delegate_has_method
from .externals.six import iteritems

__all__ = ['Pipeline', 'FeatureUnion']
__all__ = ['Pipeline', 'FeatureUnion', 'PredictionTransformer',
'ThresholdClassifier']


class Pipeline(BaseEstimator):
Expand Down Expand Up @@ -576,3 +579,32 @@ def make_union(*transformers):
f : FeatureUnion
"""
return FeatureUnion(_name_estimators(transformers))


class PredictionTransformer(BaseEstimator, TransformerMixin, MetaEstimatorMixin):
def __init__(self, clf):
"""Replaces all features with `clf.predict_proba(X)`"""
self.clf = clf

def fit(self, X, y):
self.clf.fit(X, y)
return self

def transform(self, X):
return self.clf.predict_proba(X)


class ThresholdClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, threshold=0.5):
"""Classify samples based on whether they are above of below `threshold`"""
self.threshold = threshold

def fit(self, X, y):
self.classes_ = np.unique(y)
return self

def predict(self, X):
# the implementation used here breaks ties differently
# from the one used in RFs:
#return self.classes_.take(np.argmax(X, axis=1), axis=0)
return np.where(X[:, 0]>self.threshold, *self.classes_)
17 changes: 16 additions & 1 deletion sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

from sklearn.base import clone
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
from sklearn.pipeline import PredictionTransformer, ThresholdClassifier
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris, make_classification
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import CountVectorizer

Expand Down Expand Up @@ -472,3 +474,16 @@ def test_X1d_inverse_transform():
X = np.ones(10)
msg = "1d X will not be reshaped in pipeline.inverse_transform"
assert_warns_message(FutureWarning, msg, pipeline.inverse_transform, X)


def test_prediction_transformer_pipeline():
X, y = make_classification()

pipe = make_pipeline(PredictionTransformer(RandomForestClassifier()),
ThresholdClassifier())
pipe.fit(X, y)

clf = RandomForestClassifier()
clf.fit(X, y)

assert_array_equal(clf.predict(X), pipe.predict(X))