Skip to content

Commit ec84a00

Browse files
committed
Merge pull request #4302 from amueller/isotonic_regression_duplicate_fixes
[MRG+1] Isotonic regression duplicate fixes
2 parents 6c40cd2 + c1fa16f commit ec84a00

File tree

7 files changed

+1925
-1158
lines changed

7 files changed

+1925
-1158
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,9 @@ Bug fixes
333333
to make it consistent with the documentation and
334334
``decision_function``. By Artem Sobolev.
335335

336+
- Fixed handling of ties in :class:`isotonic.IsotonicRegression`.
337+
We now use the weighted average of targets (secondary method). By
338+
`Andreas Müller`_ and `Michael Bommarito <http://bommaritollc.com/>`_.
336339

337340
API changes summary
338341
-------------------

sklearn/_isotonic.c

Lines changed: 1757 additions & 1080 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/_isotonic.pyx

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,56 @@ def _isotonic_regression(np.ndarray[DOUBLE, ndim=1] y,
6969
break
7070

7171
return solution
72+
73+
74+
@cython.boundscheck(False)
75+
@cython.wraparound(False)
76+
@cython.cdivision(True)
77+
def _make_unique(np.ndarray[dtype=np.float64_t] X,
78+
np.ndarray[dtype=np.float64_t] y,
79+
np.ndarray[dtype=np.float64_t] sample_weights):
80+
"""Average targets for duplicate X, drop duplicates.
81+
82+
Aggregates duplicate X values into a single X value where
83+
the target y is a (sample_weighted) average of the individual
84+
targets.
85+
86+
Assumes that X is ordered, so that all duplicates follow each other.
87+
"""
88+
unique_values = len(np.unique(X))
89+
if unique_values == len(X):
90+
return X, y, sample_weights
91+
cdef np.ndarray[dtype=np.float64_t] y_out = np.empty(unique_values)
92+
cdef np.ndarray[dtype=np.float64_t] x_out = np.empty(unique_values)
93+
cdef np.ndarray[dtype=np.float64_t] weights_out = np.empty(unique_values)
94+
95+
cdef float current_x = X[0]
96+
cdef float current_y = 0
97+
cdef float current_weight = 0
98+
cdef float y_old = 0
99+
cdef int i = 0
100+
cdef int current_count = 0
101+
cdef int j
102+
cdef float x
103+
cdef int n_samples = len(X)
104+
for j in range(n_samples):
105+
x = X[j]
106+
if x != current_x:
107+
# next unique value
108+
x_out[i] = current_x
109+
weights_out[i] = current_weight / current_count
110+
y_out[i] = current_y / current_weight
111+
i += 1
112+
current_x = x
113+
current_weight = sample_weights[j]
114+
current_y = y[j] * sample_weights[j]
115+
current_count = 1
116+
else:
117+
current_weight += sample_weights[j]
118+
current_y += y[j] * sample_weights[j]
119+
current_count += 1
120+
121+
x_out[i] = current_x
122+
weights_out[i] = current_weight / current_count
123+
y_out[i] = current_y / current_weight
124+
return x_out, y_out, weights_out

sklearn/calibration.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from .base import BaseEstimator, ClassifierMixin, RegressorMixin, clone
2020
from .preprocessing import LabelBinarizer
21-
from .utils import check_random_state
2221
from .utils import check_X_y, check_array, indexable, column_or_1d
2322
from .utils.validation import check_is_fitted
2423
from .isotonic import IsotonicRegression
@@ -59,9 +58,6 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
5958
If "prefit" is passed, it is assumed that base_estimator has been
6059
fitted already and all data is used for calibration.
6160
62-
random_state : int, RandomState instance or None (default=None)
63-
Used to randomly break ties when method is 'isotonic'.
64-
6561
Attributes
6662
----------
6763
classes_ : array, shape (n_classes)
@@ -86,12 +82,10 @@ class CalibratedClassifierCV(BaseEstimator, ClassifierMixin):
8682
.. [4] Predicting Good Probabilities with Supervised Learning,
8783
A. Niculescu-Mizil & R. Caruana, ICML 2005
8884
"""
89-
def __init__(self, base_estimator=None, method='sigmoid', cv=3,
90-
random_state=None):
85+
def __init__(self, base_estimator=None, method='sigmoid', cv=3):
9186
self.base_estimator = base_estimator
9287
self.method = method
9388
self.cv = cv
94-
self.random_state = random_state
9589

9690
def fit(self, X, y, sample_weight=None):
9791
"""Fit the calibrated model
@@ -116,7 +110,6 @@ def fit(self, X, y, sample_weight=None):
116110
X, y = indexable(X, y)
117111
lb = LabelBinarizer().fit(y)
118112
self.classes_ = lb.classes_
119-
random_state = check_random_state(self.random_state)
120113

121114
# Check that we each cross-validation fold can have at least one
122115
# example per class
@@ -136,7 +129,7 @@ def fit(self, X, y, sample_weight=None):
136129

137130
if self.cv == "prefit":
138131
calibrated_classifier = _CalibratedClassifier(
139-
base_estimator, method=self.method, random_state=random_state)
132+
base_estimator, method=self.method)
140133
if sample_weight is not None:
141134
calibrated_classifier.fit(X, y, sample_weight)
142135
else:
@@ -164,8 +157,7 @@ def fit(self, X, y, sample_weight=None):
164157
this_estimator.fit(X[train], y[train])
165158

