-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[DRAFT] Engine plugin API and engine entry point for Lloyd's KMeans #25535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ogrisel
wants to merge
3
commits into
main
Choose a base branch
from
feature/engine-api
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
.. Places parent toc into the sidebar | ||
|
||
:parenttoc: True | ||
|
||
.. _engine: | ||
|
||
Computation Engines (experimental) | ||
================================== | ||
|
||
**This API is experimental** which means that it is subject to change without | ||
any backward compatibility guarantees. | ||
|
||
TODO: explain goals here | ||
|
||
Activating an engine | ||
-------------------- | ||
|
||
TODO: installing third party engine provider packages | ||
|
||
TODO: how to list installed engines | ||
|
||
TODO: how to install a plugin | ||
|
||
Writing a new engine provider | ||
----------------------------- | ||
|
||
TODO: show engine API of a given estimator. | ||
|
||
TODO: give example setup.py with setuptools to define an entrypoint. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import convert_attributes, get_engine_classes, list_engine_provider_names | ||
|
||
__all__ = ["convert_attributes", "get_engine_classes", "list_engine_provider_names"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import inspect | ||
import warnings | ||
from functools import lru_cache, wraps | ||
from importlib import import_module | ||
from importlib.metadata import entry_points | ||
|
||
from sklearn._config import get_config | ||
|
||
SKLEARN_ENGINES_ENTRY_POINT = "sklearn_engines" | ||
|
||
|
||
class EngineSpec: | ||
__slots__ = ["name", "provider_name", "module_name", "engine_qualname"] | ||
|
||
def __init__(self, name, provider_name, module_name, engine_qualname): | ||
self.name = name | ||
self.provider_name = provider_name | ||
self.module_name = module_name | ||
self.engine_qualname = engine_qualname | ||
|
||
def get_engine_class(self): | ||
engine = import_module(self.module_name) | ||
for attr in self.engine_qualname.split("."): | ||
engine = getattr(engine, attr) | ||
return engine | ||
|
||
|
||
def _parse_entry_point(entry_point): | ||
module_name, engine_qualname = entry_point.value.split(":") | ||
provider_name = next(iter(module_name.split(".", 1))) | ||
return EngineSpec(entry_point.name, provider_name, module_name, engine_qualname) | ||
|
||
|
||
@lru_cache | ||
def _parse_entry_points(provider_names=None): | ||
specs = [] | ||
all_entry_points = entry_points() | ||
if hasattr(all_entry_points, "select"): | ||
engine_entry_points = all_entry_points.select(group=SKLEARN_ENGINES_ENTRY_POINT) | ||
else: | ||
engine_entry_points = all_entry_points.get(SKLEARN_ENGINES_ENTRY_POINT, ()) | ||
for entry_point in engine_entry_points: | ||
try: | ||
spec = _parse_entry_point(entry_point) | ||
if provider_names is not None and spec.provider_name not in provider_names: | ||
# Skip entry points that do not match the requested provider names. | ||
continue | ||
specs.append(spec) | ||
except Exception as e: | ||
# Do not raise an exception in case an invalid package has been | ||
# installed in the same Python env as scikit-learn: just warn and | ||
# skip. | ||
warnings.warn( | ||
f"Invalid {SKLEARN_ENGINES_ENTRY_POINT} entry point" | ||
f" {entry_point.name} with value {entry_point.value}: {e}" | ||
) | ||
if provider_names is not None: | ||
observed_provider_names = {spec.provider_name for spec in specs} | ||
missing_providers = set(provider_names) - observed_provider_names | ||
if missing_providers: | ||
raise RuntimeError( | ||
"Could not find any provider for the" | ||
f" {SKLEARN_ENGINES_ENTRY_POINT} entry point with name(s):" | ||
f" {', '.join(repr(p) for p in sorted(missing_providers))}" | ||
) | ||
return specs | ||
|
||
|
||
def list_engine_provider_names(): | ||
"""Find the list of sklearn_engine provider names | ||
|
||
This function only inspects the metadata and should trigger any module import. | ||
""" | ||
return sorted({spec.provider_name for spec in _parse_entry_points()}) | ||
|
||
|
||
def _get_engine_classes(engine_name, provider_names, engine_specs, default): | ||
specs_by_provider = {} | ||
for spec in engine_specs: | ||
if spec.name != engine_name: | ||
continue | ||
specs_by_provider.setdefault(spec.provider_name, spec) | ||
|
||
for provider_name in provider_names: | ||
if inspect.isclass(provider_name): | ||
# The provider name is actually a ready-to-go engine class. | ||
# Instead of a made up string to name this ad-hoc provider | ||
# we use the class itself. This mirrors what the user used | ||
# when they set the config (ad-hoc class or string naming | ||
# a provider). | ||
engine_class = provider_name | ||
if getattr(engine_class, "engine_name", None) != engine_name: | ||
continue | ||
yield engine_class, engine_class | ||
|
||
spec = specs_by_provider.get(provider_name) | ||
if spec is not None: | ||
yield spec.provider_name, spec.get_engine_class() | ||
|
||
yield "default", default | ||
|
||
|
||
def get_engine_classes(engine_name, default, verbose=False): | ||
"""Find all possible providers of `engine_name`. | ||
|
||
Provider candidates are found based on parsing entrypoint definitions that | ||
match the name of enabled engine providers, as well as, ad-hoc providers | ||
in the form of engine classes in the list of enabled engine providers. | ||
|
||
Parameters | ||
---------- | ||
engine_name : str | ||
The name of the algorithm for which to find engine classes. | ||
|
||
default : class | ||
The default engine class to use if no other provider is found. | ||
|
||
verbose : bool, default=False | ||
If True, print the name of the engine classes that are tried. | ||
|
||
Yields | ||
------ | ||
provider : str or class | ||
The "name" of each matching provider. The "name" corresponds to the | ||
entry in the `engine_provider` configuration. It can be a string or a | ||
class for programmatically registered ad-hoc providers. | ||
|
||
engine_class : | ||
The engine class that implements the algorithm for the given provider. | ||
""" | ||
provider_names = get_config()["engine_provider"] | ||
|
||
if not provider_names: | ||
yield "default", default | ||
return | ||
|
||
engine_specs = _parse_entry_points( | ||
provider_names=tuple( | ||
[name for name in provider_names if not inspect.isclass(name)] | ||
) | ||
) | ||
for provider, engine_class in _get_engine_classes( | ||
engine_name=engine_name, | ||
provider_names=provider_names, | ||
engine_specs=engine_specs, | ||
default=default, | ||
): | ||
if verbose: | ||
print( | ||
f"trying engine {engine_class.__module__}.{engine_class.__qualname__}." | ||
) | ||
yield provider, engine_class | ||
|
||
|
||
def convert_attributes(method): | ||
"""Convert estimator attributes after calling the decorated method. | ||
|
||
The attributes of an estimator can be stored in "engine native" types | ||
(default) or "scikit-learn native" types. This decorator will call the | ||
engine's conversion function when needed. Use this decorator on methods | ||
that set estimator attributes. | ||
""" | ||
|
||
@wraps(method) | ||
def wrapper(self, *args, **kwargs): | ||
r = method(self, *args, **kwargs) | ||
convert_attributes = get_config()["engine_attributes"] | ||
|
||
if convert_attributes == "sklearn_types": | ||
engine = self._engine_class | ||
for name, value in vars(self).items(): | ||
# All attributes are passed to the engine, which can | ||
# either convert the value (engine specific types) or | ||
# return it as is (native Python types) | ||
converted = engine.convert_to_sklearn_types(name, value) | ||
setattr(self, name, converted) | ||
|
||
# No matter which engine was used to fit, after the attribute | ||
# conversion to the sklearn native types the default engine | ||
# is used. | ||
self._engine_class = self._default_engine | ||
self._engine_provider = "default" | ||
|
||
return r | ||
|
||
return wrapper |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WDYT of raising an error in this case ? or make the "engine_name" optional ? (just realized my configuration wasn't working properly because the class I was passing to the config context missed the "engine_name" attribute)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all classes that want to be an engine should have a
engine_name
attribute. Otherwise it is unclear what engine the class implements. Imagine you wrap a pipeline that has several steps that each have a plugin.So along those lines, maybe we should use
if engine_class.engine_name != engine_name:
here, which would lead to an exception if the class doesn't define aengine_name
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for letting it fail if the attribute is not defined
The default cython engine for kmeans doesn't define this attribute. Even if it's useless here, it would be helpful to define it for the same reasons
accepts
is ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is true. I forgot that for the default engine and all engines that aren't "ad-hoc" engines (is this a good name for classes directly passed in, compared to those registered via an entrypoint?) the look up has an additional level of indirection via a engine spec. The engine spec has a
name
attribute that contains the name of the engine. Justname
seemed to generic as a name of the attribute for the "ad-hoc" classes, but maybe we should unify it?I'm unsure about defining it also for the non ad-hoc engine classes. It adds an additional location where this information is stored (attribute of the class and name of the entrypoint) and they could become out of sync. However, I also like having symmetry/no differences between the entrypoint and ad-hoc engines :-/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe the engine name could only be defined on the class itself and we could update the engine specs in the importlib metadata to only register the class?
But that means that we will need to import the class to discover its name, which is a bit sad, because it could trigger import overhead from plugins unrelated to the specific algorithm of interest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But maybe this is not a big deal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could add a
engine_name
attribute to the engine class when it is loaded viaEngineSpec.get_engine_class
?