From 18a26ce9eeb73f2fb19dbf573df9fb24671d90ff Mon Sep 17 00:00:00 2001 From: Stefanie Molin <24376333+stefmolin@users.noreply.github.com> Date: Sat, 20 Aug 2022 13:47:45 -0400 Subject: [PATCH 1/4] MAINT Add parameter validation to PatchExtractor. --- sklearn/feature_extraction/image.py | 25 +++++++++++++++++++------ sklearn/tests/test_common.py | 1 - 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index 515a7990306b6..fdabf7e2a45d9 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -10,12 +10,13 @@ # License: BSD 3 clause from itertools import product -import numbers +from numbers import Integral, Number, Real import numpy as np from scipy import sparse from numpy.lib.stride_tricks import as_strided from ..utils import check_array, check_random_state +from ..utils._param_validation import Interval from ..base import BaseEstimator __all__ = [ @@ -249,11 +250,11 @@ def _compute_n_patches(i_h, i_w, p_h, p_w, max_patches=None): all_patches = n_h * n_w if max_patches: - if isinstance(max_patches, (numbers.Integral)) and max_patches < all_patches: + if isinstance(max_patches, (Integral)) and max_patches < all_patches: return max_patches - elif isinstance(max_patches, (numbers.Integral)) and max_patches >= all_patches: + elif isinstance(max_patches, (Integral)) and max_patches >= all_patches: return all_patches - elif isinstance(max_patches, (numbers.Real)) and 0 < max_patches < 1: + elif isinstance(max_patches, (Real)) and 0 < max_patches < 1: return int(max_patches * all_patches) else: raise ValueError("Invalid value for max_patches: %r" % max_patches) @@ -299,9 +300,9 @@ def _extract_patches(arr, patch_shape=8, extraction_step=1): arr_ndim = arr.ndim - if isinstance(patch_shape, numbers.Number): + if isinstance(patch_shape, Number): patch_shape = tuple([patch_shape] * arr_ndim) - if isinstance(extraction_step, numbers.Number): + if isinstance(extraction_step, Number): extraction_step = tuple([extraction_step] * arr_ndim) patch_strides = arr.strides @@ -502,6 +503,16 @@ class PatchExtractor(BaseEstimator): Patches shape: (545706, 2, 2) """ + _parameter_constraints = { + "patch_size": ["array-like", tuple, None], + "max_patches": [ + None, + Interval(Real, 0, 1, closed="neither"), + Interval(Integral, 0, None, closed="neither"), + ], + "random_state": [None, "random_state"], + } + def __init__(self, *, patch_size=None, max_patches=None, random_state=None): self.patch_size = patch_size self.max_patches = max_patches @@ -526,6 +537,7 @@ def fit(self, X, y=None): self : object Returns the instance itself. """ + self._validate_params() return self def transform(self, X): @@ -547,6 +559,7 @@ def transform(self, X): `n_patches` is either `n_samples * max_patches` or the total number of patches that can be extracted. """ + self._validate_params() self.random_state = check_random_state(self.random_state) n_images, i_h, i_w = X.shape[:3] X = np.reshape(X, (n_images, i_h, i_w, -1)) diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9d7c53113bcf6..6229284545a27 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -493,7 +493,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): "OPTICS", "OneVsOneClassifier", "OneVsRestClassifier", - "PatchExtractor", "PolynomialCountSketch", "PolynomialFeatures", "QuadraticDiscriminantAnalysis", From c10d464191e6e68599b351870de7cc545c94e70e Mon Sep 17 00:00:00 2001 From: Stefanie Molin <24376333+stefmolin@users.noreply.github.com> Date: Mon, 22 Aug 2022 09:17:32 -0400 Subject: [PATCH 2/4] Update random_state. --- sklearn/feature_extraction/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index fdabf7e2a45d9..f811d5bce71b6 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -510,7 +510,7 @@ class PatchExtractor(BaseEstimator): Interval(Real, 0, 1, closed="neither"), Interval(Integral, 0, None, closed="neither"), ], - "random_state": [None, "random_state"], + "random_state": ["random_state"], } def __init__(self, *, patch_size=None, max_patches=None, random_state=None): From aeffa93585267b590ca11cd5b1858b3514da95fd Mon Sep 17 00:00:00 2001 From: Stefanie Molin <24376333+stefmolin@users.noreply.github.com> Date: Tue, 30 Aug 2022 18:29:52 -0400 Subject: [PATCH 3/4] Add type. --- sklearn/feature_extraction/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index f811d5bce71b6..1609db26803d5 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -503,7 +503,7 @@ class PatchExtractor(BaseEstimator): Patches shape: (545706, 2, 2) """ - _parameter_constraints = { + _parameter_constraints: dict = { "patch_size": ["array-like", tuple, None], "max_patches": [ None, From f0198a4e514369057107772e6b8c94dbd6dd62f5 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 2 Sep 2022 10:58:24 +0200 Subject: [PATCH 4/4] nitpick --- sklearn/feature_extraction/image.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/feature_extraction/image.py b/sklearn/feature_extraction/image.py index 1609db26803d5..c847b82e01641 100644 --- a/sklearn/feature_extraction/image.py +++ b/sklearn/feature_extraction/image.py @@ -504,11 +504,11 @@ class PatchExtractor(BaseEstimator): """ _parameter_constraints: dict = { - "patch_size": ["array-like", tuple, None], + "patch_size": [tuple, None], "max_patches": [ None, Interval(Real, 0, 1, closed="neither"), - Interval(Integral, 0, None, closed="neither"), + Interval(Integral, 1, None, closed="left"), ], "random_state": ["random_state"], } @@ -559,7 +559,6 @@ def transform(self, X): `n_patches` is either `n_samples * max_patches` or the total number of patches that can be extracted. """ - self._validate_params() self.random_state = check_random_state(self.random_state) n_images, i_h, i_w = X.shape[:3] X = np.reshape(X, (n_images, i_h, i_w, -1))