Skip to content

MNT Add estimator check for not calling __array_function__ #14702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Oct 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9fe7275
estimator check for not calling __array_function__
amueller Aug 20, 2019
76a1032
make sure y is a numpy array in column_or_1d
amueller Aug 20, 2019
f4b5017
allow "may_share_memory"
amueller Aug 20, 2019
66e4f78
actually be conservative so as not to break everything
amueller Aug 20, 2019
22e1a3e
don't use array_functions in the common tests accidentally
amueller Aug 20, 2019
537a93c
some more asarray
amueller Aug 21, 2019
92013f5
add test for sample_weights not an array
amueller Aug 21, 2019
32015ca
fix NotAnArray casting in _check_transformer
amueller Aug 21, 2019
b03fae5
allow 1d y in PLS
amueller Aug 21, 2019
3c95f15
fix error message
amueller Aug 21, 2019
4d2c751
don't care about order
amueller Aug 21, 2019
0668558
add tests for NotAnArray array_function asserts
amueller Aug 21, 2019
1689631
Merge branch 'master' into array_function_test
amueller Aug 26, 2019
73c3355
don't use pytest in this file
amueller Aug 26, 2019
46061bb
Merge branch 'master' into array_function_test
amueller Sep 9, 2019
7f8719f
Merge branch 'master' into array_function_test
amueller Sep 24, 2019
f4132cb
make diff smaller
amueller Sep 24, 2019
e66bdbf
fix typo in message
amueller Sep 24, 2019
8812faa
make data bigger so stackingclassifier can work....
amueller Sep 24, 2019
92c17b0
Merge branch 'master' into array_function_test
amueller Sep 25, 2019
e0f239e
don't use deprecated NotAnArray
amueller Sep 25, 2019
2b17749
Merge branch 'master' into array_function_test
amueller Sep 27, 2019
0d75bea
fix naive bayes sample weights array conversion
amueller Sep 27, 2019
6574e98
fix gaussian NB unique call
amueller Sep 27, 2019
f4b1f92
more deprecated NotAnArray
amueller Sep 27, 2019
60ada49
whatsnew
amueller Oct 7, 2019
1a09940
Update doc/whats_new/v0.22.rst
amueller Oct 8, 2019
a1d1362
Update doc/whats_new/v0.22.rst
amueller Oct 8, 2019
02a42bc
add NEP 18 links
amueller Oct 14, 2019
def3b81
Addressed Joel's comments
NicolasHug Oct 24, 2019
21640e3
Merge branch 'master' of github.com:scikit-learn/scikit-learn into pr…
NicolasHug Oct 24, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,12 @@ Changelog
Miscellaneous
.............

- |API| Scikit-learn now converts any input data structure implementing a
duck array to a numpy array (using ``__array__``) to ensure consistent
behavior instead of relying on ``__array_function__`` (see `NEP 18
<https://numpy.org/neps/nep-0018-array-function-protocol.html>`_).
:pr:`14702` by `Andreas Müller`_.

- |API| Replace manual checks with ``check_is_fitted``. Errors thrown when
using a non-fitted estimators are now more uniform.
:pr:`13013` by :user:`Agamemnon Krasoulis <agamemnonc>`.
Expand All @@ -709,6 +715,12 @@ These changes mostly affect library developers.
Such classifiers need to have the `binary_only=True` estimator tag.
:pr:`13875` by `Trevor Stephens`_.

- Estimators are expected to convert input data (``X``, ``y``,
``sample_weights``) to :class:`numpy.ndarray` and never call
``__array_function__`` on the original datatype that is passed (see `NEP 18
<https://numpy.org/neps/nep-0018-array-function-protocol.html>`_).
:pr:`14702` by `Andreas Müller`_.

- `requires_positive_X` estimator tag (for models that require
X to be non-negative) is now used by :meth:`utils.estimator_checks.check_estimator`
to make sure a proper error message is raised if X contains some negative entries.
Expand Down
1 change: 1 addition & 0 deletions sklearn/cross_decomposition/_pls_.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def transform(self, X, Y=None):
Xr = (X - self.x_mean_) / self.x_std_
x_scores = np.dot(Xr, self.x_weights_)
if Y is not None:
Y = check_array(Y, ensure_2d=False, dtype=np.float64)
if Y.ndim == 1:
Y = Y.reshape(-1, 1)
Yr = (Y - self.y_mean_) / self.y_std_
Expand Down
3 changes: 3 additions & 0 deletions sklearn/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def fit(self, X, y, sample_weight=None):
self.sparse_output_ = sp.issparse(y)

if not self.sparse_output_:
y = np.asarray(y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we not want to have a more complex atleast_1d instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean column_or_1d? I honestly don't know why we would need atleast_1d. We usually want scalars to be treated differently.

Copy link
Member Author

@amueller amueller Aug 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's also check_array(ensure_2d=False) which might be more suitable here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I tend to forget that we have both column_or_1d and atleast_1d. Ideally I think I'd rather have ensure_ndims=n in check_array maybe.

y = np.atleast_1d(y)

if y.ndim == 1:
Expand Down Expand Up @@ -470,6 +471,8 @@ def fit(self, X, y, sample_weight=None):
self.n_outputs_ = y.shape[1]

check_consistent_length(X, y, sample_weight)
if sample_weight is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this belongs to check_array too, doesn't it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean. Using check_array(sample_weights, ensure_2d=True) instead of asarray? We don't have any tests for NaN in sample weights, do we? I'm not sure how much I want to make this PR about adding way more checks to sample weights

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have _check_sample_weight in validation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that does a lot more though, right?

sample_weight = np.asarray(sample_weight)

if self.strategy == "mean":
self.constant_ = np.average(y, axis=0, weights=sample_weight)
Expand Down
10 changes: 7 additions & 3 deletions sklearn/linear_model/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,

if isinstance(sample_weight, numbers.Number):
sample_weight = None
if sample_weight is not None:
sample_weight = np.asarray(sample_weight)

if check_input:
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
Expand Down Expand Up @@ -181,8 +183,10 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
def _rescale_data(X, y, sample_weight):
"""Rescale data so as to support sample_weight"""
n_samples = X.shape[0]
sample_weight = np.full(n_samples, sample_weight,
dtype=np.array(sample_weight).dtype)
sample_weight = np.array(sample_weight)
if sample_weight.ndim == 0:
sample_weight = np.full(n_samples, sample_weight,
dtype=sample_weight.dtype)
sample_weight = np.sqrt(sample_weight)
sw_matrix = sparse.dia_matrix((sample_weight, 0),
shape=(n_samples, n_samples))
Expand Down Expand Up @@ -467,7 +471,7 @@ def fit(self, X, y, sample_weight=None):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
y_numeric=True, multi_output=True)

if sample_weight is not None and np.atleast_1d(sample_weight).ndim > 1:
if sample_weight is not None and np.asarray(sample_weight).ndim > 1:
raise ValueError("Sample weights must be 1D array or scalar")

X, y, X_offset, y_offset, X_scale = self._preprocess_data(
Expand Down
2 changes: 1 addition & 1 deletion sklearn/linear_model/_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def fit(self, X, y, sample_weight=None):
solver = self.solver

if ((sample_weight is not None) and
np.atleast_1d(sample_weight).ndim > 1):
np.asarray(sample_weight).ndim > 1):
raise ValueError("Sample weights must be 1D array or scalar")

# when X is sparse we only remove offset from y
Expand Down
4 changes: 4 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ def fit(self, X, y=None, groups=None, **fit_params):
refit_metric = 'score'

X, y, groups = indexable(X, y, groups)
# make sure fit_params are sliceable
fit_params_values = indexable(*fit_params.values())
fit_params = dict(zip(fit_params.keys(), fit_params_values))

n_splits = cv.get_n_splits(X, y, groups)

base_estimator = clone(self.estimator)
Expand Down
9 changes: 5 additions & 4 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,11 @@ def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
assert_raise_message(AssertionError,
"Expected fit parameter(s) ['eggs'] not seen.",
searcher.fit, X, y, spam=np.ones(10))
assert_raise_message(AssertionError,
"Fit parameter spam has length 1; expected",
searcher.fit, X, y, spam=np.ones(1),
eggs=np.zeros(10))
assert_raise_message(
ValueError,
"Found input variables with inconsistent numbers of samples: [",
searcher.fit, X, y, spam=np.ones(1),
eggs=np.zeros(10))
searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))


Expand Down
4 changes: 3 additions & 1 deletion sklearn/naive_bayes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .utils.extmath import safe_sparse_dot
from .utils.fixes import logsumexp
from .utils.multiclass import _check_partial_fit_first_call
from .utils.validation import check_is_fitted, check_non_negative
from .utils.validation import check_is_fitted, check_non_negative, column_or_1d

__all__ = ['BernoulliNB', 'GaussianNB', 'MultinomialNB', 'ComplementNB',
'CategoricalNB']
Expand Down Expand Up @@ -204,6 +204,7 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""
y = column_or_1d(y, warn=True)
return self._partial_fit(X, y, np.unique(y), _refit=True,
sample_weight=sample_weight)

Expand Down Expand Up @@ -622,6 +623,7 @@ def fit(self, X, y, sample_weight=None):
# this means we also don't have to cast X to floating point
if sample_weight is not None:
Y = Y.astype(np.float64, copy=False)
sample_weight = np.asarray(sample_weight)
sample_weight = np.atleast_2d(sample_weight)
Y *= check_array(sample_weight).T

Expand Down
30 changes: 28 additions & 2 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _yield_checks(name, estimator):
yield check_estimators_dtypes
yield check_fit_score_takes_y
yield check_sample_weights_pandas_series
yield check_sample_weights_not_an_array
yield check_sample_weights_list
yield check_sample_weights_invariance
yield check_estimators_fit_returns_self
Expand Down Expand Up @@ -538,11 +539,17 @@ class _NotAnArray:
"""

def __init__(self, data):
self.data = data
self.data = np.asarray(data)

def __array__(self, dtype=None):
return self.data

def __array_function__(self, func, types, args, kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning True and raising TypeError needs to be tested?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean explicitly tested? Sure, i could do that. Thought that was a bit overkill for a test helper. They are obviously used in the tests, but on CI there's no __array_function__ protocol.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be more onboard if NotAnArray was private, which goes back to "public vs private utils" #6616 (comment)

On the other hand, we are depending on this raising when something is wrong. When everything is working, our tests do not run __array_function__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. I'll add a test (that won't be run)

if func.__name__ == "may_share_memory":
return True
raise TypeError("Don't want to call array_function {}!".format(
func.__name__))


@deprecated("NotAnArray is deprecated in version "
"0.22 and will be removed in version 0.24.")
Expand Down Expand Up @@ -719,6 +726,23 @@ def check_sample_weights_pandas_series(name, estimator_orig):
"input of type pandas.Series to class weight.")


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_sample_weights_not_an_array(name, estimator_orig):
# check that estimators will accept a 'sample_weight' parameter of
# type _NotAnArray in the 'fit' function.
estimator = clone(estimator_orig)
if has_fit_parameter(estimator, "sample_weight"):
X = np.array([[1, 1], [1, 2], [1, 3], [1, 4],
[2, 1], [2, 2], [2, 3], [2, 4],
[3, 1], [3, 2], [3, 3], [3, 4]])
X = _NotAnArray(pairwise_estimator_convert_X(X, estimator_orig))
y = _NotAnArray([1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2])
weights = _NotAnArray([1] * 12)
if _safe_tags(estimator, "multioutput_only"):
y = _NotAnArray(y.data.reshape(-1, 1))
estimator.fit(X, y, sample_weight=weights)


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_sample_weights_list(name, estimator_orig):
# check that estimators will accept a 'sample_weight' parameter of
Expand Down Expand Up @@ -1162,8 +1186,10 @@ def _check_transformer(name, transformer_orig, X, y):
# fit

if name in CROSS_DECOMPOSITION:
y_ = np.c_[y, y]
y_ = np.c_[np.asarray(y), np.asarray(y)]
y_[::2, 1] *= 2
if isinstance(X, _NotAnArray):
y_ = _NotAnArray(y_)
else:
y_ = y

Expand Down
4 changes: 3 additions & 1 deletion sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def class_distribution(y, sample_weight=None):
class_prior = []

n_samples, n_outputs = y.shape
if sample_weight is not None:
sample_weight = np.asarray(sample_weight)

if issparse(y):
y = y.tocsc()
Expand All @@ -362,7 +364,7 @@ def class_distribution(y, sample_weight=None):
col_nonzero = y.indices[y.indptr[k]:y.indptr[k + 1]]
# separate sample weights for zero and non-zero elements
if sample_weight is not None:
nz_samp_weight = np.asarray(sample_weight)[col_nonzero]
nz_samp_weight = sample_weight[col_nonzero]
zeros_samp_weight_sum = (np.sum(sample_weight) -
np.sum(nz_samp_weight))
else:
Expand Down
14 changes: 13 additions & 1 deletion sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ignore_warnings,
assert_warns, assert_raises,
SkipTest)
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import check_estimator, _NotAnArray
from sklearn.utils.estimator_checks \
import check_class_weight_balanced_linear_classifier
from sklearn.utils.estimator_checks import set_random_state
Expand All @@ -23,6 +23,7 @@
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.estimator_checks import check_outlier_corruption
from sklearn.utils.fixes import _parse_version
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression, SGDClassifier
from sklearn.mixture import GaussianMixture
Expand Down Expand Up @@ -305,6 +306,17 @@ def _more_tags(self):
return {"requires_positive_y": True}


def test_not_an_array_array_function():
np_version = _parse_version(np.__version__)
if np_version < (1, 17):
raise SkipTest("array_function protocol not supported in numpy <1.17")
not_array = _NotAnArray(np.ones(10))
msg = "Don't want to call array_function sum!"
assert_raises_regex(TypeError, msg, np.sum, not_array)
# always returns True
assert np.may_share_memory(not_array, None)


def test_check_fit_score_takes_y_works_on_deprecated_fit():
# Tests that check_fit_score_takes_y works on a class with
# a deprecated fit method
Expand Down
1 change: 1 addition & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def column_or_1d(y, warn=False):
y : array

"""
y = np.asarray(y)
shape = np.shape(y)
if len(shape) == 1:
return np.ravel(y)
Expand Down