-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
API: Introduce np.isdtype
function [Array API]
#25054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
np.isdtype
function
9f3a100
to
0264091
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bit of a fly-by review out of curiosity, but one small comment inline.
In favour of adding 'flexible' would seem to be that it is fairly common to have functions that work well with either float or complex, and it may just have been an oversight of the array API (which initially had no support for complex numbers).
Thanks @mtsokol!
It's not an oversight; Now I can see a case for the opposite point as well, but I think we should start with only implementing what the standard says, and only reconsidering if it becomes clear that this is cumbersome. At that point a naming discussion should be had I think, because One of the upsides is that it's often unclear how well complex dtypes are supported by a function. If we read: np.isdtype(dtype, ('real floating', 'complex floating')) that is quite clear, and easier to understand than the old numpy names for this.
I think you should enforce the typing of the standard - so it'd either be a dtype object or a string - and if it's a string then it must be one of the strings that is a named collection of dtypes in the standard. NumPy has always had a ton of different ways to say |
Yes, makes sense to have a replacement for "flexible" only if it turns out there is an actual need. |
@rgommers Sure! I updated the implementation (and tests) to match the standard exactly. |
28e435e
to
3e1c06e
Compare
3e1c06e
to
a227f5d
Compare
np.isdtype
functionnp.isdtype
function [Array API]
numpy/_core/numerictypes.py
Outdated
|
||
""" | ||
# validate and preprocess arguments | ||
if not isinstance(dtype, (type, ma.dtype)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if there's an easy way to spell this that's more accurate than just checking for type
, but this would let you do e.g. np.isdtype(np.int64, dict)
and return False. I guess that is true but is perhaps against the spirit of this check. I don't know if there's an easy way to get a list of all the dtype scalar types numpy knows about at any given moment (which could include user dtypes or dtypes defined in downstream packages).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know it either. Let's ask about it during today's community call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe checking issubclass(dtype, np.generic)
would make snese?
In [31]: np.int64.mro()
Out[31]:
[numpy.int64,
numpy.signedinteger,
numpy.integer,
numpy.number,
numpy.generic,
object]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We decided at the triage meeting the other day that we wanted to make this error reject all dtypes or scalar types that aren't in the array API standard
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reject all dtypes or scalar types that aren't in the array API standard
That would make it harder to use while supporting, e.g., float16
, and seems like an odd suggestion - what is the justification?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the jax
bfloat16
type as a real-world example, I see:
In [10]: np.issubdtype(ml_dtypes.bfloat16, np.floating)
Out[10]: False
In [11]: ml_dtypes.bfloat16.mro()
Out[11]: [ml_dtypes.bfloat16, numpy.generic, object]
Which isn't particularly useful but is correct according to the implementation of issubdtype
.
So I think we might need to just not handle third-party dtypes for now, until we have a better story for registering dtypes in a type hierarchy. Does that sound reasonable?
Handling all the builtin numpy dtypes here makes sense, no worries about doing that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think we might need to just not handle third-party dtypes for now, until we have a better story for registering dtypes in a type hierarchy. Does that sound reasonable?
Thanks for checking. That sounds fine to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the implementation so now it accepts all NumPy's dtypes as inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For what it's worth, the reason that np.issubdtype(ml_dtypes.bfloat16, np.floating)
returns False
is that in several places NumPy hard-codes the assumption that there is only a single 16-bit floating-point type, so we could not make bfloat16
a subclass of np.floating
without causing collisions across the codebase. If you have suggestions for how to do better, we'd love to hear it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’d love to relax that assumption inside numpy, I’m sure patches that do that would be reviewed if you can get anyone to work on it. I also wasn’t privy to previous discussions about this but at this point I think numpy should probably support bfloat16 natively. Of course that’s just my opinion. Also with the new dtype API it should be much more straightforward to upstream a new dtype into numpy, there’s no need anymore to mess with all the complicated custom templating and codegen.
7e7b550
to
cd8a56c
Compare
a22def4
to
c74dbb4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @mtsokol, this is looking pretty good to me - a number of minor comments.
False | ||
>>> np.isdtype(np.int64, (np.uint64, "signed integer")) | ||
True | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to explicitly add the example in the PR discussion regarding checking for real-floating only vs. an API that supports both real and complex floating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean in the Examples
section? Or PR's description?
I added two more examples to the docstring for checking real-floating only
and real and complex floating
.
numpy/_core/numerictypes.py
Outdated
True | ||
|
||
""" | ||
# validate and preprocess arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be useful to extend this comment to state what is actually happening here. The intended inputs are np.float32
& co, which aren't instances of multiarray.dtype
. So I'm not sure that this validation is right.
Also, rejection of non-compliant objects may be good to do more explicitly, now the errors may be a bit obscure:
In [9]: np.isdtype(np.float32, np.ones(1))
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 np.isdtype(np.float32, np.ones(1))
File ~/code/numpy/build-install/usr/lib/python3.11/site-packages/numpy/_core/numerictypes.py:436, in isdtype(dtype, kind)
434 if isinstance(kind, ma.dtype):
435 kind = kind.type
--> 436 if kind not in allTypes.values():
437 raise TypeError(
438 "kind argument must be comprised of NumPy dtypes or "
439 f"strings only, but it is a {kind}."
440 )
442 processed_kinds.add(kind)
TypeError: descriptor '__array_wrap__' for 'numpy.generic' objects doesn't apply to a 'numpy.ndarray' object
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a short docstring for the helper function instead.
Basically, if we have a dtype
instance we need to extract type
. Then we check if the dtype
is in allTypes
(I added a check for numpy.ndarray
to avoid the error you posted).
f830455
to
3a3684b
Compare
3a3684b
to
151d70a
Compare
151d70a
to
4598f0c
Compare
Sorry for not merging this earlier in the week, I dropped it without looking at it again accidentally. Thanks for bringing this one home @mtsokol! |
np.isdtype
function [Array API]np.isdtype
function [Array API]
Hi @rgommers @ngoldbaum,
This PR adds
np.isdtype
mentioned in the NEP 52 and the tracking issue #23999.I followed the description in https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html.
Questions:
kind
argument aflexible
is missing (union of floating and complex types), butintegral
is present. Is it on purpose?dtype
argument to be adtype
but in the implementation I accept anything that can be consumed bynp.dtype
, such as strings "int64" etc. Should I enforce dtype instances only?