diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index 515a7990306b6..c847b82e01641 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -10,12 +10,13 @@ # License: BSD 3 clause from itertools import product -import numbers +from numbers import Integral, Number, Real import numpy as np from scipy import sparse from numpy.lib.stride_tricks import as_strided from ..utils import check_array, check_random_state +from ..utils._param_validation import Interval from ..base import BaseEstimator __all__ = [ @@ -249,11 +250,11 @@ def _compute_n_patches(i_h, i_w, p_h, p_w, max_patches=None): all_patches = n_h * n_w if max_patches: - if isinstance(max_patches, (numbers.Integral)) and max_patches < all_patches: + if isinstance(max_patches, (Integral)) and max_patches < all_patches: return max_patches - elif isinstance(max_patches, (numbers.Integral)) and max_patches >= all_patches: + elif isinstance(max_patches, (Integral)) and max_patches >= all_patches: return all_patches - elif isinstance(max_patches, (numbers.Real)) and 0 < max_patches < 1: + elif isinstance(max_patches, (Real)) and 0 < max_patches < 1: return int(max_patches * all_patches) else: raise ValueError("Invalid value for max_patches: %r" % max_patches) @@ -299,9 +300,9 @@ def _extract_patches(arr, patch_shape=8, extraction_step=1): arr_ndim = arr.ndim - if isinstance(patch_shape, numbers.Number): + if isinstance(patch_shape, Number): patch_shape = tuple([patch_shape] * arr_ndim) - if isinstance(extraction_step, numbers.Number): + if isinstance(extraction_step, Number): extraction_step = tuple([extraction_step] * arr_ndim) patch_strides = arr.strides @@ -502,6 +503,16 @@ class PatchExtractor(BaseEstimator): Patches shape: (545706, 2, 2) """ + _parameter_constraints: dict = { + "patch_size": [tuple, None], + "max_patches": [ + None, + Interval(Real, 0, 1, closed="neither"), + Interval(Integral, 1, None, closed="left"), + ], + "random_state": ["random_state"], + } + def __init__(self, *, patch_size=None, max_patches=None, random_state=None): self.patch_size = patch_size self.max_patches = max_patches @@ -526,6 +537,7 @@ def fit(self, X, y=None): self : object Returns the instance itself. """ + self._validate_params() return self def transform(self, X): diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index bb65d30731bf6..abd91efbf1ad7 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -473,7 +473,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): "OPTICS", "OneVsOneClassifier", "OneVsRestClassifier", - "PatchExtractor", "RANSACRegressor", "RidgeCV", "RidgeClassifierCV",