166159
calibrated_classifier = _CalibratedClassifier(
167-
this_estimator, method=self.method,
168-
random_state=random_state)
160+
this_estimator, method=self.method)
169161
if sample_weight is not None:
170162
calibrated_classifier.fit(X[test], y[test],
171163
sample_weight[test])
@@ -242,9 +234,6 @@ class _CalibratedClassifier(object):
242234
corresponds to Platt's method or 'isotonic' which is a
243235
non-parameteric approach based on isotonic regression.
244236
245-
random_state : int, RandomState instance or None (default=None)
246-
Used to randomly break ties when method is 'isotonic'.
247-
248237
References
249238
----------
250239
.. [1] Obtaining calibrated probability estimates from decision trees
@@ -259,11 +248,9 @@ class _CalibratedClassifier(object):
259248
.. [4] Predicting Good Probabilities with Supervised Learning,
260249
A. Niculescu-Mizil & R. Caruana, ICML 2005
261250
"""
262-
def __init__(self, base_estimator, method='sigmoid',
263-
random_state=None):
251+
def __init__(self, base_estimator, method='sigmoid'):
264252
self.base_estimator = base_estimator
265253
self.method = method
266-
self.random_state = random_state
267254

268255
def _preproc(self, X):
269256
n_classes = len(self.classes_)
@@ -312,13 +299,6 @@ def fit(self, X, y, sample_weight=None):
312299
for k, this_df in zip(idx_pos_class, df.T):
313300
if self.method == 'isotonic':
314301
calibrator = IsotonicRegression(out_of_bounds='clip')
315-
# XXX: isotonic regression cannot deal correctly with
316-
# situations in which multiple inputs are identical but
317-
# have different outputs. Since this is not untypical
318-
# when calibrating, we add some small random jitter to
319-
# the inputs.
320-
jitter = self.random_state.normal(0, 1e-10, this_df.shape[0])
321-
this_df = this_df + jitter
322302
elif self.method == 'sigmoid':
323303
calibrator = _SigmoidCalibration()
324304
else:

sklearn/isotonic.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from scipy.stats import spearmanr
99
from .base import BaseEstimator, TransformerMixin, RegressorMixin
1010
from .utils import as_float_array, check_array, check_consistent_length
11-
from ._isotonic import _isotonic_regression
11+
from .utils.fixes import astype
12+
from ._isotonic import _isotonic_regression, _make_unique
1213
import warnings
1314
import math
1415

@@ -200,12 +201,24 @@ class IsotonicRegression(BaseEstimator, TransformerMixin, RegressorMixin):
200201
The stepwise interpolating function that covers the domain
201202
X_.
202203
204+
Notes
205+
-----
206+
Ties are broken using the secondary method from Leeuw, 1977.
207+
203208
References
204209
----------
205210
Isotonic Median Regression: A Linear Programming Approach
206211
Nilotpal Chakravarti
207212
Mathematics of Operations Research
208213
Vol. 14, No. 2 (May, 1989), pp. 303-308
214+
215+
Isotone Optimization in R : Pool-Adjacent-Violators
216+
Algorithm (PAVA) and Active Set Methods
217+
Leeuw, Hornik, Mair
218+
Journal of Statistical Software 2009
219+
220+
Correctness of Kruskal's algorithms for monotone regression with ties
221+
Leeuw, Psychometrica, 1977
209222
"""
210223
def __init__(self, y_min=None, y_max=None, increasing=True,
211224
out_of_bounds='nan'):
@@ -228,8 +241,12 @@ def _build_f(self, X, y):
228241
.format(self.out_of_bounds))
229242

230243
bounds_error = self.out_of_bounds == "raise"
231-
self.f_ = interpolate.interp1d(X, y, kind='slinear',
232-
bounds_error=bounds_error)
244+
if len(y) == 1:
245+
# single y, constant prediction
246+
self.f_ = lambda x: y.repeat(x.shape)
247+
else:
248+
self.f_ = interpolate.interp1d(X, y, kind='slinear',
249+
bounds_error=bounds_error)
233250

234251
def _build_y(self, X, y, sample_weight):
235252
"""Build the y_ IsotonicRegression."""
@@ -249,8 +266,13 @@ def _build_y(self, X, y, sample_weight):
249266

250267
order = np.lexsort((y, X))
251268
order_inv = np.argsort(order)
252-
self.X_ = as_float_array(X[order], copy=False)
253-
self.y_ = isotonic_regression(y[order], sample_weight, self.y_min,
269+
if sample_weight is None:
270+
sample_weight = np.ones(len(y))
271+
X, y, sample_weight = [astype(array[order], np.float64, copy=False)
272+
for array in [X, y, sample_weight]]
273+
unique_X, unique_y, unique_sample_weight = _make_unique(X, y, sample_weight)
274+
self.X_ = unique_X
275+
self.y_ = isotonic_regression(unique_y, unique_sample_weight, self.y_min,
254276
self.y_max, increasing=self.increasing_)
255277

256278
return order_inv
@@ -319,44 +341,6 @@ def transform(self, T):
319341
T = np.clip(T, self.X_min_, self.X_max_)
320342
return self.f_(T)
321343

322-
def fit_transform(self, X, y, sample_weight=None):
323-
"""Fit model and transform y by linear interpolation.
324-
325-
Parameters
326-
----------
327-
X : array-like, shape=(n_samples,)
328-
Training data.
329-
330-
y : array-like, shape=(n_samples,)
331-
Training target.
332-
333-
sample_weight : array-like, shape=(n_samples,), optional, default: None
334-
Weights. If set to None, all weights will be equal to 1 (equal
335-
weights).
336-
337-
Returns
338-
-------
339-
y_ : array, shape=(n_samples,)
340-
The transformed data.
341-
342-
Notes
343-
-----
344-
X doesn't influence the result of `fit_transform`. It is however stored
345-
for future use, as `transform` needs X to interpolate new input
346-
data.
347-
"""
348-
# Build y_
349-
order_inv = self._build_y(X, y, sample_weight)
350-
351-
# Handle the left and right bounds on X
352-
self.X_min_ = np.min(self.X_)
353-
self.X_max_ = np.max(self.X_)
354-
355-
# Build f_
356-
self._build_f(self.X_, self.y_)
357-
358-
return self.y_[order_inv]
359-
360344
def predict(self, T):
361345
"""Predict new data by linear interpolation.
362346

sklearn/tests/test_calibration.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ def test_sample_weight_warning():
114114

115115
for method in ['sigmoid', 'isotonic']:
116116
base_estimator = LinearSVC(random_state=42)
117-
calibrated_clf = CalibratedClassifierCV(base_estimator, method=method,
118-
random_state=42)
117+
calibrated_clf = CalibratedClassifierCV(base_estimator, method=method)
119118
# LinearSVC does not currently support sample weights but they
120119
# can still be used for the calibration step (with a warning)
121120
msg = "LinearSVC does not support sample_weight."

0 commit comments

Comments
 (0)