-
Notifications
You must be signed in to change notification settings - Fork 53
RFC: update full
to accept 0D arrays for fill_value
#909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@mdhaber Do you have a sense as to how common array |
full
accept 0D arrays for fill_value
?full
to accept 0D arrays for fill_value
Yes, support is already quite good. I haven't tested MLX. from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
x = xp.asarray(0.)
try:
print(xp.full((2, 2), x))
except Exception as e:
print(e)
|
This makes sense to me. The fill value might come from something like |
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
x = xp.asarray(0, dtype=xp.int16)
try:
print(xp.__name__, xp.full((2, 2), x, dtype=xp.int32).dtype)
except Exception as e:
print(e)
So Similar with from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
zeros = xp.zeros((2, 2), dtype=xp.int16)
x = xp.asarray(0, dtype=xp.int16)
try:
print(xp.__name__, xp.full_like(zeros, x, dtype=xp.int32).dtype)
except Exception as e:
print(e)
from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
zeros = xp.zeros((2, 2), dtype=xp.int16)
x = xp.asarray(0, dtype=xp.int32)
try:
print(xp.__name__, xp.full_like(zeros, x).dtype)
except Exception as e:
print(e)
Dask currently produces a warning with: from array_api_compat.dask import array
float(xp.full((), xp.asarray(1)))
# 1.0
# FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release. |
FYI Dask loudly complains if fill_value is an array: scipy/scipy#22900 |
@crusaderky Are you referring the warning that appears just above #909 (comment)? (Please expand details.) It is not a very specific complaint, and Dask complains about a lot. It looks like a mistake on Dask's side. |
The warning message is misleading, but the root cause is that Dask doesn't expect a Dask Array as >>> import dask.array as da
>>> da.full((2, ), da.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
>>> import array_api_compat.dask.array as xp
>>> xp.full((2, ), xp.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray> |
So you mean yes - the same warning that appears at the end of the details? |
Yes, it's the same warning. |
According to the current standard, the
fill_value
offull
must be a Python scalar.Are 0d arrays allowed? This usage seems common so that the dype of the output is determined by the dtype of the
fill_value
. That's what happens inarray_api_strict
, for instance.The text was updated successfully, but these errors were encountered: