Skip to content

ENH Add dtype preservation to LocalOutlierFactor #22665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,10 @@ Changelog
:class:`neighbors.RadiusNeighborsTransformer`.
:pr:`24075` by :user:`Valentin Laurent <Valentin-Laurent>`.

- |Enhancement| :class:`neighbors.LocalOutlierFactor` now preserves
dtype for `numpy.float32` inputs.
:pr:`22665` by :user:`Julien Jerphanion <jjerphan>`.

:mod:`sklearn.pipeline`
.......................

Expand Down
20 changes: 19 additions & 1 deletion sklearn/neighbors/_lof.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,12 @@ def fit(self, X, y=None):
n_neighbors=self.n_neighbors_
)

if self._fit_X.dtype == np.float32:
self._distances_fit_X_ = self._distances_fit_X_.astype(
self._fit_X.dtype,
copy=False,
)

self._lrd = self._local_reachability_density(
self._distances_fit_X_, _neighbors_indices_fit_X_
)
Expand Down Expand Up @@ -462,7 +468,14 @@ def score_samples(self, X):
distances_X, neighbors_indices_X = self.kneighbors(
X, n_neighbors=self.n_neighbors_
)
X_lrd = self._local_reachability_density(distances_X, neighbors_indices_X)

if X.dtype == np.float32:
distances_X = distances_X.astype(X.dtype, copy=False)

X_lrd = self._local_reachability_density(
distances_X,
neighbors_indices_X,
)

lrd_ratios_array = self._lrd[neighbors_indices_X] / X_lrd[:, np.newaxis]

Expand Down Expand Up @@ -495,3 +508,8 @@ def _local_reachability_density(self, distances_X, neighbors_indices):

# 1e-10 to avoid `nan' when nb of duplicates > n_neighbors_:
return 1.0 / (np.mean(reach_dist_array, axis=1) + 1e-10)

def _more_tags(self):
return {
"preserves_dtype": [np.float64, np.float32],
}
137 changes: 101 additions & 36 deletions sklearn/neighbors/tests/test_lof.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from sklearn import neighbors
import re
import pytest
from numpy.testing import assert_array_equal

from sklearn import metrics
from sklearn.metrics import roc_auc_score

from sklearn.utils import check_random_state
from sklearn.utils._testing import assert_array_almost_equal
from sklearn.utils._testing import assert_allclose
from sklearn.utils._testing import assert_array_equal
from sklearn.utils.estimator_checks import check_outlier_corruption
from sklearn.utils.estimator_checks import parametrize_with_checks

Expand All @@ -32,9 +32,12 @@
iris.target = iris.target[perm]


def test_lof():
def test_lof(global_dtype):
# Toy sample (the last two samples are outliers):
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [5, 3], [-4, 2]]
X = np.asarray(
[[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [5, 3], [-4, 2]],
dtype=global_dtype,
)

# Test LocalOutlierFactor:
clf = neighbors.LocalOutlierFactor(n_neighbors=5)
Expand All @@ -46,18 +49,21 @@ def test_lof():

# Assert predict() works:
clf = neighbors.LocalOutlierFactor(contamination=0.25, n_neighbors=5).fit(X)
assert_array_equal(clf._predict(), 6 * [1] + 2 * [-1])
assert_array_equal(clf.fit_predict(X), 6 * [1] + 2 * [-1])
expected_predictions = 6 * [1] + 2 * [-1]
assert_array_equal(clf._predict(), expected_predictions)
assert_array_equal(clf.fit_predict(X), expected_predictions)


def test_lof_performance():
def test_lof_performance(global_dtype):
# Generate train/test data
rng = check_random_state(2)
X = 0.3 * rng.randn(120, 2)
X = 0.3 * rng.randn(120, 2).astype(global_dtype, copy=False)
X_train = X[:100]

# Generate some abnormal novel observations
X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
X_outliers = rng.uniform(low=-4, high=4, size=(20, 2)).astype(
global_dtype, copy=False
)
X_test = np.r_[X[100:], X_outliers]
y_test = np.array([0] * 20 + [1] * 20)

Expand All @@ -71,32 +77,32 @@ def test_lof_performance():
assert roc_auc_score(y_test, y_pred) > 0.99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be testing somewhere that y_pred is preserving dtype also.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do it later for decision_function and score_samples

Copy link
Member Author

@jjerphan jjerphan Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your last message as indicating that we are not needing and assertion for dtype preservation because this is done in test_score_samples. Is this what you meant?



def test_lof_values():
def test_lof_values(global_dtype):
# toy samples:
X_train = [[1, 1], [1, 2], [2, 1]]
X_train = np.asarray([[1, 1], [1, 2], [2, 1]], dtype=global_dtype)
clf1 = neighbors.LocalOutlierFactor(
n_neighbors=2, contamination=0.1, novelty=True
).fit(X_train)
clf2 = neighbors.LocalOutlierFactor(n_neighbors=2, novelty=True).fit(X_train)
s_0 = 2.0 * sqrt(2.0) / (1.0 + sqrt(2.0))
s_1 = (1.0 + sqrt(2)) * (1.0 / (4.0 * sqrt(2.0)) + 1.0 / (2.0 + 2.0 * sqrt(2)))
# check predict()
assert_array_almost_equal(-clf1.negative_outlier_factor_, [s_0, s_1, s_1])
assert_array_almost_equal(-clf2.negative_outlier_factor_, [s_0, s_1, s_1])
assert_allclose(-clf1.negative_outlier_factor_, [s_0, s_1, s_1])
assert_allclose(-clf2.negative_outlier_factor_, [s_0, s_1, s_1])
# check predict(one sample not in train)
assert_array_almost_equal(-clf1.score_samples([[2.0, 2.0]]), [s_0])
assert_array_almost_equal(-clf2.score_samples([[2.0, 2.0]]), [s_0])
assert_allclose(-clf1.score_samples([[2.0, 2.0]]), [s_0])
assert_allclose(-clf2.score_samples([[2.0, 2.0]]), [s_0])
# check predict(one sample already in train)
assert_array_almost_equal(-clf1.score_samples([[1.0, 1.0]]), [s_1])
assert_array_almost_equal(-clf2.score_samples([[1.0, 1.0]]), [s_1])
assert_allclose(-clf1.score_samples([[1.0, 1.0]]), [s_1])
assert_allclose(-clf2.score_samples([[1.0, 1.0]]), [s_1])


def test_lof_precomputed(random_state=42):
def test_lof_precomputed(global_dtype, random_state=42):
"""Tests LOF with a distance matrix."""
# Note: smaller samples may result in spurious test success
rng = np.random.RandomState(random_state)
X = rng.random_sample((10, 4))
Y = rng.random_sample((3, 4))
X = rng.random_sample((10, 4)).astype(global_dtype, copy=False)
Y = rng.random_sample((3, 4)).astype(global_dtype, copy=False)
DXX = metrics.pairwise_distances(X, metric="euclidean")
DYX = metrics.pairwise_distances(Y, X, metric="euclidean")
# As a feature matrix (n_samples by n_features)
Expand All @@ -113,8 +119,8 @@ def test_lof_precomputed(random_state=42):
pred_D_X = lof_D._predict()
pred_D_Y = lof_D.predict(DYX)

assert_array_almost_equal(pred_X_X, pred_D_X)
assert_array_almost_equal(pred_X_Y, pred_D_Y)
assert_allclose(pred_X_X, pred_D_X)
assert_allclose(pred_X_Y, pred_D_Y)


def test_n_neighbors_attribute():
Expand All @@ -129,23 +135,29 @@ def test_n_neighbors_attribute():
assert clf.n_neighbors_ == X.shape[0] - 1


def test_score_samples():
X_train = [[1, 1], [1, 2], [2, 1]]
def test_score_samples(global_dtype):
X_train = np.asarray([[1, 1], [1, 2], [2, 1]], dtype=global_dtype)
X_test = np.asarray([[2.0, 2.0]], dtype=global_dtype)
clf1 = neighbors.LocalOutlierFactor(
n_neighbors=2, contamination=0.1, novelty=True
).fit(X_train)
clf2 = neighbors.LocalOutlierFactor(n_neighbors=2, novelty=True).fit(X_train)
assert_array_equal(
clf1.score_samples([[2.0, 2.0]]),
clf1.decision_function([[2.0, 2.0]]) + clf1.offset_,
)
assert_array_equal(
clf2.score_samples([[2.0, 2.0]]),
clf2.decision_function([[2.0, 2.0]]) + clf2.offset_,

clf1_scores = clf1.score_samples(X_test)
clf1_decisions = clf1.decision_function(X_test)

clf2_scores = clf2.score_samples(X_test)
clf2_decisions = clf2.decision_function(X_test)

assert_allclose(
clf1_scores,
clf1_decisions + clf1.offset_,
)
assert_array_equal(
clf1.score_samples([[2.0, 2.0]]), clf2.score_samples([[2.0, 2.0]])
assert_allclose(
clf2_scores,
clf2_decisions + clf2.offset_,
)
assert_allclose(clf1_scores, clf2_scores)


def test_novelty_errors():
Expand All @@ -167,10 +179,10 @@ def test_novelty_errors():
getattr(clf, "fit_predict")


def test_novelty_training_scores():
def test_novelty_training_scores(global_dtype):
# check that the scores of the training samples are still accessible
# when novelty=True through the negative_outlier_factor_ attribute
X = iris.data
X = iris.data.astype(global_dtype)

# fit with novelty=False
clf_1 = neighbors.LocalOutlierFactor()
Expand All @@ -182,7 +194,7 @@ def test_novelty_training_scores():
clf_2.fit(X)
scores_2 = clf_2.negative_outlier_factor_

assert_array_almost_equal(scores_1, scores_2)
assert_allclose(scores_1, scores_2)


def test_hasattr_prediction():
Expand Down Expand Up @@ -244,3 +256,56 @@ def test_sparse():

lof = neighbors.LocalOutlierFactor(novelty=False)
lof.fit_predict(X)


@pytest.mark.parametrize("algorithm", ["auto", "ball_tree", "kd_tree", "brute"])
@pytest.mark.parametrize("novelty", [True, False])
@pytest.mark.parametrize("contamination", [0.5, "auto"])
def test_lof_input_dtype_preservation(global_dtype, algorithm, contamination, novelty):
"""Check that the fitted attributes are stored using the data type of X."""
X = iris.data.astype(global_dtype, copy=False)

iso = neighbors.LocalOutlierFactor(
n_neighbors=5, algorithm=algorithm, contamination=contamination, novelty=novelty
)
iso.fit(X)

assert iso.negative_outlier_factor_.dtype == global_dtype

for method in ("score_samples", "decision_function"):
if hasattr(iso, method):
y_pred = getattr(iso, method)(X)
assert y_pred.dtype == global_dtype


@pytest.mark.parametrize("algorithm", ["auto", "ball_tree", "kd_tree", "brute"])
@pytest.mark.parametrize("novelty", [True, False])
@pytest.mark.parametrize("contamination", [0.5, "auto"])
def test_lof_dtype_equivalence(algorithm, novelty, contamination):
"""Check the equivalence of the results with 32 and 64 bits input."""

inliers = iris.data[:50] # setosa iris are really distinct from others
outliers = iris.data[-5:] # virginica will be considered as outliers
# lower the precision of the input data to check that we have an equivalence when
# making the computation in 32 and 64 bits.
X = np.concatenate([inliers, outliers], axis=0).astype(np.float32)

lof_32 = neighbors.LocalOutlierFactor(
algorithm=algorithm, novelty=novelty, contamination=contamination
)
X_32 = X.astype(np.float32, copy=True)
lof_32.fit(X_32)

lof_64 = neighbors.LocalOutlierFactor(
algorithm=algorithm, novelty=novelty, contamination=contamination
)
X_64 = X.astype(np.float64, copy=True)
lof_64.fit(X_64)

assert_allclose(lof_32.negative_outlier_factor_, lof_64.negative_outlier_factor_)

for method in ("score_samples", "decision_function", "predict", "fit_predict"):
if hasattr(lof_32, method):
y_pred_32 = getattr(lof_32, method)(X_32)
y_pred_64 = getattr(lof_64, method)(X_64)
assert_allclose(y_pred_32, y_pred_64, atol=0.0002)