Skip to content

Commit f11f657

Browse files
committed
iforest behaviour param
fix + cosmit cosmit common test update examples cosmit + fix travis attribute error instead of value + mask warning in test whatsnew
1 parent f037200 commit f11f657

File tree

8 files changed

+100
-17
lines changed

8 files changed

+100
-17
lines changed

benchmarks/bench_isolation_forest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ def print_outlier_ratio(y):
119119
y_test = y[n_samples_train:]
120120

121121
print('--- Fitting the IsolationForest estimator...')
122-
model = IsolationForest(n_jobs=-1, random_state=random_state)
122+
model = IsolationForest(behaviour='new', n_jobs=-1,
123+
random_state=random_state)
123124
tstart = time()
124125
model.fit(X_train)
125126
fit_time = time() - tstart

doc/whats_new/v0.20.rst

+9
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,15 @@ Outlier Detection models
861861
``raw_values`` parameter is deprecated as the shifted Mahalanobis distance
862862
will be always returned in 0.22. :issue:`9015` by `Nicolas Goix`_.
863863

864+
- A ``behaviour`` parameter has been introduced in :class:`ensemble.IsolationForest`
865+
to ensure backward compatibility.
866+
In the old behaviour, the ``decision_function`` is independent of the ``contamination``
867+
parameter. A threshold attribute depending on the ``contamination`` parameter is thus
868+
used.
869+
In the new behaviour the ``decision_function`` is dependent on the ``contamination``
870+
parameter, in such a way that 0 becomes its natural threshold to detect outliers.
871+
:issue:`11553` by `Nicolas Goix`_.
872+
864873
Covariance
865874

866875
- The :func:`covariance.graph_lasso`, :class:`covariance.GraphLasso` and

