Skip to content

ENH Add Array API compatibility for entropy #29141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 14, 2024

Conversation

Tialo
Copy link
Contributor

@Tialo Tialo commented May 30, 2024

Towards #26024

Reference Issues/PRs

What does this implement/fix? Explain your changes.

Any other comments?

Copy link

github-actions bot commented May 30, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: be5fe2f. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. By the look of it there is still significant work to do, see:

Comment on lines 269 to 271
labels1 = xp.asarray([0, 0, 42.0])
labels2 = xp.asarray([])
labels3 = xp.asarray([1, 1, 1, 1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be explicit about dtypes and let's actually test on the device returned by the fixture:

Suggested change
labels1 = xp.asarray([0, 0, 42.0])
labels2 = xp.asarray([])
labels3 = xp.asarray([1, 1, 1, 1])
float_labels = xp.asarray([0, 0, 42.0], device=device, dtype=dtype_name)
empty_int32_labels = xp.asarray([], dtype="int32", device=device)
int_labels = xp.asarray([1, 1, 1, 1], device=device)

the rest of the test will need to be updated accordingly (along with the code, I think).

You can use https://gist.github.com/EdAbati/ff3bdc06bafeb92452b3740686cc8d7c to launch this test on a machine with pytorch running on a non-CPU device.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about adding integer dtype_names in yield_namespace_device_dtype_combinations? Right now there are many asserts that do not depend on dtype_name, thus they won't check anything new. E.g. assert for empty_int32_labels with xp=torch, device=cpu, dtype_name=float64 and xp=torch, device=cpu, dtype_name=float32.

Adding integer dtypes could help remove such repetitions.
For example, this test could be rewritten as

@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(yield_integers=True),
)
def test_entropy_array_api(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    labels = xp.asarray(np.asarray([0, 0, 42.0], dtype=dtype_name), device=device)
    empty_labels = xp.asarray(np.asarray([], dtype=dtype_name), device=device)
    constant_labels = xp.asarray(np.asarray([1, 1, 1, 1], dtype=dtype_name), device=device)
    with config_context(array_api_dispatch=True):
        assert_almost_equal(entropy(labels), 0.6365141, 5)
        assert entropy(empty_labels) == 1
        assert entropy(constant_labels) == 0

If it makes sense I can open separate issue for discussion.

Copy link
Member

@ogrisel ogrisel Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to have a combinatory explosion of test cases where most cases are not likely to yield interesting things to test.

For tests that require integer dtypes, we can just use either xp.int32 or xp.int64 that should be enough for most scikit-learn use cases and are supported by all known platforms.

For floating values, it's interesting to test with xp.float32 and xp.float64 (when possible) because most GPUs work best in xp.float32 and sometimes do not support xp.float64 which is the default dtype of numpy, hence the special need to test for both when possible, conditionally on the choice of namespace and device.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update, almost LGTM! See below:

pi = np.bincount(label_idx).astype(np.float64)
pi = pi[pi > 0]

pi = xp.astype(xp.unique_counts(labels)[1], xp.float64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice code simplification by the way :)

Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the above two comments are dealt with and assuming tests still pass, LGTM. Thanks for the PR.

@ogrisel ogrisel added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jun 7, 2024
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the PR!

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @Tialo

@OmarManzoor OmarManzoor removed the Waiting for Second Reviewer First reviewer is done, need a second one! label Jun 14, 2024
@OmarManzoor OmarManzoor enabled auto-merge (squash) June 14, 2024 07:10
@OmarManzoor OmarManzoor merged commit 5ced13c into scikit-learn:main Jun 14, 2024
29 checks passed
@Tialo Tialo deleted the array-api/entropy branch June 14, 2024 08:02
@jeremiedbb jeremiedbb mentioned this pull request Jul 2, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

3 participants