diff --git a/numpy/_core/src/multiarray/scalartypes.c.src b/numpy/_core/src/multiarray/scalartypes.c.src index 5e3a3ba71d3e..dc8d047917ae 100644 --- a/numpy/_core/src/multiarray/scalartypes.c.src +++ b/numpy/_core/src/multiarray/scalartypes.c.src @@ -2853,6 +2853,12 @@ static PyMethodDef numbertype_methods[] = { {NULL, NULL, 0, NULL} /* sentinel */ }; +static PyMethodDef booleantype_methods[] = { + /* for typing */ + {"__class_getitem__", Py_GenericAlias, METH_CLASS | METH_O, NULL}, + {NULL, NULL, 0, NULL} /* sentinel */ +}; + /**begin repeat * #name = cfloat,clongdouble# */ @@ -4571,6 +4577,7 @@ initialize_numeric_types(void) PyBoolArrType_Type.tp_str = genbool_type_str; PyBoolArrType_Type.tp_repr = genbool_type_repr; + PyBoolArrType_Type.tp_methods = booleantype_methods; /**begin repeat diff --git a/numpy/_core/tests/test_scalar_methods.py b/numpy/_core/tests/test_scalar_methods.py index 2d508a08bb4d..26dad71794e3 100644 --- a/numpy/_core/tests/test_scalar_methods.py +++ b/numpy/_core/tests/test_scalar_methods.py @@ -4,7 +4,7 @@ import fractions import platform import types -from typing import Any +from typing import Any, Literal import pytest @@ -171,8 +171,12 @@ def test_abc_non_numeric(self, cls: type[np.generic]) -> None: @pytest.mark.parametrize("code", np.typecodes["All"]) def test_concrete(self, code: str) -> None: cls = np.dtype(code).type - with pytest.raises(TypeError): - cls[Any] + if cls == np.bool: + # np.bool allows subscript + assert cls[Any] + else: + with pytest.raises(TypeError): + cls[Any] @pytest.mark.parametrize("arg_len", range(4)) def test_subscript_tuple(self, arg_len: int) -> None: @@ -186,6 +190,10 @@ def test_subscript_tuple(self, arg_len: int) -> None: def test_subscript_scalar(self) -> None: assert np.number[Any] + @pytest.mark.parametrize("subscript", [Literal[True], Literal[False]]) + def test_subscript_bool(self, subscript: Literal[True, False]) -> None: + assert isinstance(np.bool[subscript], types.GenericAlias) + class TestBitCount: # derived in part from the cpython test "test_bit_count"