Skip to content

API make PatchExtractor being a real scikit-learn transformer #24230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 28, 2023
2 changes: 1 addition & 1 deletion doc/modules/feature_extraction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ on overlapping areas::

The :class:`PatchExtractor` class works in the same way as
:func:`extract_patches_2d`, only it supports multiple images as input. It is
implemented as an estimator, so it can be used in pipelines. See::
implemented as a scikit-learn transformer, so it can be used in pipelines. See::

>>> five_images = np.arange(5 * 4 * 4 * 3).reshape(5, 4, 4, 3)
>>> patches = image.PatchExtractor(patch_size=(2, 2)).transform(five_images)
Expand Down
9 changes: 9 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ Changelog
inconsistent with the sckit-learn verion the estimator was pickled with.
:pr:`25297` by `Thomas Fan`_.

:mod:`sklearn.feature_extraction`
.................................

- |API| :class:`feature_extraction.image.PatchExtractor` now follows the
transformer API of scikit-learn. This class is defined as a stateless transformer
meaning that it is note required to call `fit` before calling `transform`.
Parameter validation only happens at `fit` time.
:pr:`24230` by :user:`Guillaume Lemaitre <glemaitre>`.

:mod:`sklearn.impute`
.....................

Expand Down
75 changes: 50 additions & 25 deletions sklearn/feature_extraction/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from scipy import sparse
from numpy.lib.stride_tricks import as_strided

from ..base import BaseEstimator, TransformerMixin
from ..utils import check_array, check_random_state
from ..utils._param_validation import Hidden, Interval, validate_params
from ..base import BaseEstimator

__all__ = [
"PatchExtractor",
Expand Down Expand Up @@ -491,7 +491,7 @@ def reconstruct_from_patches_2d(patches, image_size):
return img


class PatchExtractor(BaseEstimator):
class PatchExtractor(TransformerMixin, BaseEstimator):
"""Extracts patches from a collection of images.

Read more in the :ref:`User Guide <image_feature_extraction>`.
Expand All @@ -518,18 +518,23 @@ class PatchExtractor(BaseEstimator):
--------
reconstruct_from_patches_2d : Reconstruct image from all of its patches.

Notes
-----
This estimator is stateless and does not need to be fitted. However, we
recommend to call :meth:`fit_transform` instead of :meth:`transform`, as
parameter validation is only performed in :meth:`fit`.

Examples
--------
>>> from sklearn.datasets import load_sample_images
>>> from sklearn.feature_extraction import image
>>> # Use the array data from the second image in this dataset:
>>> X = load_sample_images().images[1]
>>> print('Image shape: {}'.format(X.shape))
>>> print(f"Image shape: {X.shape}")
Image shape: (427, 640, 3)
>>> pe = image.PatchExtractor(patch_size=(2, 2))
>>> pe_fit = pe.fit(X)
>>> pe_trans = pe.transform(X)
>>> print('Patches shape: {}'.format(pe_trans.shape))
>>> print(f"Patches shape: {pe_trans.shape}")
Patches shape: (545706, 2, 2)
"""

Expand All @@ -549,15 +554,18 @@ def __init__(self, *, patch_size=None, max_patches=None, random_state=None):
self.random_state = random_state

def fit(self, X, y=None):
"""Do nothing and return the estimator unchanged.
"""Only validate the parameters of the estimator.

This method is just there to implement the usual API and hence
work in pipelines.
This method allows to: (i) validate the parameters of the estimator and
(ii) be consistent with the scikit-learn transformer API.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
X : ndarray of shape (n_samples, image_height, image_width) or \
(n_samples, image_height, image_width, n_channels)
Array of images from which to extract patches. For color images,
the last dimension specifies the channel: a RGB image would have
`n_channels=3`.

y : Ignored
Not used, present for API consistency by convention.
Expand All @@ -576,32 +584,49 @@ def transform(self, X):
Parameters
----------
X : ndarray of shape (n_samples, image_height, image_width) or \
(n_samples, image_height, image_width, n_channels)
(n_samples, image_height, image_width, n_channels)
Array of images from which to extract patches. For color images,
the last dimension specifies the channel: a RGB image would have
`n_channels=3`.

Returns
-------
patches : array of shape (n_patches, patch_height, patch_width) or \
(n_patches, patch_height, patch_width, n_channels)
The collection of patches extracted from the images, where
`n_patches` is either `n_samples * max_patches` or the total
number of patches that can be extracted.
(n_patches, patch_height, patch_width, n_channels)
The collection of patches extracted from the images, where
`n_patches` is either `n_samples * max_patches` or the total
number of patches that can be extracted.
"""
self.random_state = check_random_state(self.random_state)
n_images, i_h, i_w = X.shape[:3]
X = np.reshape(X, (n_images, i_h, i_w, -1))
n_channels = X.shape[-1]
X = self._validate_data(
X=X,
ensure_2d=False,
allow_nd=True,
ensure_min_samples=1,
ensure_min_features=1,
reset=False,
)
random_state = check_random_state(self.random_state)
n_imgs, img_height, img_width = X.shape[:3]
if self.patch_size is None:
patch_size = i_h // 10, i_w // 10
patch_size = img_height // 10, img_width // 10
else:
if len(self.patch_size) != 2:
raise ValueError(
f"patch_size must be a tuple of two integers. Got {self.patch_size}"
" instead."
)
patch_size = self.patch_size

n_imgs, img_height, img_width = X.shape[:3]
X = np.reshape(X, (n_imgs, img_height, img_width, -1))
n_channels = X.shape[-1]

# compute the dimensions of the patches array
p_h, p_w = patch_size
n_patches = _compute_n_patches(i_h, i_w, p_h, p_w, self.max_patches)
patches_shape = (n_images * n_patches,) + patch_size
patch_height, patch_width = patch_size
n_patches = _compute_n_patches(
img_height, img_width, patch_height, patch_width, self.max_patches
)
patches_shape = (n_imgs * n_patches,) + patch_size
if n_channels > 1:
patches_shape += (n_channels,)

Expand All @@ -612,9 +637,9 @@ def transform(self, X):
image,
patch_size,
max_patches=self.max_patches,
random_state=self.random_state,
random_state=random_state,
)
return patches

def _more_tags(self):
return {"X_types": ["3darray"]}
return {"X_types": ["3darray"], "stateless": True}
9 changes: 9 additions & 0 deletions sklearn/feature_extraction/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,12 @@ def test_width_patch():
extract_patches_2d(x, (4, 1))
with pytest.raises(ValueError):
extract_patches_2d(x, (1, 4))


def test_patch_extractor_wrong_input():
"""Check that an informative error is raised if the patch_size is not valid."""
faces = _make_images(orange_face)
err_msg = "patch_size must be a tuple of two integers"
extractor = PatchExtractor(patch_size=(8, 8, 8))
with pytest.raises(ValueError, match=err_msg):
extractor.transform(faces)
2 changes: 2 additions & 0 deletions sklearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def test_fit_docstring_attributes(name, Estimator):
est.fit(y)
elif "2dlabels" in est._get_tags()["X_types"]:
est.fit(np.c_[y, y])
elif "3darray" in est._get_tags()["X_types"]:
est.fit(X[np.newaxis, ...], y)
else:
est.fit(X, y)

Expand Down