From 2ccfc8c4bdf66db005d7681757b4145842944fb9 Mon Sep 17 00:00:00 2001 From: Franck Charras <29153872+fcharras@users.noreply.github.com> Date: Wed, 30 Aug 2023 14:19:02 +0200 Subject: [PATCH 1/2] Engine plugin API and engine entry point for Lloyd's KMeans Co-authored-by: Tim Head Co-authored-by: Olivier Grisel Co-authored-by: Franck Charras <29153872+fcharras@users.noreply.github.com> --- doc/computing.rst | 1 + doc/computing/engine.rst | 29 ++ doc/whats_new/v1.4.rst | 18 ++ setup.py | 1 + sklearn/_config.py | 59 ++++ sklearn/_engine/__init__.py | 3 + sklearn/_engine/base.py | 186 ++++++++++++ sklearn/_engine/testing.py | 37 +++ sklearn/_engine/tests/__init__.py | 0 sklearn/_engine/tests/test_engines.py | 317 ++++++++++++++++++++ sklearn/base.py | 39 +++ sklearn/cluster/_kmeans.py | 404 +++++++++++++++++++++----- sklearn/exceptions.py | 14 + sklearn/tests/test_config.py | 6 + 14 files changed, 1046 insertions(+), 68 deletions(-) create mode 100644 doc/computing/engine.rst mode change 100755 => 100644 setup.py create mode 100644 sklearn/_engine/__init__.py create mode 100644 sklearn/_engine/base.py create mode 100644 sklearn/_engine/testing.py create mode 100644 sklearn/_engine/tests/__init__.py create mode 100644 sklearn/_engine/tests/test_engines.py diff --git a/doc/computing.rst b/doc/computing.rst index 6732b754918b0..8b355f22ec641 100644 --- a/doc/computing.rst +++ b/doc/computing.rst @@ -14,3 +14,4 @@ Computing with scikit-learn computing/scaling_strategies computing/computational_performance computing/parallelism + computing/engine diff --git a/doc/computing/engine.rst b/doc/computing/engine.rst new file mode 100644 index 0000000000000..66c3302113324 --- /dev/null +++ b/doc/computing/engine.rst @@ -0,0 +1,29 @@ +.. Places parent toc into the sidebar + +:parenttoc: True + +.. _engine: + +Computation Engines (experimental) +================================== + +**This API is experimental** which means that it is subject to change without +any backward compatibility guarantees. + +TODO: explain goals here + +Activating an engine +-------------------- + +TODO: installing third party engine provider packages + +TODO: how to list installed engines + +TODO: how to install a plugin + +Writing a new engine provider +----------------------------- + +TODO: show engine API of a given estimator. + +TODO: give example setup.py with setuptools to define an entrypoint. diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index c13922c6cb22e..077d8e4d84bb0 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -41,6 +41,24 @@ Changes impacting all modules to work with our estimators and functions. :pr:`26464` by `Thomas Fan`_. +- |Enhancement| Experimental engine API (no backward compatibility guarantees) + to allow for external packages to contribute alternative implementations for + the core computational routines of some selected scikit-learn estimators. + + Currently, the following estimators allow alternative implementations: + + - :class:`~sklearn.cluster.KMeans` (only for the LLoyd algorithm). + - TODO: add more when available. + + External engine providers include: + + - https://github.com/soda-inria/sklearn-numba-dpex that provided a KMeans + engine optimized for OpenCL enabled GPUs. + - TODO: add more here + + :pr:`25535` by :user:`ogrisel`, :user:`fcharras` and :user:`betatim`. + + Changelog --------- diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 index f9ae13c94502b..adaebd7f68be9 --- a/setup.py +++ b/setup.py @@ -600,6 +600,7 @@ def setup_package(): python_requires=python_requires, install_requires=min_deps.tag_to_packages["install"], package_data={"": ["*.csv", "*.gz", "*.txt", "*.pxd", "*.rst", "*.jpg"]}, + entry_points={"pytest11": ["sklearn_plugin_testing = sklearn._engine.testing"]}, zip_safe=False, # the package can run out of an .egg file extras_require={ key: min_deps.tag_to_packages[key] diff --git a/sklearn/_config.py b/sklearn/_config.py index 91d149c81dc59..3e20989eef3b5 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -1,5 +1,6 @@ """Global configuration state and functions for management """ +import inspect import os import threading from contextlib import contextmanager as contextmanager @@ -14,6 +15,8 @@ ), "enable_cython_pairwise_dist": True, "array_api_dispatch": False, + "engine_provider": (), + "engine_attributes": "engine_types", "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, @@ -55,6 +58,8 @@ def set_config( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + engine_provider=None, + engine_attributes=None, transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, @@ -126,6 +131,26 @@ def set_config( .. versionadded:: 1.2 + engine_provider : str or sequence of {str, engine class}, default=None + Specify list of enabled computational engine implementations provided + by third party packages. Engines are enabled by listing the name of + the provider or listing an engine class directly. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.4 + + engine_attributes : str, default=None + Enable conversion of estimator attributes to scikit-learn native + types by setting to "sklearn_types". By default attributes are + stored using engine native types. This avoids additional conversions + and memory transfers between host and device when calling `predict`/ + `transform` after `fit` of an engine-aware estimator. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.4 + transform_output : str, default=None Configure output of `transform` and `fit_transform`. @@ -185,6 +210,18 @@ def set_config( _check_array_api_dispatch(array_api_dispatch) local_config["array_api_dispatch"] = array_api_dispatch + if engine_provider is not None: + # Single provider name was passed in + if isinstance(engine_provider, str): + engine_provider = (engine_provider,) + # Allow direct registration of engine classes to ease testing, debugging + # and benchmarking without having to register a fake package with metadata + # just to use a custom engine not meant to be used by end-users. + elif inspect.isclass(engine_provider): + engine_provider = (engine_provider,) + local_config["engine_provider"] = engine_provider + if engine_attributes is not None: + local_config["engine_attributes"] = engine_attributes if transform_output is not None: local_config["transform_output"] = transform_output if enable_metadata_routing is not None: @@ -203,6 +240,8 @@ def config_context( pairwise_dist_chunk_size=None, enable_cython_pairwise_dist=None, array_api_dispatch=None, + engine_provider=None, + engine_attributes=None, transform_output=None, enable_metadata_routing=None, skip_parameter_validation=None, @@ -273,6 +312,24 @@ def config_context( .. versionadded:: 1.2 + engine_provider : str or sequence of {str, engine class}, default=None + Specify list of enabled computational engine implementations provided + by third party packages. Engines are enabled by listing the name of + the provider or listing an engine class directly. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.4 + + engine_attributes : str, default=None + Enable conversion of estimator attributes to scikit-learn native + types by setting to "sklearn_types". By default attributes are + stored using engine native types. + + See the :ref:`User Guide ` for more details. + + .. versionadded:: 1.4 + transform_output : str, default=None Configure output of `transform` and `fit_transform`. @@ -344,6 +401,8 @@ def config_context( pairwise_dist_chunk_size=pairwise_dist_chunk_size, enable_cython_pairwise_dist=enable_cython_pairwise_dist, array_api_dispatch=array_api_dispatch, + engine_provider=engine_provider, + engine_attributes=engine_attributes, transform_output=transform_output, enable_metadata_routing=enable_metadata_routing, skip_parameter_validation=skip_parameter_validation, diff --git a/sklearn/_engine/__init__.py b/sklearn/_engine/__init__.py new file mode 100644 index 0000000000000..a2ffa2e3d5ab7 --- /dev/null +++ b/sklearn/_engine/__init__.py @@ -0,0 +1,3 @@ +from .base import convert_attributes, get_engine_classes, list_engine_provider_names + +__all__ = ["convert_attributes", "get_engine_classes", "list_engine_provider_names"] diff --git a/sklearn/_engine/base.py b/sklearn/_engine/base.py new file mode 100644 index 0000000000000..31a66404bbae5 --- /dev/null +++ b/sklearn/_engine/base.py @@ -0,0 +1,186 @@ +import inspect +import warnings +from functools import lru_cache, wraps +from importlib import import_module +from importlib.metadata import entry_points + +from sklearn._config import get_config + +SKLEARN_ENGINES_ENTRY_POINT = "sklearn_engines" + + +class EngineSpec: + __slots__ = ["name", "provider_name", "module_name", "engine_qualname"] + + def __init__(self, name, provider_name, module_name, engine_qualname): + self.name = name + self.provider_name = provider_name + self.module_name = module_name + self.engine_qualname = engine_qualname + + def get_engine_class(self): + engine = import_module(self.module_name) + for attr in self.engine_qualname.split("."): + engine = getattr(engine, attr) + return engine + + +def _parse_entry_point(entry_point): + module_name, engine_qualname = entry_point.value.split(":") + provider_name = next(iter(module_name.split(".", 1))) + return EngineSpec(entry_point.name, provider_name, module_name, engine_qualname) + + +@lru_cache +def _parse_entry_points(provider_names=None): + specs = [] + all_entry_points = entry_points() + if hasattr(all_entry_points, "select"): + engine_entry_points = all_entry_points.select(group=SKLEARN_ENGINES_ENTRY_POINT) + else: + engine_entry_points = all_entry_points.get(SKLEARN_ENGINES_ENTRY_POINT, ()) + for entry_point in engine_entry_points: + try: + spec = _parse_entry_point(entry_point) + if provider_names is not None and spec.provider_name not in provider_names: + # Skip entry points that do not match the requested provider names. + continue + specs.append(spec) + except Exception as e: + # Do not raise an exception in case an invalid package has been + # installed in the same Python env as scikit-learn: just warn and + # skip. + warnings.warn( + f"Invalid {SKLEARN_ENGINES_ENTRY_POINT} entry point" + f" {entry_point.name} with value {entry_point.value}: {e}" + ) + if provider_names is not None: + observed_provider_names = {spec.provider_name for spec in specs} + missing_providers = set(provider_names) - observed_provider_names + if missing_providers: + raise RuntimeError( + "Could not find any provider for the" + f" {SKLEARN_ENGINES_ENTRY_POINT} entry point with name(s):" + f" {', '.join(repr(p) for p in sorted(missing_providers))}" + ) + return specs + + +def list_engine_provider_names(): + """Find the list of sklearn_engine provider names + + This function only inspects the metadata and should trigger any module import. + """ + return sorted({spec.provider_name for spec in _parse_entry_points()}) + + +def _get_engine_classes(engine_name, provider_names, engine_specs, default): + specs_by_provider = {} + for spec in engine_specs: + if spec.name != engine_name: + continue + specs_by_provider.setdefault(spec.provider_name, spec) + + for provider_name in provider_names: + if inspect.isclass(provider_name): + # The provider name is actually a ready-to-go engine class. + # Instead of a made up string to name this ad-hoc provider + # we use the class itself. This mirrors what the user used + # when they set the config (ad-hoc class or string naming + # a provider). + engine_class = provider_name + if getattr(engine_class, "engine_name", None) != engine_name: + continue + yield engine_class, engine_class + + spec = specs_by_provider.get(provider_name) + if spec is not None: + yield spec.provider_name, spec.get_engine_class() + + yield "default", default + + +def get_engine_classes(engine_name, default, verbose=False): + """Find all possible providers of `engine_name`. + + Provider candidates are found based on parsing entrypoint definitions that + match the name of enabled engine providers, as well as, ad-hoc providers + in the form of engine classes in the list of enabled engine providers. + + Parameters + ---------- + engine_name : str + The name of the algorithm for which to find engine classes. + + default : class + The default engine class to use if no other provider is found. + + verbose : bool, default=False + If True, print the name of the engine classes that are tried. + + Yields + ------ + provider : str or class + The "name" of each matching provider. The "name" corresponds to the + entry in the `engine_provider` configuration. It can be a string or a + class for programmatically registered ad-hoc providers. + + engine_class : + The engine class that implements the algorithm for the given provider. + """ + provider_names = get_config()["engine_provider"] + + if not provider_names: + yield "default", default + return + + engine_specs = _parse_entry_points( + provider_names=tuple( + [name for name in provider_names if not inspect.isclass(name)] + ) + ) + for provider, engine_class in _get_engine_classes( + engine_name=engine_name, + provider_names=provider_names, + engine_specs=engine_specs, + default=default, + ): + if verbose: + print( + f"trying engine {engine_class.__module__}.{engine_class.__qualname__}." + ) + yield provider, engine_class + + +def convert_attributes(method): + """Convert estimator attributes after calling the decorated method. + + The attributes of an estimator can be stored in "engine native" types + (default) or "scikit-learn native" types. This decorator will call the + engine's conversion function when needed. Use this decorator on methods + that set estimator attributes. + """ + + @wraps(method) + def wrapper(self, *args, **kwargs): + r = method(self, *args, **kwargs) + convert_attributes = get_config()["engine_attributes"] + + if convert_attributes == "sklearn_types": + engine = self._engine_class + for name, value in vars(self).items(): + # All attributes are passed to the engine, which can + # either convert the value (engine specific types) or + # return it as is (native Python types) + converted = engine.convert_to_sklearn_types(name, value) + setattr(self, name, converted) + + # No matter which engine was used to fit, after the attribute + # conversion to the sklearn native types the default engine + # is used. + self._engine_class = self._default_engine + self._engine_provider = "default" + + return r + + return wrapper diff --git a/sklearn/_engine/testing.py b/sklearn/_engine/testing.py new file mode 100644 index 0000000000000..57f0305d8a288 --- /dev/null +++ b/sklearn/_engine/testing.py @@ -0,0 +1,37 @@ +from pytest import hookimpl, xfail + +from sklearn import config_context +from sklearn.exceptions import NotSupportedByEngineError + + +# TODO: document this pytest plugin + write a tutorial on how to develop a new plugin +# and explain good practices regarding testing against sklearn test modules. +def pytest_addoption(parser): + group = parser.getgroup("Sklearn plugin testing") + group.addoption( + "--sklearn-engine-provider", + action="store", + nargs=1, + type=str, + help="Name of the an engine provider for sklearn to activate for all tests.", + ) + + +@hookimpl(hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + engine_provider = pyfuncitem.config.getoption("sklearn_engine_provider") + if engine_provider is None: + yield + return + + with config_context(engine_provider=engine_provider): + try: + outcome = yield + outcome.get_result() + except NotSupportedByEngineError: + xfail( + reason=( + "This test cover features that are not supported by the " + f"engine provided by {engine_provider}." + ) + ) diff --git a/sklearn/_engine/tests/__init__.py b/sklearn/_engine/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/_engine/tests/test_engines.py b/sklearn/_engine/tests/test_engines.py new file mode 100644 index 0000000000000..5b92b9203a843 --- /dev/null +++ b/sklearn/_engine/tests/test_engines.py @@ -0,0 +1,317 @@ +import re +from collections import namedtuple + +import numpy as np +import pytest + +from sklearn._config import config_context +from sklearn._engine import ( + convert_attributes, + get_engine_classes, + list_engine_provider_names, +) +from sklearn._engine.base import EngineSpec, _get_engine_classes, _parse_entry_point +from sklearn.base import EngineAwareMixin + + +class FakeDefaultEngine: + pass + + +class FakeEngine: + pass + + +class FakeEngineHolder: + class NestedFakeEngine: + pass + + +# Dummy classes used to test engine resolution +class DefaultEngine: + engine_name = "test-engine" + + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None, sample_weight=None): + return True + + +class NeverAcceptsEngine: + engine_name = "test-engine" + + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None, sample_weight=None): + return False + + +class AlwaysAcceptsEngine: + engine_name = "test-engine" + + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None, sample_weight=None): + return True + + +class AlsoAlwaysAcceptsEngine(AlwaysAcceptsEngine): + pass + + +class FakeEstimator(EngineAwareMixin): + _engine_name = "test-engine" + _default_engine = DefaultEngine + + +FakeEntryPoint = namedtuple("FakeEntryPoint", ["name", "value"]) + + +def test_parse_entry_point(): + fake_entry_point = FakeEntryPoint( + name="fake_engine", + value="sklearn._engine.tests.test_engines:FakeEngine", + ) + spec = _parse_entry_point(fake_entry_point) + assert spec.name == "fake_engine" + assert spec.provider_name == "sklearn" # or should it be scikit-learn? + assert spec.get_engine_class() is FakeEngine + + +def test_parse_entry_point_for_nested_engine_class(): + fake_entry_point = FakeEntryPoint( + name="nested_fake_engine", + value="sklearn._engine.tests.test_engines:FakeEngineHolder.NestedFakeEngine", + ) + spec = _parse_entry_point(fake_entry_point) + assert spec.name == "nested_fake_engine" + assert spec.provider_name == "sklearn" # or should it be scikit-learn? + assert spec.get_engine_class() is FakeEngineHolder.NestedFakeEngine + + +def test_list_engine_provider_names(): + provider_names = list_engine_provider_names() + for provider_name in provider_names: + assert isinstance(provider_name, str) + + +def test_get_engine_class_with_default(): + # Use config_context with an empty provider tuple to make sure that not provider + # are available for test_missing_engine_name + with config_context(engine_provider=()): + engine_classes = list( + get_engine_classes("test_missing_engine_name", default=FakeEngine) + ) + assert engine_classes == [("default", FakeEngine)] + + +def test_get_engine_class(): + engine_specs = ( + EngineSpec( + "kmeans", "provider3", "sklearn._engine.tests.test_engines", "FakeEngine" + ), + EngineSpec( + "kmeans", + "provider4", + "sklearn._engine.tests.test_engines", + "FakeEngineHolder.NestedFakeEngine", + ), + ) + + engine_class = list( + _get_engine_classes( + engine_name="missing", + provider_names=("provider1", "provider3"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [("default", FakeDefaultEngine)] + + engine_class = list( + _get_engine_classes( + engine_name="kmeans", + provider_names=("provider3", "provider4"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [ + ("provider3", FakeEngine), + ("provider4", FakeEngineHolder.NestedFakeEngine), + ("default", FakeDefaultEngine), + ] + + engine_class = list( + _get_engine_classes( + engine_name="kmeans", + provider_names=("provider4", "provider3"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + ) + assert engine_class == [ + ("provider4", FakeEngineHolder.NestedFakeEngine), + ("provider3", FakeEngine), + ("default", FakeDefaultEngine), + ] + + engine_specs = engine_specs + ( + EngineSpec( + "kmeans", + "provider1", + "sklearn.provider1.somewhere", + "OtherEngine", + ), + ) + + # Invalid imports are delayed until they are actually needed. + engine_classes = _get_engine_classes( + engine_name="kmeans", + provider_names=("provider4", "provider3", "provider1"), + engine_specs=engine_specs, + default=FakeDefaultEngine, + ) + + next(engine_classes) + next(engine_classes) + with pytest.raises(ImportError, match=re.escape("sklearn.provider1")): + next(engine_classes) + + +@pytest.mark.parametrize( + "attribute_types,converted", [("sklearn_types", True), ("engine_types", False)] +) +def test_attribute_conversion(attribute_types, converted): + """Test attribute conversion logic + + The estimator uses Numpy Array API arrays as its native type. + """ + np_array_api = pytest.importorskip("numpy.array_api") + + class Engine: + @staticmethod + def convert_to_sklearn_types(name, value): + return np.asarray(value) + + class Estimator: + _default_engine = DefaultEngine + # Setup attribute as if `Engine` had previously been selected, + # we want to test attribute conversion, not engine resolution. + _engine_class = Engine + + @convert_attributes + def fit(self, X): + self.X_ = np_array_api.asarray(X) + + X = np.array([1, 2, 3]) + est = Estimator() + with config_context(engine_attributes=attribute_types): + est.fit(X) + + assert isinstance(est.X_, np.ndarray) == converted + if converted: + assert est._engine_class == est._default_engine == DefaultEngine + else: + assert est._engine_class == Engine + + +def test_engine_selection(): + """Check that the correct engine is selected.""" + # Values aren't important, just need something to pass as argument + # to _get_engine + X = [[1, 2], [3, 4]] + + # If no engine accepts, default engine should be selected + with config_context(engine_provider=NeverAcceptsEngine): + est = FakeEstimator() + engine = est._get_engine(X) + assert isinstance(engine, DefaultEngine) + + with config_context(engine_provider=AlwaysAcceptsEngine): + est = FakeEstimator() + engine = est._get_engine(X) + assert isinstance(engine, AlwaysAcceptsEngine) + + # Engine with second priority (AlwaysAccepts) is selected + with config_context(engine_provider=(NeverAcceptsEngine, AlwaysAcceptsEngine)): + est = FakeEstimator() + engine = est._get_engine(X) + assert isinstance(engine, AlwaysAcceptsEngine) + + +def test_engine_selection_is_fozen(): + """Check that a previously selected engine keeps being used. + + Engine selection is only performed once, after that the same engine + is used. Re-reselecting the engine is possible when explicitly requested. + """ + # Values aren't important, just need something to pass as argument + # to _get_engine + X = [[1, 2], [3, 4]] + + est = FakeEstimator() + + with config_context(engine_provider=(NeverAcceptsEngine, AlwaysAcceptsEngine)): + engine = est._get_engine(X) + assert isinstance(engine, AlwaysAcceptsEngine) + + # Even though `AlsoAlwaysAcceptsEngine` is listed first, it should not + # be selected + with config_context(engine_provider=(AlsoAlwaysAcceptsEngine, AlwaysAcceptsEngine)): + engine = est._get_engine(X) + assert isinstance(engine, AlwaysAcceptsEngine) + + # Explicitly ask for engine re-selection + with config_context(engine_provider=(AlsoAlwaysAcceptsEngine, AlwaysAcceptsEngine)): + engine = est._get_engine(X, reset=True) + assert isinstance(engine, AlsoAlwaysAcceptsEngine) + + +def test_missing_engine_raises(): + """Check an exception is raised when a previously configured engine is + no longer available. + """ + # Values aren't important, just need something to pass as argument + # to _get_engine + X = [[1, 2], [3, 4]] + + est = FakeEstimator() + + with config_context(engine_provider=(NeverAcceptsEngine, AlwaysAcceptsEngine)): + engine = est._get_engine(X) + assert isinstance(engine, AlwaysAcceptsEngine) + + # Raise an exception because the previously selected engine isn't available + with config_context(engine_provider=(AlsoAlwaysAcceptsEngine,)): + with pytest.raises(RuntimeError, match="Previously selected engine.*"): + est._get_engine(X) + + # Doesn't raise because `reset=True` + with config_context(engine_provider=(NeverAcceptsEngine, AlsoAlwaysAcceptsEngine)): + engine = est._get_engine(X, reset=True) + assert isinstance(engine, AlsoAlwaysAcceptsEngine) + + +def test_default_engine_always_works(): + """Check that an estimator that uses the default engine works, even when + no engines are explicitly configured. + """ + # Values aren't important, just need something to pass as argument + # to _get_engine + X = [[1, 2], [3, 4]] + + est = FakeEstimator() + + with config_context(engine_provider=NeverAcceptsEngine): + engine = est._get_engine(X) + assert isinstance(engine, DefaultEngine) + + assert est._engine_class == DefaultEngine + + # With no explicit config, the default engine should still be selected + engine = est._get_engine(X) + assert isinstance(engine, DefaultEngine) diff --git a/sklearn/base.py b/sklearn/base.py index a7c93937ebe72..4a91adbdc44d2 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -15,6 +15,7 @@ from . import __version__ from ._config import config_context, get_config +from ._engine import get_engine_classes from .exceptions import InconsistentVersionWarning from .utils import _IS_32BIT from .utils._estimator_html_repr import estimator_html_repr @@ -1122,6 +1123,44 @@ def _more_tags(self): } +class EngineAwareMixin: + """Mixin for estimators that use a pluggable engine to do the work""" + + def _get_engine(self, X, y=None, sample_weight=None, reset=False): + """Determine the engine for the estimator to use. + + All enabled engine providers are tried in turn, the first one that + accepts is selected. The choice of engine is stored and re-used + unless `reset=True`, in which case the selection process is started + again. + """ + if hasattr(self, "_engine_provider") and not reset: + configured_providers = get_config()["engine_provider"] + # Special case: the default engine can be selected + # when no provider is explicitly configured and it is not an error + # to keep using it when nothing is explicitly configured. + if ( + self._engine_provider != "default" + and self._engine_provider not in configured_providers + ): + raise RuntimeError( + f"Previously selected engine ({self._engine_provider}) is no longer" + " configured." + ) + + return self._engine_class(self) + + for provider, engine_class in get_engine_classes( + self._engine_name, + default=self._default_engine, + ): + engine = engine_class(self) + if engine.accepts(X, y=y, sample_weight=sample_weight): + self._engine_provider = provider + self._engine_class = engine_class + return engine + + def is_classifier(estimator): """Return True if the given estimator is (probably) a classifier. diff --git a/sklearn/cluster/_kmeans.py b/sklearn/cluster/_kmeans.py index d1da355290073..0dcc7363714c9 100644 --- a/sklearn/cluster/_kmeans.py +++ b/sklearn/cluster/_kmeans.py @@ -18,10 +18,12 @@ import numpy as np import scipy.sparse as sp +from .._engine import convert_attributes from ..base import ( BaseEstimator, ClassNamePrefixFeaturesOutMixin, ClusterMixin, + EngineAwareMixin, TransformerMixin, _fit_context, ) @@ -280,15 +282,196 @@ def _kmeans_plusplus( # K-means batch estimation by EM (expectation maximization) -def _tolerance(X, tol): - """Return a tolerance which is dependent on the dataset.""" - if tol == 0: - return 0 - if sp.issparse(X): - variances = mean_variance_axis(X, axis=0)[1] - else: - variances = np.var(X, axis=0) - return np.mean(variances) * tol +class _IgnoreParam: + pass + + +class KMeansCythonEngine: + """Cython-based implementation of the core k-means routines + + This implementation is meant to be swappable by alternative implementations + in third-party packages via the sklearn_engines entry-point and the + `engine_provider` kwarg of `sklearn.config_context`. + + TODO: see URL for more details. + """ + + @staticmethod + def convert_to_sklearn_types(name, value): + """Convert estimator attributes to scikit-learn types. + + Users can configure whether estimator attributes should be stored + using engine native types or scikit-learn types. This function is + used to convert attributes from engine to scikit-learn native types. + + Scikit-learn native types are ndarrays and basic Python types. There + is no need to convert these. + + Parameters + ---------- + name : str + Name of the attribute being converted. + + value + Value of the attribute being converted. + + Returns + -------- + converted + Attribute value converted to a scikit-learn native type. + """ + # XXX Maybe a bit useless as it should never get called, but it + # does demonstrate the API + return value + + def __init__(self, estimator): + self.estimator = estimator + + def accepts(self, X, y=None, sample_weight=None): + """Determine if input data and hyper-parameters are supported by + this engine. + + Determine if this engine can handle the hyper-parameters of the + estimator as well as the input data. If not, return `False`. This + method is called during engine selection where each enabled engine + is tried in the user defined order. + + Should fail as quickly as possible. + """ + # The default engine accepts everything + return True + + def prepare_fit(self, X, y=None, sample_weight=None): + estimator = self.estimator + + X = estimator._validate_data( + X, + accept_sparse="csr", + dtype=[np.float64, np.float32], + order="C", + copy=estimator.copy_x, + accept_large_sparse=False, + ) + # this sets estimator _algorithm implicitly + # XXX: shall we explose this logic as part of then engine API? + # or is the current API flexible enough? + estimator._check_params_vs_input(X) + + # TODO: delegate rng and sample weight checks to engine + random_state = check_random_state(estimator.random_state) + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) + + # Also store the number of threads on the estimator to be reused at + # prediction time XXX: shall we wrap engine-specific private fit + # attributes in a predict context dict set as attribute on the + # estimator? + estimator._n_threads = self._n_threads = _openmp_effective_n_threads() + + # Validate init array + init = estimator.init + init_is_array_like = _is_arraylike_not_scalar(init) + if init_is_array_like: + init = check_array(init, dtype=X.dtype, copy=True, order="C") + estimator._validate_center_shape(X, init) + + # subtract of mean of x for more accurate distance computations + if not sp.issparse(X): + X_mean = X.mean(axis=0) + # The copy was already done above + X -= X_mean + + if init_is_array_like: + init -= X_mean + + self.X_mean = X_mean + + # precompute squared norms of data points + x_squared_norms = row_norms(X, squared=True) + + if estimator._algorithm == "elkan": + kmeans_single = _kmeans_single_elkan + else: + kmeans_single = _kmeans_single_lloyd + estimator._check_mkl_vcomp(X, X.shape[0]) + + self.x_squared_norms = x_squared_norms + self.kmeans_single_func = kmeans_single + self.random_state = random_state + self.tol = self.scale_tolerance(X, estimator.tol) + self.init = init + return X, y, sample_weight + + def init_centroids(self, X, sample_weight): + # XXX: the actual implementation of the centroids init should also be + # moved to the engine. + return self.estimator._init_centroids( + X, + x_squared_norms=self.x_squared_norms, + init=self.init, + random_state=self.random_state, + sample_weight=sample_weight + ) + + def scale_tolerance(self, X, tol): + """Return a tolerance which is dependent on the dataset.""" + if tol == 0: + return 0 + if sp.issparse(X): + _, variances = mean_variance_axis(X, axis=0) + else: + variances = np.var(X, axis=0) + return np.mean(variances) * tol + + def unshift_centers(self, X, best_centers): + if not sp.issparse(X): + if not self.estimator.copy_x: + X += self.X_mean + best_centers += self.X_mean + + def is_same_clustering(self, labels, best_labels, n_clusters): + return _is_same_clustering(labels, best_labels, n_clusters) + + def count_distinct_clusters(self, cluster_labels): + """Count the number of unique centers""" + return len(set(cluster_labels)) + + def kmeans_single(self, X, sample_weight, centers_init): + return self.kmeans_single_func( + X, + sample_weight, + centers_init, + max_iter=self.estimator.max_iter, + tol=self.tol, + n_threads=self._n_threads, + verbose=self.estimator.verbose, + ) + + def prepare_prediction(self, X, sample_weight): + X = self.estimator._check_test_data(X) + sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) + return X, sample_weight + + def get_labels(self, X, sample_weight): + labels, _ = _labels_inertia_threadpool_limit( + X, + sample_weight, + self.estimator.cluster_centers_, + n_threads=self.estimator._n_threads, + ) + + return labels + + def prepare_transform(self, X): + return self.estimator._check_test_data(X) + + def get_euclidean_distances(self, X): + return euclidean_distances(X, self.estimator.cluster_centers_) + + def get_score(self, X, sample_weight): + _, scores = _labels_inertia_threadpool_limit( + X, sample_weight, self.estimator.cluster_centers_, self.estimator._n_threads + ) + return scores @validate_params( @@ -867,9 +1050,6 @@ def _check_params_vs_input(self, X, default_n_init=None): f"n_samples={X.shape[0]} should be >= n_clusters={self.n_clusters}." ) - # tol - self._tol = _tolerance(X, self.tol) - # n-init # TODO(1.4): Remove self._n_init = self.n_init @@ -1208,7 +1388,7 @@ def _more_tags(self): } -class KMeans(_BaseKMeans): +class KMeans(_BaseKMeans, EngineAwareMixin): """K-Means clustering. Read more in the :ref:`User Guide `. @@ -1386,6 +1566,9 @@ class KMeans(_BaseKMeans): ], } + _engine_name = "kmeans" + _default_engine = KMeansCythonEngine + def __init__( self, n_clusters=8, @@ -1444,6 +1627,7 @@ def _warn_mkl_vcomp(self, n_active_threads): f" variable OMP_NUM_THREADS={n_active_threads}." ) + @convert_attributes @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y=None, sample_weight=None): """Compute k-means clustering. @@ -1472,69 +1656,27 @@ def fit(self, X, y=None, sample_weight=None): self : object Fitted estimator. """ - X = self._validate_data( + engine = self._get_engine(X, y, sample_weight, reset=True) + + X, y, sample_weight = engine.prepare_fit( X, - accept_sparse="csr", - dtype=[np.float64, np.float32], - order="C", - copy=self.copy_x, - accept_large_sparse=False, + y=y, + sample_weight=sample_weight, ) - self._check_params_vs_input(X) - - random_state = check_random_state(self.random_state) - sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) - self._n_threads = _openmp_effective_n_threads() - - # Validate init array - init = self.init - init_is_array_like = _is_arraylike_not_scalar(init) - if init_is_array_like: - init = check_array(init, dtype=X.dtype, copy=True, order="C") - self._validate_center_shape(X, init) - - # subtract of mean of x for more accurate distance computations - if not sp.issparse(X): - X_mean = X.mean(axis=0) - # The copy was already done above - X -= X_mean - - if init_is_array_like: - init -= X_mean - - # precompute squared norms of data points - x_squared_norms = row_norms(X, squared=True) - - if self._algorithm == "elkan": - kmeans_single = _kmeans_single_elkan - else: - kmeans_single = _kmeans_single_lloyd - self._check_mkl_vcomp(X, X.shape[0]) - best_inertia, best_labels = None, None for i in range(self._n_init): # Initialize centers - centers_init = self._init_centroids( - X, - x_squared_norms=x_squared_norms, - init=init, - random_state=random_state, - sample_weight=sample_weight, - ) + centers_init = engine.init_centroids(X, sample_weight) if self.verbose: print("Initialization complete") # run a k-means once - labels, inertia, centers, n_iter_ = kmeans_single( + labels, inertia, centers, n_iter_ = engine.kmeans_single( X, sample_weight, centers_init, - max_iter=self.max_iter, - verbose=self.verbose, - tol=self._tol, - n_threads=self._n_threads, ) # determine if these results are the best so far @@ -1544,19 +1686,17 @@ def fit(self, X, y=None, sample_weight=None): # permuted labels, due to rounding errors) if best_inertia is None or ( inertia < best_inertia - and not _is_same_clustering(labels, best_labels, self.n_clusters) + and not engine.is_same_clustering(labels, best_labels, self.n_clusters) ): best_labels = labels best_centers = centers best_inertia = inertia best_n_iter = n_iter_ - if not sp.issparse(X): - if not self.copy_x: - X += X_mean - best_centers += X_mean + engine.unshift_centers(X, best_centers) + + distinct_clusters = engine.count_distinct_clusters(best_labels) - distinct_clusters = len(set(best_labels)) if distinct_clusters < self.n_clusters: warnings.warn( "Number of distinct clusters ({}) found smaller than " @@ -1573,6 +1713,125 @@ def fit(self, X, y=None, sample_weight=None): self.n_iter_ = best_n_iter return self + def predict(self, X, sample_weight="deprecated"): + """Predict the closest cluster each sample in X belongs to. + + In the vector quantization literature, `cluster_centers_` is called + the code book and each value returned by `predict` is the index of + the closest code in the code book. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to predict. + + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. + + .. deprecated:: 1.3 + The parameter `sample_weight` is deprecated in version 1.3 + and will be removed in 1.5. + + Returns + ------- + labels : ndarray of shape (n_samples,) + Index of the cluster each sample belongs to. + """ + check_is_fitted(self) + engine = self._get_engine(X, sample_weight=sample_weight) + if isinstance(sample_weight, str) and sample_weight == "deprecated": + # Caller left the default value of sample_weight unchanged. + sample_weight = None + else: + # Caller explicitly passed sample_weight, so we warn. + warnings.warn( + "'sample_weight' was deprecated in version 1.3 and " + "will be removed in 1.5.", + FutureWarning, + ) + X, sample_weight = engine.prepare_prediction(X, sample_weight) + return engine.get_labels(X, sample_weight) + + def fit_transform(self, X, y=None, sample_weight=None): + """Compute clustering and transform X to cluster-distance space. + + Equivalent to fit(X).transform(X), but more efficiently implemented. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to transform. + + y : Ignored + Not used, present here for API consistency by convention. + + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. + + Returns + ------- + X_new : ndarray of shape (n_samples, n_clusters) + X transformed in the new space. + """ + self.fit(X, sample_weight=sample_weight) + engine = self._get_engine(X, y=y, sample_weight=sample_weight) + return self._transform(X, engine) + + def transform(self, X): + """Transform X to a cluster-distance space. + + In the new space, each dimension is the distance to the cluster + centers. Note that even if X is sparse, the array returned by + `transform` will typically be dense. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data to transform. + + Returns + ------- + X_new : ndarray of shape (n_samples, n_clusters) + X transformed in the new space. + """ + check_is_fitted(self) + engine = self._get_engine(X) + X = engine.prepare_transform(X) + return self._transform(X, engine) + + def _transform(self, X, engine): + """Guts of transform method; no input validation.""" + return engine.get_euclidean_distances(X) + + def score(self, X, y=None, sample_weight=None): + """Opposite of the value of X on the K-means objective. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + New data. + + y : Ignored + Not used, present here for API consistency by convention. + + sample_weight : array-like of shape (n_samples,), default=None + The weights for each observation in X. If None, all observations + are assigned equal weight. + + Returns + ------- + score : float + Opposite of the value of X on the K-means objective. + """ + check_is_fitted(self) + engine = self._get_engine(X, y=y, sample_weight=sample_weight) + + X, sample_weight = engine.prepare_prediction(X, sample_weight) + + return -engine.get_score(X, sample_weight) + def _mini_batch_step( X, @@ -1933,6 +2192,15 @@ def __init__( def _check_params_vs_input(self, X): super()._check_params_vs_input(X, default_n_init=3) + if self.tol > 0: + if sp.issparse(X): + _, variances = mean_variance_axis(X, axis=0) + else: + variances = np.var(X, axis=0) + self._tol = np.mean(variances) * self.tol + else: + self._tol = 0.0 + self._batch_size = min(self.batch_size, X.shape[0]) # init_size diff --git a/sklearn/exceptions.py b/sklearn/exceptions.py index ad7ae08c1fec0..712b9731235e4 100644 --- a/sklearn/exceptions.py +++ b/sklearn/exceptions.py @@ -5,6 +5,7 @@ __all__ = [ "NotFittedError", + "NotSupportedByEngineError", "ConvergenceWarning", "DataConversionWarning", "DataDimensionalityWarning", @@ -64,6 +65,19 @@ class NotFittedError(ValueError, AttributeError): """ +class NotSupportedByEngineError(NotImplementedError): + """External plugins might not support all the combinations of parameters and + input types that the vanilla sklearn implementation otherwise supports. In such + cases, plugins can raise this exception. When running the sklearn test modules + using the sklearn pytest plugin, all the unit tests that fail by raising this + exception class will be automatically marked as "xfail", this enables sorting out + the tests that fail because they test features that are not supported by the plugin + and tests that fail because the plugin misbehave on supported features. + + .. versionadded:: 1.4 + """ + + class ConvergenceWarning(UserWarning): """Custom warning to capture convergence problems diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 1b92d58a5f28e..6e0dde421c364 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -18,6 +18,8 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), + "engine_attributes": "engine_types", "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, @@ -36,6 +38,8 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), + "engine_attributes": "engine_types", "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, @@ -71,6 +75,8 @@ def test_config_context(): "array_api_dispatch": False, "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, + "engine_provider": (), + "engine_attributes": "engine_types", "transform_output": "default", "enable_metadata_routing": False, "skip_parameter_validation": False, From 7d52073b15ee920c6f49208c777e7ce7663ff74b Mon Sep 17 00:00:00 2001 From: Tim Head Date: Mon, 6 Nov 2023 15:20:53 +0100 Subject: [PATCH 2/2] Missing changes --- doc/whats_new/v1.4.rst | 3 --- 1 file changed, 3 deletions(-) diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index ca62ec2e0a1da..7aa1b5e2501a5 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -48,7 +48,6 @@ Changes impacting all modules to work with our estimators and functions. :pr:`26464` by `Thomas Fan`_. -<<<<<<< HEAD - |Enhancement| Experimental engine API (no backward compatibility guarantees) to allow for external packages to contribute alternative implementations for the core computational routines of some selected scikit-learn estimators. @@ -66,7 +65,6 @@ Changes impacting all modules :pr:`25535` by :user:`ogrisel`, :user:`fcharras` and :user:`betatim`. -======= - |Enhancement| The HTML representation of estimators now includes a link to the documentation and is color-coded to denote whether the estimator is fitted or not (unfitted estimators are orange, fitted estimators are blue). @@ -197,7 +195,6 @@ and classes are impacted: :user:`Yao Xiao `; - :class:`preprocessing.PolynomialFeatures` in :pr:`27166` by :user:`Mohit Joshi `. ->>>>>>> main Changelog ---------