Skip to content

Commit 08607f5

Browse files
committed
#5778: Replaced cut_points_ attribute with zero_intervals_ and searched_points_
1 parent d97d8bf commit 08607f5

File tree

1 file changed

+89
-18
lines changed

1 file changed

+89
-18
lines changed

sklearn/preprocessing/discretization.py

+89-18
Original file line numberDiff line numberDiff line change
@@ -37,33 +37,59 @@ class Discretizer(BaseEstimator, TransformerMixin):
3737
max_ : float
3838
The maximum value of the input data.
3939
40-
cut_points_ : array, shape [numBins - 1]
40+
cut_points_ : array, shape [numBins - 1, n_continuous_features]
4141
Contains the boundaries for which the data lies. Each interval
4242
has an open left boundary, and a closed right boundary.
4343
44+
Given a feature, the width of each interval is given by
45+
(max - min) / n_bins.
46+
47+
zero_intervals_ : list of tuples of length n_continuous_features
48+
A list of 2-tuples that represents the intervals for which a number
49+
would be discretized to zero.
50+
51+
searched_points_ : array, shape [numBins - 2, n_continuous_features]
52+
An array of cut points used for discretization.
53+
4454
n_features_ : int
4555
The number of features from the original dataset.
4656
57+
n_continuous_features_ : int
58+
The number of continuous features.
59+
4760
continuous_features_ : list
4861
Contains the indices of continuous columns in the dataset.
4962
This list is sorted.
5063
5164
Example
5265
-------
53-
>>> # X has two examples, with X[:, 2] as categorical
54-
>>> X = [[0, 1, 0, 4], \
55-
[6, 7, 1, 5]]
66+
X has two examples, with X[:, 2] as categorical
67+
68+
>>> X = [[-3, 1, 0, 5 ], \
69+
[-2, 7, 8, 4.5], \
70+
[3, 3, 1, 4 ]]
5671
>>> from sklearn.preprocessing import Discretizer
57-
>>> discretizer = Discretizer(n_bins=3, categorical_features=[2])
72+
>>> discretizer = Discretizer(n_bins=4, categorical_features=[2])
5873
>>> discretizer.fit(X) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
5974
Discretizer(...)
6075
>>> discretizer.cut_points_ # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
61-
array([[ 2. , 3. , 4.3333...],
62-
[ 4. , 5. , 4.6666...]])
63-
>>> # Transforming X will move the categorical features to the last indices
64-
>>> discretizer.transform(X)
65-
array([[0, 0, 0, 0],
66-
[2, 2, 2, 1]])
76+
array([[-1.5 , 2.5 , 4.25],
77+
[ 0. , 4. , 4.5 ],
78+
[ 1.5 , 5.5 , 4.75]])
79+
80+
>>> discretizer.zero_intervals_
81+
[(0.0, 1.5), (-inf, 2.5), (-inf, 4.25)]
82+
83+
>>> discretizer.searched_points_ # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
84+
array([[-1.5 , 4. , 4.5 ],
85+
[ 1.5 , 5.5 , 4.75]])
86+
87+
Transforming X will move the categorical features to the last indices
88+
89+
>>> discretizer.transform(X) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
90+
array([[ 0., 0., 3., 0.],
91+
[ 0., 3., 1., 8.],
92+
[ 3., 1., 0., 1.]])
6793
"""
6894
sparse_formats = ['csr', 'csc']
6995

@@ -74,16 +100,16 @@ def __init__(self, n_bins=2, categorical_features=None):
74100
# Attributes
75101
self.min_ = None
76102
self.max_ = None
77-
self.cut_points_ = None
78103
self.n_features_ = None
79104
self.continuous_features_ = categorical_features
105+
self.searched_points_ = None
106+
self.zero_intervals_ = None
80107

81108
def _set_continuous_features(self):
82109
"""Sets a boolean array that determines which columns are
83110
continuous features.
84111
"""
85112
if self.categorical_features is None:
86-
self.n_continuous_features_ = self.n_features_
87113
self.continuous_features_ = range(self.n_features_)
88114
return
89115

@@ -102,8 +128,6 @@ def _set_continuous_features(self):
102128
"categorical indices. Input was: {0}" \
103129
.format(self.continuous_features_))
104130

105-
self.n_continuous_features_ = len(self.continuous_features_)
106-
107131
def fit(self, X, y=None):
108132
"""Finds the intervals of interest from the input data.
109133
@@ -138,13 +162,59 @@ def fit(self, X, y=None):
138162
else:
139163
self.min_ = continuous.min(axis=0)
140164
self.max_ = continuous.max(axis=0)
141-
cut_points = list()
165+
166+
searched_points = list()
167+
zero_intervals = list()
168+
142169
for min_, max_ in zip(self.min_, self.max_):
170+
143171
points = np.linspace(min_, max_, num=self.n_bins, endpoint=False)[1:]
144-
cut_points.append(points.reshape(-1, 1))
145-
self.cut_points_ = np.hstack(cut_points)
172+
173+
# Get index of where zero goes. Omit this index in
174+
# the rebuilt array
175+
# TODO: Watch out for when there is only two intervals
176+
zero_index = np.searchsorted(points, 0)
177+
178+
if zero_index == 0:
179+
zero_int = (-np.inf, points[zero_index])
180+
else:
181+
zero_int = (points[zero_index], points[zero_index + 1])
182+
zero_intervals.append(zero_int)
183+
184+
searched = np.hstack((points[:zero_index], points[zero_index + 1:]))
185+
searched_points.append(searched.reshape(-1, 1))
186+
187+
## TODO: CHANGE HERE
188+
self.searched_points_ = np.hstack(searched_points)
189+
self.zero_intervals_ = zero_intervals
146190
return self
147191

192+
@property
193+
def n_continuous_features_(self):
194+
if not self.continuous_features_:
195+
return None
196+
return len(self.continuous_features_)
197+
198+
@property
199+
def cut_points_(self):
200+
if not self.zero_intervals_:
201+
return None
202+
zero_intervals = self.zero_intervals_
203+
searched_points = self.searched_points_
204+
205+
cut_points = list()
206+
for (lower, upper), col in zip(zero_intervals, searched_points.T):
207+
zero_index = np.searchsorted(col, lower)
208+
209+
# Case when lower == -np.inf
210+
if zero_index == 0:
211+
cut_column = np.insert(col, 0, upper)
212+
else:
213+
cut_column = np.insert(col, zero_index, lower)
214+
cut_points.append(cut_column.reshape(-1, 1))
215+
cut_points = np.hstack(cut_points)
216+
return cut_points
217+
148218
def _check_sparse(self, X, ravel=True):
149219
if ravel:
150220
return X.toarray().ravel() if sp.issparse(X) else X
@@ -177,6 +247,7 @@ def transform(self, X, y=None):
177247

178248
continuous = X[:, self.continuous_features_]
179249
discretized = list()
250+
180251
for cut_points, cont in zip(self.cut_points_.T, continuous.T):
181252
cont = self._check_sparse(cont) # np.searchsorted can't handle sparse
182253
dis_features = np.searchsorted(cut_points, cont)

0 commit comments

Comments
 (0)