-
Notifications
You must be signed in to change notification settings - Fork 53
Issues with "Mixing arrays and Python scalars" section #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Actually, I think |
IMHO, the spec should split the operators into four classes:
Arithmetic operators should be defined for integer and floating point dtypes, but not required for the boolean dtype. Boolean operators should be defined for integer (bitwise) and boolean dtypes, but are not required for floating point dtypes. Comparison operators are defined for all arrays regardless of dtype. The array operator (matmul) is defined only for integer and floating point dtypes, and also has additional shape constraints (the input arrays must have at least 1 (2?) dimensions, which align in a certain way). Also, arithmetic, boolean, and array operators participate in type promotion. Comparison operators may error on invalid type combinations (which aren't specified anyway), but always return an array with a bool dtype. But I may also be missing or forgetting prior discussions that happened with this. |
A similar split could be made for the 2-argument array functions. |
Also, it looks like NumPy doesn't do what the spec says for float scalars. The spec says:
But NumPy apparently converts float scalars into float64: >>> (np.array(0.0, dtype=np.float32) * 0.0).dtype
dtype('float64')
>>> (np.array(0.0, dtype=np.float32) * np.array(0.0, dtype=np.float32)).dtype
dtype('float32') Is it intentional to give a different behavior from NumPy here (I didn't check the other array libraries yet)? |
This is based on what the spec currently says, but there are some issues with it (see data-apis/array-api#98).
Yes, this is an intentional deviation, copied from libraries such as PyTorch and JAX. On GPU/TPU it is important to avoid inadvertent type promotion to float64. |
The NumPy behavior is also apparently only that way for shape >>> (np.array(0.0, dtype=np.float32) * 0.0).dtype
dtype('float64')
>>> (np.array([0.0], dtype=np.float32) * 0.0).dtype
dtype('float32') So we should consider the NumPy promotion to float64 as incorrect (I intended to raise this issue on the NumPy tracker but didn't get around to it yet). |
NumPy's type promotion rules treat 0d arrays the same scalars:
|
Yeah, the behavior is the same whether it's an "array scalar" or "scalar constant" (or whatever NumPy calls them). But we don't make any such distinction in the spec, and don't special-case 0-D arrays as far as type promotion is concerned:
|
Right, we intentionally deviated from NumPy here. There is no notion of
separate "scalar" types, and 0d arrays are not treated differently from
other arrays.
…On Wed, Dec 16, 2020 at 5:29 PM Aaron Meurer ***@***.***> wrote:
Yeah, the behavior is the same whether it's an "array scalar" or "scalar
constant" (or whatever NumPy calls them). But we don't make any such
distinction in the spec, and don't special-case 0-D arrays as far as type
promotion is concerned:
Type promotion rules must apply when determining the common result type
for two array operands during an arithmetic operation, regardless of array
dimension. Accordingly, zero-dimensional arrays must be subject to the same
type promotion rules as dimensional arrays.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#98 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJJFVUXHRC3JRSUHYQIS6DSVFNGJANCNFSM4UODP77Q>
.
|
I agree with all this, seems like that would help make the operator description and behaviour clearer.
I would not mix this into the type promotion section; that would make it harder to understand for little gain. The individual functions already describe this, and I don't think anyone will be surprised by for example |
Yes, as far as organization in the spec document, I think the above distinctions should go on the page about operators, not the page about type promotion (or someone more general if we also apply it to functions). The type promotion page describes those operator classes that participate in type promotion, but doesn't talk about the other ones. |
Here's how libraries currently behave, testing an Code:
Results:
So this is a mess. I personally think NumPy and MXNet are behaving the most consistent here, what CuPy does is also not unreasonable, and what PyTorch/TensorFlow/JAX do looks bad.
Yes indeed. Here there's only two flavors, upcasting to
Only NumPy and CuPy upcast; MXNet, TensorFlow, JAX and PyTorch all keep In both the |
Unsigned integer overflow is technically undefined behavior in C/C++, and that seems to work out OK. It's not great, but I think these inconsistencies are mostly an indication that nobody expects Python's big ints to do something sensible when combined with fixed size arrays. |
Wrt the OP:
Grouping operators into separate classes was addressed in #308. As the concerns raised in this issue have been resolved, will close out. Any further concerns and discussion can be raised on a new issue. |
This is based on what the spec currently says, but there are some issues with it (see data-apis/array-api#98).
A few issues with the mixing arrays and Python scalars section that came up from testing:
@
) does not (and should not) support scalar types, so it should be excluded.int8 array + large Python int
. Should it cast the int, or give an error. Or should this be unspecified.The text was updated successfully, but these errors were encountered: