Skip to content

Commit 75c79e5

Browse files
authored
Merge pull request #29370 from riku-sakamoto/add_boolean_type_subscribe
ENH: Allow subscript access for `np.bool` by adding `__class_getitem__`
2 parents f1a9e8e + 003281b commit 75c79e5

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

numpy/_core/src/multiarray/scalartypes.c.src

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,6 +2853,12 @@ static PyMethodDef numbertype_methods[] = {
28532853
{NULL, NULL, 0, NULL} /* sentinel */
28542854
};
28552855

2856+
static PyMethodDef booleantype_methods[] = {
2857+
/* for typing */
2858+
{"__class_getitem__", Py_GenericAlias, METH_CLASS | METH_O, NULL},
2859+
{NULL, NULL, 0, NULL} /* sentinel */
2860+
};
2861+
28562862
/**begin repeat
28572863
* #name = cfloat,clongdouble#
28582864
*/
@@ -4571,6 +4577,7 @@ initialize_numeric_types(void)
45714577

45724578
PyBoolArrType_Type.tp_str = genbool_type_str;
45734579
PyBoolArrType_Type.tp_repr = genbool_type_repr;
4580+
PyBoolArrType_Type.tp_methods = booleantype_methods;
45744581

45754582

45764583
/**begin repeat

numpy/_core/tests/test_scalar_methods.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import fractions
55
import platform
66
import types
7-
from typing import Any
7+
from typing import Any, Literal
88

99
import pytest
1010

@@ -171,8 +171,12 @@ def test_abc_non_numeric(self, cls: type[np.generic]) -> None:
171171
@pytest.mark.parametrize("code", np.typecodes["All"])
172172
def test_concrete(self, code: str) -> None:
173173
cls = np.dtype(code).type
174-
with pytest.raises(TypeError):
175-
cls[Any]
174+
if cls == np.bool:
175+
# np.bool allows subscript
176+
assert cls[Any]
177+
else:
178+
with pytest.raises(TypeError):
179+
cls[Any]
176180

177181
@pytest.mark.parametrize("arg_len", range(4))
178182
def test_subscript_tuple(self, arg_len: int) -> None:
@@ -186,6 +190,10 @@ def test_subscript_tuple(self, arg_len: int) -> None:
186190
def test_subscript_scalar(self) -> None:
187191
assert np.number[Any]
188192

193+
@pytest.mark.parametrize("subscript", [Literal[True], Literal[False]])
194+
def test_subscript_bool(self, subscript: Literal[True, False]) -> None:
195+
assert isinstance(np.bool[subscript], types.GenericAlias)
196+
189197

190198
class TestBitCount:
191199
# derived in part from the cpython test "test_bit_count"

0 commit comments

Comments
 (0)