-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] CountFeaturizer for Categorical Data #9614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5f2ebc3
f1c8bc0
d8578f7
0d333a9
50c7136
4417c0e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
========================================================= | ||
Using CountFeaturizer to featurize frequencies | ||
========================================================= | ||
|
||
Shows how to use CountFeaturizer to transform some categorical variables | ||
into a frequency feature. CountFeaturizer can often be used to reduce | ||
training time, classification time, and classification error. | ||
""" | ||
from __future__ import print_function | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import time | ||
from sklearn.preprocessing import FunctionTransformer | ||
from sklearn.preprocessing.data import CountFeaturizer | ||
from sklearn.preprocessing.data import OneHotEncoder | ||
from collections import OrderedDict | ||
from sklearn.datasets import make_classification | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.pipeline import FeatureUnion | ||
|
||
RANDOM_STATE = 30 | ||
|
||
n_datapoints = 1000 | ||
n_informative = 30 | ||
n_features = 30 | ||
n_redundant = 0 | ||
|
||
# Generate a binary classification dataset. | ||
X, y = make_classification(n_samples=n_datapoints, n_features=n_features, | ||
n_clusters_per_class=1, n_informative=n_informative, | ||
n_redundant=n_redundant, random_state=RANDOM_STATE) | ||
|
||
# only make these selected features "categorical" | ||
discretized_features = [0, 1] | ||
non_discretized_features = \ | ||
list(set(range(n_features)) - set(discretized_features)) | ||
non_discretized_features_count = \ | ||
list(set(range(n_features + 1)) - set(discretized_features)) | ||
|
||
|
||
def select_non_discrete(X, count=False): | ||
"""Selects the non-discrete features.""" | ||
if count: | ||
return X[:, non_discretized_features_count] | ||
return X[:, non_discretized_features] | ||
|
||
|
||
def select_discrete(X): | ||
"""Selects the discrete features.""" | ||
return X[:, discretized_features] | ||
|
||
|
||
def process_discrete(X): | ||
"""Processes discrete features to make them categorical.""" | ||
for feature in discretized_features: | ||
X_transform_col = (X[:, feature]).astype(int) | ||
col_min = min(np.amin(X_transform_col), 0) | ||
X[:, feature] = X_transform_col - col_min | ||
return X | ||
|
||
|
||
time_start = time.time() | ||
pipeline_cf = make_pipeline( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the premises CountFeaturizer was built on (based on this Microsoft ML article on count featurization https://docs.microsoft.com/en-us/azure/machine-learning/studio-module-reference/data-transformation-learning-with-counts) was that it may be possible to reduce the time training and classification takes compared to something like one hot encoding due to there being less features generated. Here, the reason why the pipeline was split was because it's necessary to separate out the time it took to preprocess the data from the time it takes to train and use the classifier. |
||
FunctionTransformer(func=process_discrete), | ||
CountFeaturizer(inclusion=discretized_features), | ||
FunctionTransformer(func=lambda X: select_non_discrete(X, count=True))) | ||
X_count = pipeline_cf.fit_transform(X, y=y) | ||
cf_time_preprocessing = time.time() - time_start | ||
|
||
time_start = time.time() | ||
pipeline_ohe_nd = make_pipeline(FunctionTransformer(func=select_non_discrete)) | ||
pipeline_ohe_d = make_pipeline( | ||
FunctionTransformer(func=select_discrete), | ||
FunctionTransformer(func=process_discrete), | ||
OneHotEncoder()) | ||
pipeline_ohe = FeatureUnion( | ||
[("discrete", pipeline_ohe_d), ("nondiscrete", pipeline_ohe_nd)]) | ||
X_one_hot = pipeline_ohe.fit_transform(X, y=y).todense() | ||
ohe_time_preprocessing = time.time() - time_start | ||
|
||
|
||
def get_classifier(): | ||
return RandomForestClassifier(warm_start=True, max_features="log2", | ||
oob_score=True, random_state=RANDOM_STATE) | ||
|
||
|
||
clf = get_classifier() | ||
labels = ["CountFeaturizer + RandomForestClassifier", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this experiment (running the three pipelines on an increasing number of trees) be put into a for loop? Removes duplicate code |
||
"OneHotEncoder + RandomForestClassifier", | ||
"Only RandomForestClassifier"] | ||
error_rate = OrderedDict((label, []) for label in labels) | ||
|
||
min_estimators = (15 * n_datapoints // 500) | ||
max_estimators = (175 * n_datapoints // 500) | ||
time_start = time.time() | ||
|
||
for i in range(min_estimators, max_estimators + 1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does going in steps of 10s make this faster? |
||
clf.set_params(n_estimators=i) | ||
clf.fit(X_count, y) | ||
oob_error = 1 - clf.oob_score_ | ||
error_rate[labels[0]].append((i, oob_error)) | ||
|
||
print("Time taken on CountFeaturizer: ", | ||
(time.time() - time_start + cf_time_preprocessing)) | ||
clf = get_classifier() | ||
time_start = time.time() | ||
|
||
for i in range(min_estimators, max_estimators + 1): | ||
clf.set_params(n_estimators=i) | ||
clf.fit(X_one_hot, y) | ||
oob_error = 1 - clf.oob_score_ | ||
error_rate[labels[1]].append((i, oob_error)) | ||
|
||
print("Time taken on OneHotEncoder: ", | ||
(time.time() - time_start + ohe_time_preprocessing)) | ||
clf = get_classifier() | ||
time_start = time.time() | ||
|
||
for i in range(min_estimators, max_estimators + 1): | ||
clf.set_params(n_estimators=i) | ||
clf.fit(X, y) | ||
oob_error = 1 - clf.oob_score_ | ||
error_rate[labels[2]].append((i, oob_error)) | ||
|
||
print("Time taken on No Encoding: ", (time.time() - time_start)) | ||
|
||
# Generate the "OOB error rate" vs. "n_estimators" plot. | ||
for label, clf_err in error_rate.items(): | ||
xs, ys = zip(*clf_err) | ||
plt.plot(xs, ys, label=label) | ||
|
||
plt.xlim(min_estimators, max_estimators) | ||
plt.xlabel("n_estimators") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (* Personally I'm a big fan of declaring the plotting labels and other stuff (l135, 136, 137) high up in the code, so |
||
plt.ylabel("OOB error rate") | ||
plt.legend(loc="upper right") | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,9 @@ | |
|
||
from __future__ import division | ||
|
||
from collections import defaultdict | ||
from itertools import chain, combinations | ||
import functools | ||
import numbers | ||
import warnings | ||
from itertools import combinations_with_replacement as combinations_w_r | ||
|
@@ -22,6 +24,7 @@ | |
from ..externals import six | ||
from ..externals.six import string_types | ||
from ..utils import check_array | ||
from ..utils import check_X_y | ||
from ..utils.extmath import row_norms | ||
from ..utils.extmath import _incremental_mean_and_var | ||
from ..utils.fixes import _argmax | ||
|
@@ -2866,6 +2869,229 @@ def power_transform(X, method='box-cox', standardize=True, copy=True): | |
return pt.fit_transform(X) | ||
|
||
|
||
def _get_nested_counter(remaining, y_dim, inclusion_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is something in the CI tests (something to do with the pickle module) that does not allow you to do something like So this is a workaround Should I add "This is a workaround due to some pickle issues in CI testing regarding creating nested dicts dynamically" to this as a comment? |
||
"A nested dictionary with 'remaining' layers and a 2D array at the end" | ||
if remaining == 1: | ||
return np.zeros((y_dim, inclusion_size)) | ||
return defaultdict( | ||
functools.partial( | ||
_get_nested_counter, remaining - 1, y_dim, inclusion_size)) | ||
|
||
|
||
class CountFeaturizer(BaseEstimator, TransformerMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Adds a feature representing each feature value's count in training | ||
|
||
Specifically, for each data point 'X_i' in the dataset 'X', it will add in | ||
a new set of columns 'count_X_i' to the end of 'X_i' where 'count_X_i' is the | ||
number of occurences of 'X_i' in the dataset 'X' given the equality indicator | ||
'inclusion'. | ||
|
||
If a 'y' argument is given during the fit step, then the | ||
count in the transform step will be conditional on the 'y'. | ||
The number of columns added will be the number of different values 'y' can | ||
take on. | ||
|
||
This preprocessing step is useful when the number of occurences | ||
of a particular piece of data is helpful in computing the prediction. | ||
|
||
Parameters | ||
---------- | ||
inclusion : 'all', 'each', list, or numpy.ndarray | ||
The inclusion criteria for counting | ||
|
||
- 'all' (default) : Every feature is concatenated and counted | ||
- 'each' : Each feature will have its own set of counts | ||
- list of indices : Only the given list of features is | ||
concatenated and counted | ||
- list of lists of indices : The given list of lists of features is | ||
concatenated and counted, but each list in the | ||
list of lists has its own set of counts | ||
|
||
Attributes | ||
---------- | ||
count_cache_ : defaultdict(int) | ||
The counts of each example learned during 'fit' | ||
|
||
classes_ : list of (index, y) tuples | ||
An enumerated set of all unique values 'y' can have | ||
|
||
n_input_features_ : int | ||
The number of columns of 'X' learned during 'fit' | ||
We use this to compare to the number of columns of 'X' | ||
during transform to make sure that the transformation is compatible | ||
|
||
n_output_features_ : int | ||
The number of columns of 'y' learned during 'fit' | ||
If 0, then the fit is not conditional on the y given | ||
|
||
Examples | ||
-------- | ||
Given a dataset with two features and four samples, we let the transformer | ||
find the number of occurences of each data point in the dataset | ||
Note how the first column duplicates the input data, the second column | ||
corresponds to the count of ``y=0`` and the third column corresponds to | ||
the count of ``y=1``. | ||
|
||
>>> from sklearn.preprocessing.data import CountFeaturizer | ||
>>> X = [[0], [0], [0], [0], [1], [1], [1], [1]] | ||
>>> y = [0, 1, 1, 1, 0, 0, 0, 0] | ||
>>> cf = CountFeaturizer().fit(X, y) | ||
>>> cf.transform(X) # doctest: +NORMALIZE_WHITESPACE | ||
array([[ 0., 1., 3.], | ||
[ 0., 1., 3.], | ||
[ 0., 1., 3.], | ||
[ 0., 1., 3.], | ||
[ 1., 4., 0.], | ||
[ 1., 4., 0.], | ||
[ 1., 4., 0.], | ||
[ 1., 4., 0.]]) | ||
|
||
See also | ||
-------- | ||
https://blogs.technet.microsoft.com/machinelearning/2015/02/17/big-learning-made-easy-with-counts/ # noqa | ||
https://msdn.microsoft.com/en-us/library/azure/dn913056.aspx | ||
""" | ||
def __init__(self, inclusion='all'): | ||
self.inclusion = inclusion | ||
|
||
@staticmethod | ||
def _valid_data_type(type_check): | ||
return isinstance(type_check, (np.ndarray, list)) | ||
|
||
@staticmethod | ||
def _check_inclusion(inclusion, n_input_features=1): | ||
if inclusion is None: | ||
raise ValueError("Inclusion cannot be none") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not tested |
||
if isinstance(inclusion, str) and inclusion == "all": | ||
return np.array([range(n_input_features)]) | ||
elif isinstance(inclusion, str) and inclusion == "each": | ||
return np.array([[i] for i in range(n_input_features)]) | ||
elif CountFeaturizer._valid_data_type(inclusion): | ||
if len(inclusion) == 0: | ||
raise ValueError("Inclusion size must not be 0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not tested |
||
if CountFeaturizer._valid_data_type(inclusion[0]): | ||
return inclusion | ||
else: | ||
return [inclusion] | ||
else: | ||
raise ValueError("Illegal value for inclusion parameter") | ||
|
||
def fit(self, X, y=None): | ||
"""Fits the CountFeaturizer to X, y | ||
|
||
Stores the counts for each example X, conditional on y | ||
Both X and y must be appropriately reshaped to a 2D list | ||
|
||
Parameters | ||
---------- | ||
X : array | ||
The data set to learn the counts from, conditional to 'y' | ||
X must not be 1 dimensional | ||
|
||
y : array-like, optional | ||
If provided, a separate column is output for each value of 'y', | ||
counting the occurences of 'X' conditioned on that 'y' value | ||
|
||
Returns | ||
------- | ||
self | ||
""" | ||
|
||
if y is not None: | ||
X, y = check_X_y(X, y, multi_output=True) | ||
if len(y.shape) == 1: | ||
self.n_output_features_ = 1 | ||
y = np.reshape(y, (-1, 1)) | ||
else: | ||
self.n_output_features_ = len(y[0]) | ||
else: | ||
X = check_array(X) | ||
y = np.zeros((len(X), 1)) | ||
self.n_output_features_ = 1 | ||
|
||
self.n_input_features_ = len(X[0]) | ||
inclusion_used = \ | ||
CountFeaturizer._check_inclusion( | ||
self.inclusion, n_input_features=self.n_input_features_) | ||
len_data = len(X) | ||
len_inclusion = len(inclusion_used) | ||
self.count_cache_ = \ | ||
_get_nested_counter(3, self.n_output_features_, len_inclusion) | ||
classes_unsorted = [set() for i in range(self.n_output_features_)] | ||
|
||
for inclusion_i in range(len_inclusion): | ||
for i in range(len_data): | ||
X_key = tuple(X[i].take(inclusion_used[inclusion_i])) | ||
for j in range(self.n_output_features_): | ||
y_key = y[i, j] | ||
self.count_cache_[X_key][y_key][j, inclusion_i] += 1 | ||
classes_unsorted[j].add(y_key) | ||
|
||
self.classes_ = \ | ||
[list(enumerate(sorted(ys))) for ys in classes_unsorted] | ||
|
||
return self | ||
|
||
def transform(self, X): | ||
"""Transforms X to include the counts learned during 'fit' | ||
|
||
Augments 'X' with a new column containing the counts of each example, | ||
conditional on 'y'. | ||
|
||
Parameters | ||
---------- | ||
X : array | ||
The 'X' that we augment with count columns | ||
'X' must not be 1 dimensional | ||
|
||
Returns | ||
------- | ||
transformed : numpy.ndarray | ||
The transformed input | ||
|
||
Notes | ||
----- | ||
The data returned from the transformation will always be a | ||
numpy.ndarray | ||
""" | ||
|
||
check_is_fitted(self, ['count_cache_', 'n_input_features_']) | ||
|
||
X = check_array(X) | ||
len_data = len(X) | ||
len_classes = 0 | ||
for ys in self.classes_: | ||
len_classes += len(ys) | ||
|
||
num_features = len(X[0]) | ||
if self.n_input_features_ != num_features: | ||
raise ValueError("Dimensions mismatch in X during transform") | ||
inclusion_used = \ | ||
CountFeaturizer._check_inclusion( | ||
self.inclusion, n_input_features=self.n_input_features_) | ||
|
||
# the number of added cols is the number of unique y vals | ||
# multiplied by the number of different inclusion lists | ||
num_added_cols = len_classes * len(inclusion_used) | ||
transformed = np.zeros((len_data, num_features + num_added_cols)) | ||
transformed[:, :-num_added_cols] = X | ||
|
||
col_offset_inclusion = 0 | ||
for inclusion_i in range(len(inclusion_used)): | ||
col_offset_y = 0 | ||
col_offset_inclusion = inclusion_i * len_classes | ||
for j in range(self.n_output_features_): | ||
for y_ind, y_key in self.classes_[j]: | ||
for i in range(len_data): | ||
X_key = tuple(X[i].take(inclusion_used[inclusion_i])) | ||
transformed[i, num_features + y_ind + | ||
col_offset_y + col_offset_inclusion] = \ | ||
self.count_cache_[X_key][y_key][j, inclusion_i] | ||
col_offset_y += len(self.classes_[j]) | ||
|
||
return transformed | ||
|
||
|
||
class CategoricalEncoder(BaseEstimator, TransformerMixin): | ||
"""Encode categorical features as a numeric array. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would replace the last sentence with:
CountFeaturizer can be used as an alternative to one-hot encoding for non-linear estimators that do not work efficiently on high cardinality categorical variables such as ensemble of trees (random forests and gradient boosted trees).