examples/covariance/plot_outlier_detection.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
"""
3333

3434
import numpy as np
35-
from scipy import stats
3635
import matplotlib.pyplot as plt
3736
import matplotlib.font_manager
3837

@@ -58,7 +57,8 @@
5857
"One-Class SVM": svm.OneClassSVM(nu=0.95 * outliers_fraction + 0.05,
5958
kernel="rbf", gamma=0.1),
6059
"Robust covariance": EllipticEnvelope(contamination=outliers_fraction),
61-
"Isolation Forest": IsolationForest(max_samples=n_samples,
60+
"Isolation Forest": IsolationForest(behaviour='new',
61+
max_samples=n_samples,
6262
contamination=outliers_fraction,
6363
random_state=rng),
6464
"Local Outlier Factor": LocalOutlierFactor(

examples/ensemble/plot_isolation_forest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040
X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
4141

4242
# fit the model
43-
clf = IsolationForest(max_samples=100, random_state=rng, contamination='auto')
43+
clf = IsolationForest(behaviour='new', max_samples=100,
44+
random_state=rng, contamination='auto')
4445
clf.fit(X_train)
4546
y_pred_train = clf.predict(X_train)
4647
y_pred_test = clf.predict(X_test)

examples/plot_anomaly_comparison.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
("Robust covariance", EllipticEnvelope(contamination=outliers_fraction)),
5555
("One-Class SVM", svm.OneClassSVM(nu=outliers_fraction, kernel="rbf",
5656
gamma=0.1)),
57-
("Isolation Forest", IsolationForest(contamination=outliers_fraction,
57+
("Isolation Forest", IsolationForest(behaviour='new',
58+
contamination=outliers_fraction,
5859
random_state=42)),
5960
("Local Outlier Factor", LocalOutlierFactor(
6061
n_neighbors=35, contamination=outliers_fraction))]

sklearn/ensemble/iforest.py

+48-10
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,15 @@ class IsolationForest(BaseBagging, OutlierMixin):
8989
The number of jobs to run in parallel for both `fit` and `predict`.
9090
If -1, then the number of jobs is set to the number of cores.
9191
92+
behaviour: str, optional (default='old')
93+
Accepted values are 'old' or 'new'. Behaviour of the decision_function.
94+
Default "behaviour" parameter will change to "new" in version 0.22.
95+
Passing behaviour="new" makes the decision_function change to match
96+
other anomaly detection algorithm API, as explained in details in the
97+
offset_ attribute documentation. Basically, the decision_function
98+
becomes dependent on the contamination parameter, in such a way that
99+
0 becomes its natural threshold to detect outliers.
100+
92101
random_state : int, RandomState instance or None, optional (default=None)
93102
If int, random_state is the seed used by the random number generator;
94103
If RandomState instance, random_state is the random number generator;
@@ -114,12 +123,16 @@ class IsolationForest(BaseBagging, OutlierMixin):
114123
offset_ : float
115124
Offset used to define the decision function from the raw scores.
116125
We have the relation: ``decision_function = score_samples - offset_``.
126+
Assuming behaviour == 'new', offset_ is defined as follows.
117127
When the contamination parameter is set to "auto", the offset is equal
118128
to -0.5 as the scores of inliers are close to 0 and the scores of
119129
outliers are close to -1. When a contamination parameter different
120130
than "auto" is provided, the offset is defined in such a way we obtain
121131
the expected number of outliers (samples with decision function < 0)
122132
in training.
133+
Assuming the behaviour parameter is set to 'old', we always have
134+
offset_ = -0.5, making the decision function independent from the
135+
contamination parameter.
123136
124137
References
125138
----------
@@ -138,6 +151,7 @@ def __init__(self,
138151
max_features=1.,
139152
bootstrap=False,
140153
n_jobs=1,
154+
behaviour='old',
141155
random_state=None,
142156
verbose=0):
143157
super(IsolationForest, self).__init__(
@@ -154,8 +168,17 @@ def __init__(self,
154168
n_jobs=n_jobs,
155169
random_state=random_state,
156170
verbose=verbose)
171+
172+
self.behaviour = behaviour
157173
self.contamination = contamination
158174

175+
if behaviour == 'old':
176+
warnings.warn('Default "behaviour" parameter will change to "new" '
177+
'in version 0.22. Passing behaviour="new" makes '
178+
'IsolationForest decision_function change to match '
179+
'other anomaly detection algorithm API.',
180+
FutureWarning)
181+
159182
def _set_oob_score(self, X, y):
160183
raise NotImplementedError("OOB score not supported by iforest")
161184

@@ -226,16 +249,29 @@ def fit(self, X, y=None, sample_weight=None):
226249
max_depth=max_depth,
227250
sample_weight=sample_weight)
228251

252+
if self.behaviour == 'old':
253+
# in this case, decision_function = 0.5 + self.score_samples(X):
254+
if self._contamination == "auto":
255+
raise ValueError("contamination parameter cannot be set to "
256+
"'auto' when behaviour == 'old'.")
257+
258+
self.offset_ = -0.5
259+
self._threshold_ = sp.stats.scoreatpercentile(
260+
self.decision_function(X), 100. * self._contamination)
261+
262+
return self
263+
264+
# else, self.behaviour == 'new':
229265
if self._contamination == "auto":
230266
# 0.5 plays a special role as described in the original paper.
231267
# we take the opposite as we consider the opposite of their score.
232268
self.offset_ = -0.5
233-
# need to save (depreciated) threshold_ in this case:
234-
self._threshold_ = sp.stats.scoreatpercentile(
235-
self.score_samples(X), 100. * 0.1)
236-
else:
237-
self.offset_ = sp.stats.scoreatpercentile(
238-
self.score_samples(X), 100. * self._contamination)
269+
return self
270+
271+
# else, define offset_ wrt contamination parameter, so that the
272+
# threshold_ attribute is implicitly 0 and is not needed anymore:
273+
self.offset_ = sp.stats.scoreatpercentile(
274+
self.score_samples(X), 100. * self._contamination)
239275

240276
return self
241277

@@ -258,7 +294,8 @@ def predict(self, X):
258294
check_is_fitted(self, ["offset_"])
259295
X = check_array(X, accept_sparse='csr')
260296
is_inlier = np.ones(X.shape[0], dtype=int)
261-
is_inlier[self.decision_function(X) < 0] = -1
297+
threshold = self.threshold_ if self.behaviour == 'old' else 0
298+
is_inlier[self.decision_function(X) < threshold] = -1
262299
return is_inlier
263300

264301
def decision_function(self, X):
@@ -359,11 +396,12 @@ def score_samples(self, X):
359396

360397
@property
361398
def threshold_(self):
399+
if self.behaviour != 'old':
400+
raise AttributeError("threshold_ attribute does not exist when "
401+
"behaviour != 'old'")
362402
warnings.warn("threshold_ attribute is deprecated in 0.20 and will"
363403
" be removed in 0.22.", DeprecationWarning)
364-
if self.contamination == 'auto':
365-
return self._threshold_
366-
return self.offset_
404+
return self._threshold_
367405

368406

369407
def _average_path_length(n_samples_leaf):

sklearn/ensemble/tests/test_iforest.py

+28-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.utils.testing import assert_array_equal
1616
from sklearn.utils.testing import assert_array_almost_equal
1717
from sklearn.utils.testing import assert_raises
18+
from sklearn.utils.testing import assert_raises_regex
1819
from sklearn.utils.testing import assert_warns_message
1920
from sklearn.utils.testing import assert_equal
2021
from sklearn.utils.testing import assert_greater
@@ -47,6 +48,7 @@
4748
boston.target = boston.target[perm]
4849

4950

51+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
5052
def test_iforest():
5153
"""Check Isolation Forest for various parameter settings."""
5254
X_train = np.array([[0, 1], [1, 2]])
@@ -63,6 +65,7 @@ def test_iforest():
6365

6466

6567
@pytest.mark.filterwarnings('ignore:default contamination')
68+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
6669
def test_iforest_sparse():
6770
"""Check IForest for various parameter settings on sparse input."""
6871
rng = check_random_state(0)
@@ -91,6 +94,7 @@ def test_iforest_sparse():
9194

9295

9396
@pytest.mark.filterwarnings('ignore:default contamination')
97+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
9498
def test_iforest_error():
9599
"""Test that it gives proper exception on deficient input."""
96100
X = iris.data
@@ -128,6 +132,11 @@ def test_iforest_error():
128132
# test X_test n_features match X_train one:
129133
assert_raises(ValueError, IsolationForest().fit(X).predict, X[:, 1:])
130134

135+
# test threshold_ attribute error when behaviour is not old:
136+
msg = "threshold_ attribute does not exist when behaviour != 'old'"
137+
assert_raises_regex(AttributeError, msg, getattr,
138+
IsolationForest(behaviour='new'), 'threshold_')
139+
131140

132141
@pytest.mark.filterwarnings('ignore:default contamination')
133142
def test_recalculate_max_depth():
@@ -155,6 +164,7 @@ def test_max_samples_attribute():
155164

156165

157166
@pytest.mark.filterwarnings('ignore:default contamination')
167+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
158168
def test_iforest_parallel_regression():
159169
"""Check parallel regression."""
160170
rng = check_random_state(0)
@@ -204,13 +214,15 @@ def test_iforest_performance():
204214
assert_greater(roc_auc_score(y_test, y_pred), 0.98)
205215

206216

217+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
207218
def test_iforest_works():
208219
# toy sample (the last two samples are outliers)
209220
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [-4, 7]]
210221

211222
# Test IsolationForest
212223
for contamination in [0.25, "auto"]:
213-
clf = IsolationForest(random_state=rng, contamination=contamination)
224+
clf = IsolationForest(behaviour='new', random_state=rng,
225+
contamination=contamination)
214226
clf.fit(X)
215227
decision_func = - clf.decision_function(X)
216228
pred = clf.predict(X)
@@ -228,6 +240,7 @@ def test_max_samples_consistency():
228240

229241

230242
@pytest.mark.filterwarnings('ignore:default contamination')
243+
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
231244
def test_iforest_subsampled_features():
232245
# It tests non-regression for #5732 which failed at predict.
233246
rng = check_random_state(0)
@@ -274,8 +287,21 @@ def test_deprecation():
274287
'in version 0.22 to "auto"',
275288
clf.fit, X)
276289

277-
clf = IsolationForest(contamination='auto').fit(X)
290+
assert_warns_message(FutureWarning,
291+
'Default "behaviour" parameter will change to "new" '
292+
'in version 0.22',
293+
IsolationForest, )
294+
295+
clf = IsolationForest().fit(X)
278296
assert_warns_message(DeprecationWarning,
279297
"threshold_ attribute is deprecated in 0.20 and will"
280298
" be removed in 0.22.",
281299
getattr, clf, "threshold_")
300+
301+
302+
def test_behaviour_param():
303+
X_train = [[1, 1], [1, 2], [2, 1]]
304+
clf1 = IsolationForest(behaviour='old').fit(X_train)
305+
clf2 = IsolationForest(behaviour='new', contamination='auto').fit(X_train)
306+
assert_array_equal(clf1.decision_function([[2., 2.]]),
307+
clf2.decision_function([[2., 2.]]))

sklearn/utils/estimator_checks.py

+7
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,13 @@ def set_checking_parameters(estimator):
366366
if estimator.__class__.__name__ == "TheilSenRegressor":
367367
estimator.max_subpopulation = 100
368368

369+
if estimator.__class__.__name__ == "IsolationForest":
370+
# XXX to be removed in 0.22.
371+
# this is used because the old IsolationForest does not
372+
# respect the outlier detection API and thus and does not
373+
# pass the outlier detection common tests.
374+
estimator.set_params(behaviour='new')
375+
369376
if isinstance(estimator, BaseRandomProjection):
370377
# Due to the jl lemma and often very few samples, the number
371378
# of components of the random matrix projection will be probably

0 commit comments

Comments
 (0)