Skip to content

RFC: update full to accept 0D arrays for fill_value #909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mdhaber opened this issue Feb 28, 2025 · 9 comments
Open

RFC: update full to accept 0D arrays for fill_value #909

mdhaber opened this issue Feb 28, 2025 · 9 comments
Labels
RFC Request for comments. Feature requests and proposed changes.

Comments

@mdhaber
Copy link
Contributor

mdhaber commented Feb 28, 2025

According to the current standard, the fill_value of full must be a Python scalar.

full(shape: int | Tuple[int, ...], fill_value: bool | int | float | complex, *, dtype: dtype | None = None, device: device | None = None) → array[¶](https://data-apis.org/array-api/latest/API_specification/generated/array_api.full.html#array_api.full)

Are 0d arrays allowed? This usage seems common so that the dype of the output is determined by the dtype of the fill_value. That's what happens in array_api_strict, for instance.

xp.full(1, xp.asarray(1.0, dtype=xp.float32))
# Array([1.], dtype=array_api_strict.float32)
@kgryte
Copy link
Contributor

kgryte commented Apr 3, 2025

@mdhaber Do you have a sense as to how common array fill_value support is across the ecosystem? (e.g., NumPy, PyTorch, JAX, Dask, MLX, ndonnx, CuPy)

@kgryte kgryte added this to Proposals Apr 3, 2025
@kgryte kgryte moved this to Stage 0 in Proposals Apr 3, 2025
@kgryte kgryte changed the title Should full accept 0D arrays for fill_value? RFC: update full to accept 0D arrays for fill_value Apr 3, 2025
@kgryte kgryte added the RFC Request for comments. Feature requests and proposed changes. label Apr 3, 2025
@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 3, 2025

Yes, support is already quite good. I haven't tested MLX.

from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
    x = xp.asarray(0.)

    try:
        print(xp.full((2, 2), x))
    except Exception as e:
        print(e)
/usr/local/lib/python3.11/dist-packages/ndonnx/__init__.py:16: UserWarning: onnxruntime is not installed. ndonnx will use the (incomplete) reference implementation for value propagation.
  warn(
[[0. 0.]
 [0. 0.]]
tensor([[0., 0.],
        [0., 0.]])
[[0. 0.]
 [0. 0.]]
dask.array<full_like, shape=(2, 2), dtype=float64, chunksize=(2, 2), chunktype=numpy.ndarray>
[[0. 0.]
 [0. 0.]]
tf.Tensor(
[[0. 0.]
 [0. 0.]], shape=(2, 2), dtype=float64)
unable to infer dtype from `array(data: 0.0, dtype=float64)`

ndonnx is the only one of those that fails, and they would either want to update their error message or support it. It doesn't seem right to say that the data type can't be inferred from an array; the data type is right there.

@asmeurer
Copy link
Member

asmeurer commented Apr 3, 2025

This makes sense to me. The fill value might come from something like a[0] where it would be a 0-d array. The only concern is if the fill value dtype disagrees with the dtype argument. I guess the dtype argument should take priority, but we could also say it's undefined if not all libraries do that. The existing text also doesn't necessarily completely leave undefined the cases where the dtype and the fill value are different kinds, not sure if that is intentional.

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 3, 2025

from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
    x = xp.asarray(0, dtype=xp.int16)

    try:
        print(xp.__name__, xp.full((2, 2), x, dtype=xp.int32).dtype)
    except Exception as e:
        print(e)
array_api_compat.numpy int32
array_api_compat.torch torch.int32
array_api_compat.cupy int32
array_api_compat.dask.array int32
jax.numpy int32
tensorflow.experimental.numpy <dtype: 'int32'>
ndonnx int32

So dtype passed into xp.full takes precedence for all libraries. And ndonnx no longer complains about being unable to infer the dtype.

Similar with full_like. dtype passed into full_like takes precedence over the dtypes of either positional argument, and if no dtype is specified, the first argument's dtype takes precedence over the second.

from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
    zeros = xp.zeros((2, 2), dtype=xp.int16)
    x = xp.asarray(0, dtype=xp.int16)

    try:
        print(xp.__name__, xp.full_like(zeros, x, dtype=xp.int32).dtype)
    except Exception as e:
        print(e)
array_api_compat.numpy int32
array_api_compat.torch torch.int32
array_api_compat.cupy int32
array_api_compat.dask.array int32
jax.numpy int32
tensorflow.experimental.numpy <dtype: 'int32'>
ndonnx int32

from array_api_compat import numpy, torch, cupy
from array_api_compat.dask import array as dask
import jax.numpy as jax
import tensorflow.experimental.numpy as tf
import ndonnx
xps = [numpy, torch, cupy, dask, jax, tf, ndonnx]
for xp in xps:
    zeros = xp.zeros((2, 2), dtype=xp.int16)
    x = xp.asarray(0, dtype=xp.int32)

    try:
        print(xp.__name__, xp.full_like(zeros, x).dtype)
    except Exception as e:
        print(e)
array_api_compat.numpy int16
array_api_compat.torch torch.int16
array_api_compat.cupy int16
array_api_compat.dask.array int16
jax.numpy int16
tensorflow.experimental.numpy <dtype: 'int16'>
ndonnx int16

Dask currently produces a warning with:

from array_api_compat.dask import array
float(xp.full((), xp.asarray(1)))
# 1.0
# FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.

@crusaderky
Copy link
Contributor

FYI Dask loudly complains if fill_value is an array: scipy/scipy#22900

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 28, 2025

@crusaderky Are you referring the warning that appears just above #909 (comment)? (Please expand details.) It is not a very specific complaint, and Dask complains about a lot. It looks like a mistake on Dask's side.

@crusaderky
Copy link
Contributor

crusaderky commented Apr 28, 2025

The warning message is misleading, but the root cause is that Dask doesn't expect a Dask Array as fill_value, so the whole machinery works by pure accident:

>>> import dask.array as da
>>> da.full((2, ), da.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
  warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>

>>> import array_api_compat.dask.array as xp
>>> xp.full((2, ), xp.asarray(1)).persist()
/home/crusaderky/miniforge3/envs/array-api-compat/lib/python3.10/site-packages/dask/array/core.py:1763: FutureWarning: The `numpy.copyto` function is not implemented by Dask array. You may want to use the da.map_blocks function or something similar to silence this warning. Your code may stop working in a future release.
  warnings.warn(
dask.array<full_like, shape=(2,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>

@mdhaber
Copy link
Contributor Author

mdhaber commented Apr 28, 2025

So you mean yes - the same warning that appears at the end of the details?

@crusaderky
Copy link
Contributor

Yes, it's the same warning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
RFC Request for comments. Feature requests and proposed changes.
Projects
Status: Stage 0
Development

No branches or pull requests

4 participants