Skip to content

BUG: NDArray does not allow parametrizing with a TypeVar #25153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Jacob-Stevens-Haas opened this issue Nov 15, 2023 · 4 comments
Closed

BUG: NDArray does not allow parametrizing with a TypeVar #25153

Jacob-Stevens-Haas opened this issue Nov 15, 2023 · 4 comments
Labels
33 - Question Question about NumPy usage or development 41 - Static typing

Comments

@Jacob-Stevens-Haas
Copy link
Contributor

Jacob-Stevens-Haas commented Nov 15, 2023

Describe the issue:

I'm looking to parametrize a numpy array with a type variable. Not sure if this is meant to be implemented, and thus a bug, or not-yet-implemented, and thus a feature request. The following code is flagged by Pylance:

Reproduce the code example:

from numpy.typing import NDArray
from typing import TypeVar

T = TypeVar("T")
GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"]

Error message:

Could not specialize type "NDArray[_ScalarType_co@NDArray]"
  Type "T@GridsearchResult" cannot be assigned to type "generic"
    "object" is incompatible with "generic"

Runtime information:

[{'numpy_version': '1.26.2',
'python': '3.10.12 (main, Jun 11 2023, 05:26:28) [GCC 11.4.0]',
'uname': uname_result(system='Linux', node='pontus', release='6.2.0-36-generic', version='#37~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Mon Oct 9 15:34:04 UTC 2', machine='x86_64')},
{'simd_extensions': {'baseline': ['SSE', 'SSE2', 'SSE3'],
'found': ['SSSE3',
'SSE41',
'POPCNT',
'SSE42',
'AVX',
'F16C',
'FMA3',
'AVX2'],
'not_found': ['AVX512F',
'AVX512CD',
'AVX512_KNL',
'AVX512_KNM',
'AVX512_SKX',
'AVX512_CLX',
'AVX512_CNL',
'AVX512_ICL']}},
{'architecture': 'Haswell',
'filepath': '/home/xenophon/github/gen-experiments/env/lib/python3.10/site-packages/numpy.libs/libopenblas64_p-r0-0cf96a72.3.23.dev.so',
'internal_api': 'openblas',
'num_threads': 4,
'prefix': 'libopenblas',
'threading_layer': 'pthreads',
'user_api': 'blas',
'version': '0.3.23.dev'}]

Context for the issue:

My gridsearch code returns a tuple of arrays, both 2D of the same size. One return array carries the maxes of a larger array along several axes, the other return array carries the argmax along those indices. So the two types are:

Annotated[NDArray[np.float64], "(n_metrics, plot_axis_length)"]
Annotated[NDArray[np.dtype("i,i,i,i")], "(n_metrics, plot_axis_length)"]

Here, the length of the indexing tuple is variable/determined later, but I would like a function signature akin to (ignoring the size annotation for now):

def maxes_and_argmaxes[T, U](arr: NDArray[T], axes: tuple[int, ...]): -> tuple[NDArray[T], NDArray[U]]

Here, U would be the variadic tuple of ints dtype.

I admit I'm a bit new to type variables, so this may not be the right approach. But I imagine the correct final form, if I want to Annotate the NDArrays is

T = TypeVar("T")
GridsearchResult = Annotated[NDArray[T], "(n_metrics, n_plot_axis)"]

def maxes_and_argmaxes[T, U](
     arr: NDArray[T], axes: tuple[int, ...]
): -> tuple[GridsearchResult[T], GridsearchResult[U]]
@vnmabus
Copy link

vnmabus commented Nov 17, 2023

You just need to assign a bound of np.generic to your TypeVar.

@rgommers rgommers added 41 - Static typing 33 - Question Question about NumPy usage or development and removed 00 - Bug labels Nov 17, 2023
@Jacob-Stevens-Haas
Copy link
Contributor Author

Ok, thanks! that works.

@mscheltienne
Copy link

But then can we constrain the TypeVar, e.g. what if I know my array should always contain either np.float32 or np.float64. Would TypeVar("T", np.float32, np.float64) and NDArray[T] be valid?

@Jacob-Stevens-Haas
Copy link
Contributor Author

@mscheltienne try np.floating, a subtype of np.generic. For more general reference, there's a table of generic subtypes on this page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
33 - Question Question about NumPy usage or development 41 - Static typing
Projects
None yet
Development

No branches or pull requests

4 participants