Skip to content

BUG: polynomial evaluation functions assume numpy array and work inconsistently with other array types #29680

@FelixBenning

Description

@FelixBenning

Describe the issue:

The following code works with numpy arrays, lists and tuples but breaks on jax.numpy arrays for example:

if isinstance(x, (tuple, list)):
x = np.asarray(x)
if isinstance(x, np.ndarray) and tensor:
c = c.reshape(c.shape + (1,) * x.ndim)

Reproduce the code example:

This results in inconsistent behavior:

>>> from numpy.polynomial import polynomial as poly
>>> b = jax.numpy.array(
       [[1, 2, 3],
       [2, 3, 4]]
)
>>> poly.polyval(np.array(b),np.array(b), tensor=True)
array([[[ 3.,  5.,  7.],
        [ 5.,  7.,  9.]],

       [[ 5.,  8., 11.],
        [ 8., 11., 14.]],

       [[ 7., 11., 15.],
        [11., 15., 19.]]])
>>> poly.polyval(b,b, tensor=True)
Array([[ 3.,  8., 15.],
       [ 5., 11., 19.]], dtype=float32)

Python and NumPy Versions:

>>> import sys, numpy; print(numpy.__version__); print(sys.version)
2.3.2
3.13.7 (v3.13.7:bcee1c32211, Aug 14 2025, 19:10:51) [Clang 16.0.0 (clang-1600.0.26.6)]

Context for the issue:

It may be a better idea to check for the attribute .ndim that is used afterwards or the attribute __array__ to identify an Array-Like.

I would be willing to make a Pull request.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions