-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
PERF speedup classification_report by attaching unique values to dtype.metadata #29738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
6047d62
PERF speedup classification_report by attaching unique values to dtyp…
adrinjalali 39339f8
API cleanup
adrinjalali 9663448
Merge remote-tracking branch 'upstream/main' into unique-cache
adrinjalali 7209447
changelog
adrinjalali 11f4295
move caching out of _check_targets
adrinjalali 24afeae
Merge remote-tracking branch 'upstream/main' into unique-cache
adrinjalali fd31259
add docstrings
adrinjalali ee83711
Merge remote-tracking branch 'upstream/main' into unique-cache
adrinjalali 9a51cea
Merge remote-tracking branch 'upstream/main' into unique-cache
adrinjalali 075f918
use np.unique
adrinjalali 45fa9f5
Merge remote-tracking branch 'upstream/main' into unique-cache
adrinjalali bb9537c
add tests
adrinjalali 9ec3865
add another test
adrinjalali 8b495cf
add comment
adrinjalali c6ced7c
Guillaume's comments
adrinjalali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Authors: The scikit-learn developers | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import numpy as np | ||
|
||
from sklearn.utils._array_api import get_namespace | ||
|
||
|
||
def _attach_unique(y): | ||
"""Attach unique values of y to y and return the result. | ||
|
||
The result is a view of y, and the metadata (unique) is not attached to y. | ||
""" | ||
if not isinstance(y, np.ndarray): | ||
return y | ||
try: | ||
# avoid recalculating unique in nested calls. | ||
if "unique" in y.dtype.metadata: | ||
return y | ||
except (AttributeError, TypeError): | ||
pass | ||
|
||
unique = np.unique(y) | ||
unique_dtype = np.dtype(y.dtype, metadata={"unique": unique}) | ||
return y.view(dtype=unique_dtype) | ||
|
||
|
||
def attach_unique(*ys, return_tuple=False): | ||
"""Attach unique values of ys to ys and return the results. | ||
|
||
The result is a view of y, and the metadata (unique) is not attached to y. | ||
|
||
IMPORTANT: The output of this function should NEVER be returned in functions. | ||
This is to avoid this pattern: | ||
|
||
.. code:: python | ||
|
||
y = np.array([1, 2, 3]) | ||
y = attach_unique(y) | ||
y[1] = -1 | ||
# now np.unique(y) will be different from cached_unique(y) | ||
|
||
Parameters | ||
---------- | ||
*ys : sequence of array-like | ||
Input data arrays. | ||
|
||
return_tuple : bool, default=False | ||
If True, always return a tuple even if there is only one array. | ||
|
||
Returns | ||
------- | ||
ys : tuple of array-like or array-like | ||
Input data with unique values attached. | ||
""" | ||
res = tuple(_attach_unique(y) for y in ys) | ||
if len(res) == 1 and not return_tuple: | ||
return res[0] | ||
return res | ||
|
||
|
||
def _cached_unique(y, xp=None): | ||
"""Return the unique values of y. | ||
|
||
Use the cached values from dtype.metadata if present. | ||
|
||
This function does NOT cache the values in y, i.e. it doesn't change y. | ||
|
||
Call `attach_unique` to attach the unique values to y. | ||
""" | ||
try: | ||
if y.dtype.metadata is not None and "unique" in y.dtype.metadata: | ||
return y.dtype.metadata["unique"] | ||
except AttributeError: | ||
# in case y is not a numpy array | ||
pass | ||
xp, _ = get_namespace(y, xp=xp) | ||
return xp.unique_values(y) | ||
|
||
|
||
def cached_unique(*ys, xp=None): | ||
"""Return the unique values of ys. | ||
|
||
Use the cached values from dtype.metadata if present. | ||
|
||
This function does NOT cache the values in y, i.e. it doesn't change y. | ||
|
||
Call `attach_unique` to attach the unique values to y. | ||
|
||
Parameters | ||
---------- | ||
*ys : sequence of array-like | ||
Input data arrays. | ||
|
||
xp : module, default=None | ||
Precomputed array namespace module. When passed, typically from a caller | ||
that has already performed inspection of its own inputs, skips array | ||
namespace inspection. | ||
|
||
Returns | ||
------- | ||
res : tuple of array-like or array-like | ||
Unique values of ys. | ||
""" | ||
res = tuple(_cached_unique(y, xp=xp) for y in ys) | ||
if len(res) == 1: | ||
return res[0] | ||
return res |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
from numpy.testing import assert_array_equal | ||
|
||
from sklearn.utils._unique import attach_unique, cached_unique | ||
from sklearn.utils.validation import check_array | ||
|
||
|
||
def test_attach_unique_attaches_unique_to_array(): | ||
arr = np.array([1, 2, 2, 3, 4, 4, 5]) | ||
arr_ = attach_unique(arr) | ||
assert_array_equal(arr_.dtype.metadata["unique"], np.array([1, 2, 3, 4, 5])) | ||
assert_array_equal(arr_, arr) | ||
|
||
|
||
def test_cached_unique_returns_cached_unique(): | ||
my_dtype = np.dtype(np.float64, metadata={"unique": np.array([1, 2])}) | ||
arr = np.array([1, 2, 2, 3, 4, 4, 5], dtype=my_dtype) | ||
assert_array_equal(cached_unique(arr), np.array([1, 2])) | ||
|
||
|
||
def test_attach_unique_not_ndarray(): | ||
"""Test that when not np.ndarray, we don't touch the array.""" | ||
arr = [1, 2, 2, 3, 4, 4, 5] | ||
arr_ = attach_unique(arr) | ||
assert arr_ is arr | ||
|
||
|
||
def test_attach_unique_returns_view(): | ||
"""Test that attach_unique returns a view of the array.""" | ||
arr = np.array([1, 2, 2, 3, 4, 4, 5]) | ||
arr_ = attach_unique(arr) | ||
assert arr_.base is arr | ||
|
||
|
||
def test_attach_unique_return_tuple(): | ||
"""Test return_tuple argument of the function.""" | ||
arr = np.array([1, 2, 2, 3, 4, 4, 5]) | ||
arr_tuple = attach_unique(arr, return_tuple=True) | ||
assert isinstance(arr_tuple, tuple) | ||
assert len(arr_tuple) == 1 | ||
assert_array_equal(arr_tuple[0], arr) | ||
|
||
arr_single = attach_unique(arr, return_tuple=False) | ||
assert isinstance(arr_single, np.ndarray) | ||
assert_array_equal(arr_single, arr) | ||
|
||
|
||
def test_check_array_keeps_unique(): | ||
"""Test that check_array keeps the unique metadata.""" | ||
arr = np.array([[1, 2, 2, 3, 4, 4, 5]]) | ||
arr_ = attach_unique(arr) | ||
arr_ = check_array(arr_) | ||
assert_array_equal(arr_.dtype.metadata["unique"], np.array([1, 2, 3, 4, 5])) | ||
assert_array_equal(arr_, arr) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we include a comment here that starts that the output of
attach_unique
should never be returned from a public function?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, added a comment for this.