Skip to content

MNT Updated DistanceMetric API with new ABC/interface #26471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 3, 2023

Conversation

Micky774
Copy link
Contributor

Reference Issues/PRs

Related to #26267
Addresses #26267 (comment)

What does this implement/fix? Explain your changes.

  • Renames DistanceMetric-->DistanceMetric64 in line with DistanceMetric32
  • Adds DistanceMetric as an abstract base class for DistanceMetric{32, 64}
  • Adds a general get_metric method to DistanceMetric which accepts a dtype argument to decide on metric specialization internally, without requiring explicit use of DistanceMetric{32, 64}.get_metric

Somewhat optional:

  • Removes sklearn.neighbors.DistanceMetric as a completion of its due deprecation (clears otherwise confusing namespace)

Any other comments?

The removal of sklearn.neighbors.DistanceMetric is not strictly necessary for this, however it seems appropriate.

@Micky774
Copy link
Contributor Author

@jjerphan pinging in case you would like to take a look at the changes.

Note that this PR doesn't make full use of the new flexibility offered by the API change, however it makes it far easier for subsequent PRs to do so (e.g. #25914)

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a quick look and I like the general idea although it's breaking the build at the moment:

Error compiling Cython file:
------------------------------------------------------------
...

            next_node_min_reach = min_reachability[j]
            next_node_source = current_sources[j]

            pair_distance = dist_metric.dist(
                &raw_data[current_node, 0],
               ^
------------------------------------------------------------

sklearn/cluster/_hdbscan/_linkage.pyx:182:16: Cannot convert 'const float64_t *' to Python object

Error compiling Cython file:
------------------------------------------------------------
...
            next_node_min_reach = min_reachability[j]
            next_node_source = current_sources[j]

            pair_distance = dist_metric.dist(
                &raw_data[current_node, 0],
                &raw_data[j, 0],
               ^
------------------------------------------------------------

sklearn/cluster/_hdbscan/_linkage.pyx:183:16: Cannot convert 'const float64_t *' to Python object

@Micky774
Copy link
Contributor Author

Micky774 commented Jun 1, 2023

Thanks for pointing that out @ogrisel! I forgot to adjust for the newly-merged HDBSCAN. Should be fixed now :)

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Meekail.

Just a few comments.

Could you also add tests to check that the dispatch on np.float{32,64} data return correct instances of respectively DistanceMetric{32,64} and that it raise an error otherwise?

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once Julien's comments are addressed.

@Micky774
Copy link
Contributor Author

Micky774 commented Jun 2, 2023

@jjerphan All concerns should be addressed now 😄

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Merci!

@jjerphan jjerphan merged commit 314f7ba into scikit-learn:main Jun 3, 2023
@Micky774 Micky774 deleted the distance_metric_refactor branch June 3, 2023 13:16
Shreesha3112 pushed a commit to Shreesha3112/scikit-learn that referenced this pull request Jun 5, 2023
@jjerphan jjerphan added this to the 1.4 milestone Jun 10, 2023
@jjerphan
Copy link
Member

As of 0e253d9, we have:

In [1]: from sklearn.metrics import DistanceMetric

In [2]: DistanceMetric.get_metric("euclidean")
Out[2]: <sklearn.metrics._dist_metrics.EuclideanDistance64 at 0x7f7785681b60>

As mentioned in #25914 (comment), we might want to hide DistanceMetric's dtype-specific implementation details.

This PR is not required at all for 1.3, but might be misleading if it is released now. Hence, I am for now pinning this PR for 1.4 not to have users adhere to this change of behavior.

What do you think, @Micky774?

@Micky774
Copy link
Contributor Author

Just for visibility sake, copying my comment from the linked discussion:

On a separate note, I agree we should propogate that pattern to DistanceMetrics as well at some point, though it is admittedly not high on my personal priorities right now 😅

manudarmi pushed a commit to primait/scikit-learn that referenced this pull request Jun 12, 2023
@jjerphan jjerphan removed this from the 1.4 milestone Jun 14, 2023
@jjerphan jjerphan added this to the 1.3 milestone Jun 14, 2023
@jjerphan
Copy link
Member

Actually, the consensus is to keep the current behavior in main. I reverted the milestone to 1.3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants