Skip to content

MAINT make AdditiveChi2Sampler stateless and check that stateless Transformers don't raise NotFittedError #25190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
de5aeb8
MAINT make AdditiveChi2Sampler stateless
Vincent-Maladiere Dec 14, 2022
f0d7a57
apply feedbacks
Vincent-Maladiere Dec 15, 2022
9a85c32
fix tests
Vincent-Maladiere Dec 15, 2022
58f2391
apply new suggestions
Vincent-Maladiere Dec 15, 2022
efb5fa5
typo
Vincent-Maladiere Dec 15, 2022
d5a4fed
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Dec 15, 2022
28d334c
iterate on tests
Vincent-Maladiere Dec 22, 2022
cc00fce
typo in docstring
Vincent-Maladiere Dec 22, 2022
71a7b9f
fix docstring
Vincent-Maladiere Dec 28, 2022
26ee464
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Dec 28, 2022
d3a36e2
update changelog
Vincent-Maladiere Dec 28, 2022
9eb30b4
improve test coverage
Vincent-Maladiere Dec 28, 2022
c6ef260
add test_common
Vincent-Maladiere Dec 31, 2022
5d84bf3
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Dec 31, 2022
74e3998
improve coverage by removing from test_common
Vincent-Maladiere Jan 2, 2023
5895114
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 3, 2023
628bfc2
move stateless check into _yield_transformer_checks
Vincent-Maladiere Jan 3, 2023
3823e3a
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 3, 2023
19c34bb
apply suggestions
Vincent-Maladiere Jan 4, 2023
ea81616
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 4, 2023
2149b26
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 9, 2023
b516a5c
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 17, 2023
f4d0b0e
Apply suggestions from code review
Vincent-Maladiere Jan 17, 2023
b8245fd
apply suggestions
Vincent-Maladiere Jan 17, 2023
88fdd7e
Update doc/whats_new/v1.3.rst
Vincent-Maladiere Jan 17, 2023
1234066
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 18, 2023
ea18e0b
Merge branch 'main' into make_AdditiveChi2Sampler_stateless
Vincent-Maladiere Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,18 @@ Changelog
:pr:`24935` by :user:`Seladus <seladus>`, :user:`Guillaume Lemaitre <glemaitre>`, and
:user:`Dea María Léon <deamarialeon>`, :pr:`25257` by :user:`Gleb Levitski <glevv>`.

- |Fix| :class:`AdditiveChi2Sampler` is now stateless.
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.

:mod:`sklearn.utils`
....................

- |API| :func:`estimator_checks.check_transformers_unfitted_stateless` has been
introduced to ensure stateless transformers don't raise `NotFittedError`
during `transform` with no prior call to `fit` or `fit_transform`.
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.

Code and Documentation Contributors
-----------------------------------

Expand Down
97 changes: 70 additions & 27 deletions sklearn/kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .base import TransformerMixin
from .base import ClassNamePrefixFeaturesOutMixin
from .utils import check_random_state
from .utils import deprecated
from .utils.extmath import safe_sparse_dot
from .utils.validation import check_is_fitted
from .utils.validation import _check_feature_names_in
Expand Down Expand Up @@ -600,6 +601,9 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
Stored sampling interval. Specified as a parameter if `sample_steps`
not in {1,2,3}.

.. deprecated:: 1.3
`sample_interval_` serves internal purposes only and will be removed in 1.5.

n_features_in_ : int
Number of features seen during :term:`fit`.

Expand All @@ -626,6 +630,10 @@ class AdditiveChi2Sampler(TransformerMixin, BaseEstimator):
This estimator approximates a slightly different version of the additive
chi squared kernel then ``metric.additive_chi2`` computes.

This estimator is stateless and does not need to be fitted. However, we
recommend to call :meth:`fit_transform` instead of :meth:`transform`, as
parameter validation is only performed in :meth:`fit`.

References
----------
See `"Efficient additive kernels via explicit feature maps"
Expand Down Expand Up @@ -658,7 +666,10 @@ def __init__(self, *, sample_steps=2, sample_interval=None):
self.sample_interval = sample_interval

def fit(self, X, y=None):
"""Set the parameters.
"""Only validates estimator's parameters.

This method allows to: (i) validate the estimator's parameters and
(ii) be consistent with the scikit-learn transformer API.

Parameters
----------
Expand All @@ -676,27 +687,40 @@ def fit(self, X, y=None):
Returns the transformer.
"""
self._validate_params()

X = self._validate_data(X, accept_sparse="csr")
check_non_negative(X, "X in AdditiveChi2Sampler.fit")

# TODO(1.5): remove the setting of _sample_interval from fit
if self.sample_interval is None:
# See reference, figure 2 c)
# See figure 2 c) of "Efficient additive kernels via explicit feature maps"
# <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence,
# 2011
if self.sample_steps == 1:
self.sample_interval_ = 0.8
self._sample_interval = 0.8
elif self.sample_steps == 2:
self.sample_interval_ = 0.5
self._sample_interval = 0.5
elif self.sample_steps == 3:
self.sample_interval_ = 0.4
self._sample_interval = 0.4
else:
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)
else:
self.sample_interval_ = self.sample_interval
self._sample_interval = self.sample_interval

return self

# TODO(1.5): remove
@deprecated( # type: ignore
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
"will be removed 1.5."
)
@property
def sample_interval_(self):
return self._sample_interval

def transform(self, X):
"""Apply approximate feature map to X.

Expand All @@ -713,22 +737,39 @@ def transform(self, X):
Whether the return value is an array or sparse matrix depends on
the type of the input X.
"""
msg = (
"%(name)s is not fitted. Call fit to set the parameters before"
" calling transform"
)
check_is_fitted(self, msg=msg)

X = self._validate_data(X, accept_sparse="csr", reset=False)
check_non_negative(X, "X in AdditiveChi2Sampler.transform")
sparse = sp.issparse(X)

if hasattr(self, "_sample_interval"):
# TODO(1.5): remove this branch
sample_interval = self._sample_interval

else:
if self.sample_interval is None:
# See figure 2 c) of "Efficient additive kernels via explicit feature maps" # noqa
# <http://www.robots.ox.ac.uk/~vedaldi/assets/pubs/vedaldi11efficient.pdf>
# A. Vedaldi and A. Zisserman, Pattern Analysis and Machine Intelligence, # noqa
# 2011
if self.sample_steps == 1:
sample_interval = 0.8
elif self.sample_steps == 2:
sample_interval = 0.5
elif self.sample_steps == 3:
sample_interval = 0.4
else:
raise ValueError(
"If sample_steps is not in [1, 2, 3],"
" you need to provide sample_interval"
)
else:
sample_interval = self.sample_interval

# zeroth component
# 1/cosh = sech
# cosh(0) = 1.0

transf = self._transform_sparse if sparse else self._transform_dense
return transf(X)
return transf(X, self.sample_steps, sample_interval)

def get_feature_names_out(self, input_features=None):
"""Get output feature names for transformation.
Expand Down Expand Up @@ -758,20 +799,21 @@ def get_feature_names_out(self, input_features=None):

return np.asarray(names_list, dtype=object)

def _transform_dense(self, X):
@staticmethod
def _transform_dense(X, sample_steps, sample_interval):
non_zero = X != 0.0
X_nz = X[non_zero]

X_step = np.zeros_like(X)
X_step[non_zero] = np.sqrt(X_nz * self.sample_interval_)
X_step[non_zero] = np.sqrt(X_nz * sample_interval)

X_new = [X_step]

log_step_nz = self.sample_interval_ * np.log(X_nz)
step_nz = 2 * X_nz * self.sample_interval_
log_step_nz = sample_interval * np.log(X_nz)
step_nz = 2 * X_nz * sample_interval

for j in range(1, self.sample_steps):
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * self.sample_interval_))
for j in range(1, sample_steps):
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * sample_interval))

X_step = np.zeros_like(X)
X_step[non_zero] = factor_nz * np.cos(j * log_step_nz)
Expand All @@ -783,21 +825,22 @@ def _transform_dense(self, X):

return np.hstack(X_new)

def _transform_sparse(self, X):
@staticmethod
def _transform_sparse(X, sample_steps, sample_interval):
indices = X.indices.copy()
indptr = X.indptr.copy()

data_step = np.sqrt(X.data * self.sample_interval_)
data_step = np.sqrt(X.data * sample_interval)
X_step = sp.csr_matrix(
(data_step, indices, indptr), shape=X.shape, dtype=X.dtype, copy=False
)
X_new = [X_step]

log_step_nz = self.sample_interval_ * np.log(X.data)
step_nz = 2 * X.data * self.sample_interval_
log_step_nz = sample_interval * np.log(X.data)
step_nz = 2 * X.data * sample_interval

for j in range(1, self.sample_steps):
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * self.sample_interval_))
for j in range(1, sample_steps):
factor_nz = np.sqrt(step_nz / np.cosh(np.pi * j * sample_interval))

data_step = factor_nz * np.cos(j * log_step_nz)
X_step = sp.csr_matrix(
Expand Down
61 changes: 38 additions & 23 deletions sklearn/tests/test_kernel_approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,34 +114,49 @@ def test_additive_chi2_sampler():
Y_neg[0, 0] = -1
msg = "Negative values in data passed to"
with pytest.raises(ValueError, match=msg):
transform.transform(Y_neg)
transform.fit(Y_neg)


@pytest.mark.parametrize("method", ["fit", "fit_transform", "transform"])
@pytest.mark.parametrize("sample_steps", range(1, 4))
def test_additive_chi2_sampler_sample_steps(method, sample_steps):
"""Check that the input sample step doesn't raise an error
and that sample interval doesn't change after fit.
"""
transformer = AdditiveChi2Sampler(sample_steps=sample_steps)
getattr(transformer, method)(X)

sample_interval = 0.5
transformer = AdditiveChi2Sampler(
sample_steps=sample_steps,
sample_interval=sample_interval,
)
getattr(transformer, method)(X)
transformer.sample_interval == sample_interval


# TODO(1.5): remove
def test_additive_chi2_sampler_future_warnings():
"""Check that we raise a FutureWarning when accessing to `sample_interval_`."""
transformer = AdditiveChi2Sampler()
transformer.fit(X)
msg = re.escape(
"The ``sample_interval_`` attribute was deprecated in version 1.3 and "
"will be removed 1.5."
)
with pytest.warns(FutureWarning, match=msg):
assert transformer.sample_interval_ is not None


# test error on invalid sample_steps
transform = AdditiveChi2Sampler(sample_steps=4)
@pytest.mark.parametrize("method", ["fit", "fit_transform", "transform"])
def test_additive_chi2_sampler_wrong_sample_steps(method):
"""Check that we raise a ValueError on invalid sample_steps"""
transformer = AdditiveChi2Sampler(sample_steps=4)
msg = re.escape(
"If sample_steps is not in [1, 2, 3], you need to provide sample_interval"
)
with pytest.raises(ValueError, match=msg):
transform.fit(X)

# test that the sample interval is set correctly
sample_steps_available = [1, 2, 3]
for sample_steps in sample_steps_available:

# test that the sample_interval is initialized correctly
transform = AdditiveChi2Sampler(sample_steps=sample_steps)
assert transform.sample_interval is None

# test that the sample_interval is changed in the fit method
transform.fit(X)
assert transform.sample_interval_ is not None

# test that the sample_interval is set correctly
sample_interval = 0.3
transform = AdditiveChi2Sampler(sample_steps=4, sample_interval=sample_interval)
assert transform.sample_interval == sample_interval
transform.fit(X)
assert transform.sample_interval_ == sample_interval
getattr(transformer, method)(X)


def test_skewed_chi2_sampler():
Expand Down
17 changes: 17 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def _yield_transformer_checks(transformer):
yield partial(check_transformer_general, readonly_memmap=True)
if not _safe_tags(transformer, key="stateless"):
yield check_transformers_unfitted
else:
yield check_transformers_unfitted_stateless
# Dependent on external solvers and hence accessing the iter
# param is non-trivial.
external_solver = [
Expand Down Expand Up @@ -1554,6 +1556,21 @@ def check_transformers_unfitted(name, transformer):
transformer.transform(X)


@ignore_warnings(category=FutureWarning)
def check_transformers_unfitted_stateless(name, transformer):
"""Check that using transform without prior fitting
doesn't raise a NotFittedError for stateless transformers.
"""
rng = np.random.RandomState(0)
X = rng.uniform(size=(20, 5))
X = _enforce_estimator_tags_X(transformer, X)

transformer = clone(transformer)
X_trans = transformer.transform(X)

assert X_trans.shape[0] == X.shape[0]


def _check_transformer(name, transformer_orig, X, y):
n_samples, n_features = np.asarray(X).shape
transformer = clone(transformer_orig)
Expand Down