-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Add Array API compatibility for entropy
#29141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. By the look of it there is still significant work to do, see:
labels1 = xp.asarray([0, 0, 42.0]) | ||
labels2 = xp.asarray([]) | ||
labels3 = xp.asarray([1, 1, 1, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be explicit about dtypes and let's actually test on the device returned by the fixture:
labels1 = xp.asarray([0, 0, 42.0]) | |
labels2 = xp.asarray([]) | |
labels3 = xp.asarray([1, 1, 1, 1]) | |
float_labels = xp.asarray([0, 0, 42.0], device=device, dtype=dtype_name) | |
empty_int32_labels = xp.asarray([], dtype="int32", device=device) | |
int_labels = xp.asarray([1, 1, 1, 1], device=device) |
the rest of the test will need to be updated accordingly (along with the code, I think).
You can use https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c to launch this test on a machine with pytorch running on a non-CPU device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about adding integer dtype_names in yield_namespace_device_dtype_combinations
? Right now there are many asserts that do not depend on dtype_name
, thus they won't check anything new. E.g. assert for empty_int32_labels
with xp=torch, device=cpu, dtype_name=float64
and xp=torch, device=cpu, dtype_name=float32
.
Adding integer dtypes could help remove such repetitions.
For example, this test could be rewritten as
@pytest.mark.parametrize(
"array_namespace, device, dtype_name",
yield_namespace_device_dtype_combinations(yield_integers=True),
)
def test_entropy_array_api(array_namespace, device, dtype_name):
xp = _array_api_for_tests(array_namespace, device)
labels = xp.asarray(np.asarray([0, 0, 42.0], dtype=dtype_name), device=device)
empty_labels = xp.asarray(np.asarray([], dtype=dtype_name), device=device)
constant_labels = xp.asarray(np.asarray([1, 1, 1, 1], dtype=dtype_name), device=device)
with config_context(array_api_dispatch=True):
assert_almost_equal(entropy(labels), 0.6365141, 5)
assert entropy(empty_labels) == 1
assert entropy(constant_labels) == 0
If it makes sense I can open separate issue for discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to have a combinatory explosion of test cases where most cases are not likely to yield interesting things to test.
For tests that require integer dtypes, we can just use either xp.int32
or xp.int64
that should be enough for most scikit-learn use cases and are supported by all known platforms.
For floating values, it's interesting to test with xp.float32
and xp.float64
(when possible) because most GPUs work best in xp.float32
and sometimes do not support xp.float64
which is the default dtype of numpy, hence the special need to test for both when possible, conditionally on the choice of namespace and device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, almost LGTM! See below:
pi = np.bincount(label_idx).astype(np.float64) | ||
pi = pi[pi > 0] | ||
|
||
pi = xp.astype(xp.unique_counts(labels)[1], xp.float64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice code simplification by the way :)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the above two comments are dealt with and assuming tests still pass, LGTM. Thanks for the PR.
…o array-api/entropy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @Tialo
Towards #26024
Reference Issues/PRs
What does this implement/fix? Explain your changes.
Any other comments?