Skip to content

Conversation

amacati
Copy link

@amacati amacati commented Aug 21, 2025

Fancy indexing is currently not supported for __setitem__, which blocks some PRs in scipy (scipy/scipy#23425).

As discussed in data-apis/array-api#864 (comment), all frameworks of the array api already implement this feature, but not necessarily in a consistent manner for duplicate indices. It is currently not part of the standard. This PR adds a workaround for array api strict to allow fancy indexing in xpx.at(x, ...).

@amacati amacati changed the title Add __setitem__ fancy indexing support for array-api-strict ENH: Add __setitem__ fancy indexing support for array-api-strict Aug 21, 2025
@lucascolley lucascolley added the enhancement New feature or request label Aug 22, 2025
Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @amacati !

@lucascolley lucascolley changed the title ENH: Add __setitem__ fancy indexing support for array-api-strict ENH: Add __setitem__ fancy indexing fallback Aug 22, 2025
@lucascolley lucascolley changed the title ENH: Add __setitem__ fancy indexing fallback ENH: at: add __setitem__ fancy indexing fallback Aug 22, 2025
@amacati
Copy link
Author

amacati commented Aug 23, 2025

The fallback is now implemented as error handling for __setitem__ and the behaviour for data races is documented. I also fixed the handling of negative indices and the assignment of scalar values.

One remaining concern is out-of-bounds indices. For anything negative or positive > len(x), the fallback currently does not throw an error. We would need to add value-based checks to catch that.

@lucascolley
Copy link
Member

For anything negative or positive > len(x), the fallback currently does not throw an error. We would need to add value-based checks to catch that.

Seems fine to leave that for now. I guess the pattern would be to hide things behind not is_lazy_array(x) like SciPy?

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me, thanks @amacati !

Would you like to take a look @crusaderky ?

@lucascolley
Copy link
Member

the codecov failure is not a worry

@lucascolley lucascolley added this to the 0.9.0 milestone Aug 23, 2025
@lucascolley
Copy link
Member

lint failures are real but trivial

@lucascolley
Copy link
Member

I'll release 0.8.1 soon, and merge this for 0.9.0 by the end of the week if there is no further feedback

Copy link
Contributor

@crusaderky crusaderky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns regarding user experience.

This PR introduces support for integer array indices on backends that don't support it in __setitem__.

However, it makes it work

  • exclusively on set(); not on add() nor any other methods;
  • exclusively when it is on axis 0;
  • exclusively when it is not expressed as a tuple;
  • exclusively for data of trivial size (anything in the magnitude of megabytes will crash with MemoryError).

This I suspect will result in a rather unpleasant user experience for those that don't fall in this very specific use case.

Comment on lines +158 to +164
>>> import numpy as np
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3]))
array([3])
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3]))
Array([3], dtype=array_api_strict.int64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example leaves me confused. I don't think it adds anything?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see

>>> xpx.at(torch.tensor([0]).cuda(), torch.tensor([0, 0]).cuda()).set(torch.tensor([2, 3]).cuda())
torch.Tensor([2], dtype=torch.int64)

Comment on lines +378 to +380
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.

Suggested change
except IndexError as e:
if "Fancy indexing" not in str(e): # Avoid masking other index errors
raise e
except Exception as e:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing.

)
raise IndexError(msg) from e

x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1)
x_rng = xp.arange(x.shape[0], device=device(u_idx_pos))
x_mask = xp.any(x_rng[..., None] == u_idx_pos, axis=-1)

Could you add a comment explaning what you're doing here?

Copy link
Contributor

@crusaderky crusaderky Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add in the documentation above a note warning that the implementation is quadratic.

If x is 10 MiB along axis 0 and the u_idx_pos is 10 MiB, this line transitorily consumes 100 terabytes of RAM.
Have you considered using searchsorted?

@@ -355,9 +370,70 @@ def _op(
# Backends without boolean indexing (other than JAX) crash here
if in_place_op: # add(), subtract(), ...
x[idx] = in_place_op(x[idx], y)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These remain broken.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I'm not sure if we should attempt to fix them. See my general comment.

@amacati
Copy link
Author

amacati commented Aug 28, 2025

I have concerns regarding user experience.

This PR introduces support for integer array indices on backends that don't support it in __setitem__.

However, it makes it work

  • exclusively on set(); not on add() nor any other methods;
  • exclusively when it is on axis 0;
  • exclusively when it is not expressed as a tuple;
  • exclusively for data of trivial size (anything in the magnitude of megabytes will crash with MemoryError).

This I suspect will result in a rather unpleasant user experience for those that don't fall in this very specific use case.

Correct. I think this should be discussed before moving forward.

The only way to use array indices for these operations while remaining compliant with the standard (please correct me if I am wrong!) is to use boolean masking. However, creating masks that exactly replicate the selection from integer indices are hard to create. The aim of restricting axis to 0, no tuples, and only for set() was to reduce the potential for errors and keep the complexity manageable. The current implementation already has several special cases that need to be accounted for (e.g. wrapping for negative indices, uniqueness etc). I expect this to grow significantly if we open up the narrow scope of this PR.

If we can't express the logic of integer indexing easily with existing array API functions, we might want to rethink if adding it makes sense.

@amacati
Copy link
Author

amacati commented Aug 28, 2025

I have updated the code with most of the suggested changes from the review.

However:

The current implementation is broken for non-unique indices. One test case should fail. To see why, consider the following code

import array_api_extra as xpx
import array_api_strict as xp


x = xp.asarray([0, 1, 2])
y = xp.asarray([3, 4, 5])
idx = xp.asarray([0, 1, 0])
print(xpx.at(x, idx).set(y))  # [4, 5, 2], should be [5, 4, 2]

The reason why this happens is that we construct two masks for the set operation, one for x and one for y. The x mask is True where the index of array x is in idx, which in this case is [True, True, False]. The y mask is True where the unique indices in idx appear for the last time, i.e. [False, True, True].

However, we have no way to express the fact that the first True in the x mask belongs to the second True in the y mask. Thus, the values get assigned in the wrong order. One way to fix this would be to create an integer index for y instead of a mask and shuffle the values around such that the order matches x. But this isn't exactly making things less brittle.

@crusaderky What's your opinion on this? These workarounds don't exactly feel great given we are currently doing this only for array-api-strict, which is only used for testing. Having such a complex separate code path where the testing framework deviates from the other frameworks kind of defeats the point.

@crusaderky
Copy link
Contributor

These workarounds don't exactly feel great given we are currently doing this only for array-api-strict, which is only used for testing

array-api-strict is a proxy for not explicitly tested, possibly not yet known or even existing, additional libraries.

I honestly doubt that this PR should be merged, given its complication and its way-too-many caveats.
However, it has still value in demonstrating how painful it is to work around this limitation of the Array API, for functionality that is being leveraged as we speak by scipy. It should be used as exhibit A when pushing for inclusion of integer array indices in __setitem__ in the Array API standard.

@amacati
Copy link
Author

amacati commented Aug 29, 2025

I sadly have to agree with the feeling. There are too many complications for a too narrow scope. If it was a general solution for integer indexing for all operations it might have been worth the trouble, but as it stands it's probably best to wait for the standard to advance (if it ever does).

@lucascolley
Copy link
Member

lucascolley commented Aug 29, 2025

To jog my memory, is the blocker for the standard that existing libraries implement orthogonal semantics, so standardising one set of behaviour would mean a significant break for at least some library?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants