Skip to content

DEP: Deprecate registering dtype names with np.sctypeDict? #24699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ngoldbaum opened this issue Sep 13, 2023 · 12 comments
Open

DEP: Deprecate registering dtype names with np.sctypeDict? #24699

ngoldbaum opened this issue Sep 13, 2023 · 12 comments

Comments

@ngoldbaum
Copy link
Member

ngoldbaum commented Sep 13, 2023

On import, ml_dtypes adds new entries to np.sctypeDict so that e.g. np.dtype(“int4”) returns an int4 dtype defined outside NumPy.

Since jax currently documents this behavior to users and relies on it internally, I don’t think we can reasonably break it without a deprecation story and a migration story.

For deprecating it, we would keep a list of all the strings that NumPy accepts out of the box and if any other string is passed in and somehow we get back a valid dtype, we should raise a deprecation warning. I don’t know if there are other ways of injecting a string dtype name into NumPy’s internals without manipulating sctypeDict so this will catch any other shenanigans.

We should probably also deprecate np.sctypeDict too?

In a few releases after adding the deprecation, we could make it so np.dtype can only return dtype instances with a mapping defined out of the box in NumPy or via some as-yet unwritten mechanism to associate string names with dtypes, probably with some kind of support for namespacing.

As far as I know jax is the only downstream library that injects dtype names into the np.dtype("dtype_name") mechanism.

The deprecation should not be added until we have a clear migration story for the jax library and any possible impacts on jax users are considered.

xref #24376 (comment) and the discussion that follows for context

@ngoldbaum ngoldbaum changed the title Deprecate registering dtype names with np.sctypeDict ENH: Deprecate registering dtype names with np.sctypeDict Sep 13, 2023
@ngoldbaum ngoldbaum changed the title ENH: Deprecate registering dtype names with np.sctypeDict DEP: Deprecate registering dtype names with np.sctypeDict? Sep 13, 2023
@rgommers
Copy link
Member

Thanks for the summary @ngoldbaum!

I think it would be useful to get the actual requirements and constraints here clear first, before thinking about potential solutions to make string names for external dtypes work.

Since jax currently documents this behavior to users and relies on it internally, I don’t think we can reasonably break it without a deprecation story and a migration story.

Isn't this completely untested/documented and just happpened to work by relying on numpy-internal implementation details? If we did document or recommended it, that would make a difference here. But my current impression is that this could have broken at any time, and it's bad practice for a package to modify global state in a way that can break other packages.

np.dtype(“int4”) returns an int4 dtype defined outside NumPy.

I don't see a real reason to support this, at least if the only rationale is to save a few characters (since np.dtype(ml_dtypes.int4) works fine apparently). Is there a better reason?

Also, how often do you actually need this dtype instance? Idiomatic code in numpy uses func(x, dtype=np.int8) and not func(x, dtype=np.dtype('int8')).

@hawkinsp
Copy link
Contributor

@rgommers Well, the entire numpy type extension API is more or less undocumented, so to a certain extent that's true of anything related to user-defined types. I know there are plans to change that (NEP 42) but they aren't ready yet.

This extension has existed in various forms since 2017 (originally as part of TensorFlow), and it looks like we started adding entries to np.typeDict (later np.sctypeDict) in 2021 (tensorflow/tensorflow@16671ca). The extension was later moved into its own package (ml_dtypes) so it could be shared between projects without requiring a TF dependency.

The issue is not so much that we write this, it's that our users who may write this or expect to write this. I don't think it's that prevalent, but some users definitely do it.

And it strikes me as something that's reasonable to expect: we're trying to add additional NumPy integer types (int4, uint4) and floating point types (bfloat16, float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz) that to the extent possible are supposed to look and feel exactly like the builtin int and float types.

There are other instances where to do this right may require updating global state, for instance users expect to write things like np.finfo(bfloat16) (although one we don't do implement that one in practice for reasons not worth discussing right here). In that case, we just provide ml_dtypes.finfo and ml_dtypes.iinfo because NumPy wasn't sufficiently extensible.

That said, I suspect this problem is less pressing than it once was: now that ml_dtypes is a self-contained package, one can easily name the type directly (ml_dtypes.bfloat16). So if you want to disallow doing this, we can adapt.

@rgommers
Copy link
Member

And it strikes me as something that's reasonable to expect .... are supposed to look and feel exactly like the builtin int and float types.

Sure, but part of the point of all the cleanups in 2.0 is to reduce the many ways of doing the same thing, and give users some better guidance of doing things. And this to me does not look like recommended usage. E.g.:

>>> x = np.ones(2, dtype=np.int8)  # the canonical way
>>> x = np.ones(2, dtype='int8')
>>> x = np.ones(2, dtype='i1')
>>> x = np.ones(2, dtype=np.dtype('int8'))
>>> x = np.ones(2, dtype=np.byte)
>>> ... # etc.

This is quite a mess. I believe the recommended way to write code for ml_types users should be np.ones(2, dtype=ml_types.int4), analogous to the numpy canonical way.

The issue is not so much that we write this, it's that our users who may write this or expect to write this. I don't think it's that prevalent, but some users definitely do it.

Sure. We can think about ways to keep this working. However before we arrive at "we cannot touch this at all for 2.0", let's first figure out how this is supposed to look in the future.

And only after that how we move forward in a way that isn't too disruptive for JAX. I'll note that you have to do a new release for 2.0 support anyway, so if we give you any other way (perhaps temporary/private) to keep np.dtype('int4') working if you cannot deprecate it fast enough, that should be fine.

@jakevdp
Copy link
Contributor

jakevdp commented Sep 14, 2023

I think one challenge here is that JAX/ml_dtypes cannot deprecate np.dtype('bfloat16') from the JAX side – that's a numpy API, and we cannot make it raise any sort of deprecation warning (except maybe by some sort of monkey patching, but I wouldn't consider that a viable solution). The only knob we have is to break users by no longer registering the type name.

We could certainly update all instances in our own code and other code that we have control over, but any more gradual runtime deprecation behavior would have to come from NumPy.

Isn't this completely untested/documented and just happpened to work by relying on numpy-internal implementation details?

No, it's explicitly documented. e.g. in the README at https://github.com/jax-ml/ml_dtypes:

Importing ml_dtypes also registers the data types with numpy, so that they may be referred to by their string name:

>>> np.dtype('bfloat16')
dtype(bfloat16)
>>> np.dtype('float8_e5m2')
dtype(float8_e5m2)

That's probably my fault – I had assumed this sort of registration was intentionally supported by NumPy, and so I advertised it as so.

@rgommers
Copy link
Member

I think one challenge here is that JAX/ml_dtypes cannot deprecate np.dtype('bfloat16') from the JAX side – that's a numpy API, and we cannot make it raise any sort of deprecation warning

Can't you simply insert {'bfloat16': ml_types._bfloat16)} into sctypeDict, where _bfloat16 raises a deprecation warning on access and then returns your public ml_types.bfloat16?

No, it's explicitly documented. e.g. in the README at https://github.com/jax-ml/ml_dtypes:

I meant in the numpy docs.

That's probably my fault – I had assumed this sort of registration was intentionally supported by NumPy, and so I advertised it as so.

No worries at all. It's not impossible that this was in some numpy tutorial or docs. And yes, NumPy historically had neither docs for this kind of thing nor any sort of reasonable public/private split. So who knows if this dict was ever meant to be public for reading from (let alone for writing into).

@jakevdp
Copy link
Contributor

jakevdp commented Sep 14, 2023

Can't you simply insert {'bfloat16': ml_types._bfloat16)} into sctypeDict, where _bfloat16 raises a deprecation warning on access and then returns your public ml_types.bfloat16?

I don't totally follow: say we create a shadow _bfloat16 scalar type that looks just like the real ml_dtypes.bfloat16 scalar type. Where exactly would we raise the deprecation warning? I suppose we'd probably want the equivalent of np.dtype(ml_dtypes._bfloat16) to result in a warning, but I'm unclear on what methods of the shadow _bfloat16 object would be called in this case.

@ngoldbaum
Copy link
Member Author

ngoldbaum commented Sep 14, 2023

I haven't done this with the legacy API, but could you add a tp_new implementation to the dtype class and create the warning when a dtype instance is instantiated? It looks like the ml_dtypes types don't have tp_new implementations, but I think you could add one to the struct where it is filled in e.g. here for int4.

@jakevdp
Copy link
Contributor

jakevdp commented Sep 14, 2023

Does accessing the dtype singleton involve instantiating the scalar type?

@jakevdp
Copy link
Contributor

jakevdp commented Sep 14, 2023

OK, so I traced it through: I think np.dtype(typ) will eventually end up here if typ is not an actual scalar type:

PyArray_Descr *ret = _try_convert_from_dtype_attr(obj);

So as long as sctypeDict['bfloat16'] doesn't have to actually contain a scalar type, we could accomplish this deprecation warning this way:

import warnings
import ml_dtypes
import numpy as np

class _deprecated_bfloat16:
  @classmethod
  @property
  def dtype(self):
    warnings.warn("np.dtype('bfloat16') is deprecated. Use np.dtype(ml_dtypes.bfloat16) instead.")
    return np.dtype(ml_dtypes.bfloat16)

np.sctypeDict['bfloat16'] = _deprecated_bfloat16

print(np.dtype('bfloat16'))
# UserWarning: np.dtype('bfloat16') is deprecated. Use np.dtype(ml_dtypes.bfloat16) instead.
# dtype(bfloat16)

It would be nice to make _deprecated_bfloat16 a subclass of the true bfloat16 to cover other use-cases of sctypeDict['bfloat16'], but in that case np.dtype('bfloat16') bottoms-out in this condition:

if (PyType_IsSubtype(typ, &PyGenericArrType_Type)) {
return PyArray_DescrFromTypeObject(obj);
}

I don't see any good overloading route from the Python side within PyArray_DescrFromTypeObject, since it's only accessing C-level attributes.

Let me know if you have better ideas!

@seberg
Copy link
Member

seberg commented Sep 18, 2023

I agree that the deprecation should be spit out by NumPy, happy to nudge you towards doing it, but hacking it in ml_dtypes seems just harder/confusing.
We fall back to looking up in the scalar type dict at some point for np.dtype() construction and that should be a good place to do that. It could probably be also be done by making sctypeDict a dict-like object or dict subclass, but I doubt it's better.
(np.dtype() construction is very messy, but it probably doesn't even matter for this purpose.)

The only thing that would make me pause is if you/users have a large want to keep it working. In that case we should add a new way to do a proper registration with a function. (and I might be fine to just make the old one fail.)

@jakevdp
Copy link
Contributor

jakevdp commented Sep 18, 2023

The only thing that would make me pause is if you/users have a large want to keep it working.

I think this would be ideal – after all the internal dtype system already has a notion of dtype string name, and does string-to-dtype lookups at the C level, even for user-registered dtypes, without any references to sctypeDict. Supporting that at the Python level as well doesn't seem like too much of a stretch.

@seberg
Copy link
Member

seberg commented Sep 18, 2023

I don't really like it, because you cannot control name clashes well and the next thing will be someone asking if we can allow np.dtype("mydtype[days]") (parameters). But, I can live with it with those caveats because unfortunately that is a typical pattern (and maybe convenient, although I am not sure it is meaningfully convenient).

But as I said, I would be fine with hiding a registration function somewhere (which could internally just insert into the sctypeDict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants