-
Notifications
You must be signed in to change notification settings - Fork 14
ENH: at
: add __setitem__
fancy indexing fallback
#395
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @amacati !
__setitem__
fancy indexing fallback
__setitem__
fancy indexing fallbackat
: add __setitem__
fancy indexing fallback
The fallback is now implemented as error handling for One remaining concern is out-of-bounds indices. For anything negative or positive |
Seems fine to leave that for now. I guess the pattern would be to hide things behind |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, thanks @amacati !
Would you like to take a look @crusaderky ?
the codecov failure is not a worry |
lint failures are real but trivial |
I'll release 0.8.1 soon, and merge this for 0.9.0 by the end of the week if there is no further feedback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns regarding user experience.
This PR introduces support for integer array indices on backends that don't support it in __setitem__
.
However, it makes it work
- exclusively on
set()
; not onadd()
nor any other methods; - exclusively when it is on axis 0;
- exclusively when it is not expressed as a tuple;
- exclusively for data of trivial size (anything in the magnitude of megabytes will crash with MemoryError).
This I suspect will result in a rather unpleasant user experience for those that don't fall in this very specific use case.
>>> import numpy as np | ||
>>> import array_api_strict as xp | ||
>>> import array_api_extra as xpx | ||
>>> xpx.at(np.asarray([0]), np.asarray([0, 0])).set(np.asarray([2, 3])) | ||
array([3]) | ||
>>> xpx.at(xp.asarray([0]), xp.asarray([0, 0])).set(xp.asarray([2, 3])) | ||
Array([3], dtype=array_api_strict.int64) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example leaves me confused. I don't think it adds anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The aim is to show that np's and xpx's behavior is identical. For torch tensors on the GPU you would see
>>> xpx.at(torch.tensor([0]).cuda(), torch.tensor([0, 0]).cuda()).set(torch.tensor([2, 3]).cuda())
torch.Tensor([2], dtype=torch.int64)
except IndexError as e: | ||
if "Fancy indexing" not in str(e): # Avoid masking other index errors | ||
raise e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite fragile as it cherry-picks array-api-strict's behaviour. Different libraries would have different error messages and different exceptions.
except IndexError as e: | |
if "Fancy indexing" not in str(e): # Avoid masking other index errors | |
raise e | |
except Exception as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought about this as well. However, I am strongly opposed to a blank except. We would mask errors for regular frameworks that would subsequently enter an unexpected code path which may throw obscure errors. Hence the commend on masking other index errors. This feels almost worse than the added benefit of having array-api-strict support for integer indexing.
src/array_api_extra/_lib/_at.py
Outdated
) | ||
raise IndexError(msg) from e | ||
|
||
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x_mask = xp.any(xp.arange(x.shape[0])[..., None] == u_idx_pos, axis=-1) | |
x_rng = xp.arange(x.shape[0], device=device(u_idx_pos)) | |
x_mask = xp.any(x_rng[..., None] == u_idx_pos, axis=-1) |
Could you add a comment explaning what you're doing here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to add in the documentation above a note warning that the implementation is quadratic.
If x is 10 MiB along axis 0 and the u_idx_pos is 10 MiB, this line transitorily consumes 100 terabytes of RAM.
Have you considered using searchsorted
?
@@ -355,9 +370,70 @@ def _op( | |||
# Backends without boolean indexing (other than JAX) crash here | |||
if in_place_op: # add(), subtract(), ... | |||
x[idx] = in_place_op(x[idx], y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These remain broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I'm not sure if we should attempt to fix them. See my general comment.
Correct. I think this should be discussed before moving forward. The only way to use array indices for these operations while remaining compliant with the standard (please correct me if I am wrong!) is to use boolean masking. However, creating masks that exactly replicate the selection from integer indices are hard to create. The aim of restricting axis to 0, no tuples, and only for If we can't express the logic of integer indexing easily with existing array API functions, we might want to rethink if adding it makes sense. |
I have updated the code with most of the suggested changes from the review. However: The current implementation is broken for non-unique indices. One test case should fail. To see why, consider the following code import array_api_extra as xpx
import array_api_strict as xp
x = xp.asarray([0, 1, 2])
y = xp.asarray([3, 4, 5])
idx = xp.asarray([0, 1, 0])
print(xpx.at(x, idx).set(y)) # [4, 5, 2], should be [5, 4, 2] The reason why this happens is that we construct two masks for the set operation, one for However, we have no way to express the fact that the first @crusaderky What's your opinion on this? These workarounds don't exactly feel great given we are currently doing this only for array-api-strict, which is only used for testing. Having such a complex separate code path where the testing framework deviates from the other frameworks kind of defeats the point. |
array-api-strict is a proxy for not explicitly tested, possibly not yet known or even existing, additional libraries. I honestly doubt that this PR should be merged, given its complication and its way-too-many caveats. |
I sadly have to agree with the feeling. There are too many complications for a too narrow scope. If it was a general solution for integer indexing for all operations it might have been worth the trouble, but as it stands it's probably best to wait for the standard to advance (if it ever does). |
To jog my memory, is the blocker for the standard that existing libraries implement orthogonal semantics, so standardising one set of behaviour would mean a significant break for at least some library? |
Fancy indexing is currently not supported for
__setitem__
, which blocks some PRs in scipy (scipy/scipy#23425).As discussed in data-apis/array-api#864 (comment), all frameworks of the array api already implement this feature, but not necessarily in a consistent manner for duplicate indices. It is currently not part of the standard. This PR adds a workaround for array api strict to allow fancy indexing in
xpx.at(x, ...)
.