-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
Open
Labels
Description
Describe the issue:
The following code works with numpy arrays, lists and tuples but breaks on jax.numpy
arrays for example:
numpy/numpy/polynomial/polynomial.py
Lines 748 to 751 in f675dbb
if isinstance(x, (tuple, list)): | |
x = np.asarray(x) | |
if isinstance(x, np.ndarray) and tensor: | |
c = c.reshape(c.shape + (1,) * x.ndim) |
Reproduce the code example:
This results in inconsistent behavior:
>>> from numpy.polynomial import polynomial as poly
>>> b = jax.numpy.array(
[[1, 2, 3],
[2, 3, 4]]
)
>>> poly.polyval(np.array(b),np.array(b), tensor=True)
array([[[ 3., 5., 7.],
[ 5., 7., 9.]],
[[ 5., 8., 11.],
[ 8., 11., 14.]],
[[ 7., 11., 15.],
[11., 15., 19.]]])
>>> poly.polyval(b,b, tensor=True)
Array([[ 3., 8., 15.],
[ 5., 11., 19.]], dtype=float32)
Python and NumPy Versions:
>>> import sys, numpy; print(numpy.__version__); print(sys.version)
2.3.2
3.13.7 (v3.13.7:bcee1c32211, Aug 14 2025, 19:10:51) [Clang 16.0.0 (clang-1600.0.26.6)]
Context for the issue:
It may be a better idea to check for the attribute .ndim
that is used afterwards or the attribute __array__
to identify an Array-Like.
I would be willing to make a Pull request.