-
Notifications
You must be signed in to change notification settings - Fork 53
Summarize integer type promotion rules with a lattice? #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I could possibly see such a diagram as a supplement to, but not a replacement for, the current set of tables. As a replacement, not immediately clear to me that this conveys info effectively and/or more clearly than a table. A table at least has the more practical benefit of being more directly translatable to code; whereas a directed graph requires a bit more work, both in terms of translation and interpretation. |
I think it would be worth including both: for people accustomed to them, lattices are much easier than tables to understand at a glance. |
+1 I'm all for more pictures & diagrams, let's do both |
The comment in JAX's doc source has the networkx code that generates the diagram: https://raw.githubusercontent.com/google/jax/f1b14aa22d91cb9699e5a35503d7bdbe1d58b87a/docs/type_promotion.rst import networkx as nx
import matplotlib.pyplot as plt
lattice = {
'b1': ['i*'], 'u1': ['u2', 'i2'], 'u2': ['i4', 'u4'], 'u4': ['u8', 'i8'], 'u8': ['f*'],
'i*': ['u1', 'i1'], 'i1': ['i2'], 'i2': ['i4'], 'i4': ['i8'], 'i8': ['f*'],
'f*': ['c*', 'f2', 'bf'], 'bf': ['f4'], 'f2': ['f4'], 'f4': ['c4', 'f8'], 'f8': ['c8'],
'c*': ['c4'], 'c4': ['c8'], 'c8': [],
}
graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
pos = {
'b1': [0, 0], 'u1': [2, 0], 'u2': [3, 0], 'u4': [4, 0], 'u8': [5, 0],
'i*': [1, 1], 'i1': [2, 2], 'i2': [3, 2], 'i4': [4, 2], 'i8': [5, 2],
'f*': [6, 1], 'bf': [7.5, 0.6], 'f2': [7.5, 1.4], 'f4': [9, 1], 'f8': [10, 1],
'c*': [7, 2], 'c4': [10, 2], 'c8': [11, 2],
}
fig, ax = plt.subplots(figsize=(8, 2.5))
nx.draw(graph, with_labels=True, node_size=600, node_color='lightgray', pos=pos, ax=ax)
fig.savefig('type_lattice.svg', bbox_inches='tight') |
I think the |
Looks great! We need to remove the arrows that go from signed ints to unsigned ints, I think. |
Oops yes, was just messing with styling, my brain doesn't actually work anymore after 11pm:( |
The styling looks really nice! |
On this topic, I've been meaning to write a public design doc on the criteria used to generate the JAX promotion semantics & the lattice linked above. What would be the most useful thing in terms of the array-api effort? |
Not clear to me what role the Python And sorry for the ignorance, but what information/insight is the lattice intended to convey? How would I use it in practice? For a table, this is relatively straightforward: given two operands, find the respective row and column and that is the promoted type. Not clear to me how the lattice provides an alternative mechanism, but I may just be a bit dense. |
It's mixed array-scalar behaviour, which we specify: https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars |
The lattice can be used to generate the type promotion table. I think technically it is a visual representation of a partially-ordered set, and the type promotion between any two elements is the join/least upper bound of the two elements. In real-person terms, if you want to know what two types promote to, you follow the arrows from each and find the first node that is reachable by both (including the node itself, if applicable) |
So the lattice
|
That sounds really useful.
In terms of content, what's on https://jax.readthedocs.io/en/latest/type_promotion.html is already quite useful. The lattice is rather straightforward to understand. The most complex part is the mixed-kind behaviour in the table. The details/trade-offs of why For the API standard that's probably also the most interesting: are there mixed-kind rules that we could sensibly add either now or in a future version. |
MXNet actually has a global switch between casting rules optimal for deep learning (default) and numpy's rules. PyTorch is adding mixed-kind promotion at the moment too, as is |
I am clearly in the minority here, but I am a If we include both, then a user (e.g., me) will wonder: okay, I see a table; got it; now, I see a lattice; what new info do I gain? If they are the same thing, I might wonder if I am missing something. Is the lattice telling me something new? Maybe I am just not seeing it; I'll stare a bit longer. If the table and lattice convey the same information, then the diagrams are redundant and we're wasting reader time by including superfluous information. I understand how directed graphs work, so that is not the issue. The issue is that I have to work a fair amount harder to divine the same info that is explicitly enumerated in a table. |
Fair enough. My first time through the doc, it took me five minutes to peruse the several tables and get a big picture view of what they were saying, and even then I didn't understand how Python scalars play into things until my next read-through when I noticed the "Notes" section below them. I would have grokked all of that in under 15 seconds had there been a lattice displaying the same information. |
I like to have visual aids like this too. |
I would prefer the lattice for human understanding, but the table for writing code. |
@rgommers not quite -- the unsigned types are promoting to the wrong signed types, e.g., uint8 should connect to int16, not int8. |
I suppose bool (either Python or non-Python) can be cast to int too. |
The current spec doesn't include bool to int promotion. Is the diagram shown correct? The "mixed" table at https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html shows that unsigned int and a signed should always promote to a higher bitness than the unsigned input. But my understanding of the diagram doesn't seem to show that. Also, I would omit Python scalars, as they don't participate in type promotion in the same way (see the notes section in the spec as well as #98). |
Interesting, though making sense. How about the other way around (int to bool)? |
The Python scalar type is always converted "to a 0-D array with the same dtype as that of the array used in the expression". That's what is shown in the graphs above. |
The "always converted" is wrong though. From the conversation in #98, the spec will only require ints to convert when they are in the same range as the dtype. Otherwise the spec doesn't specify the behavior. |
I found the spec confusing on this point, because it seems to include two contradictory statements:
then shortly after:
|
I added an overview of the behaviour of each library in case Python ints or floats have values that are out of range for the
Yes indeed, that could use a tweak. What is meant with the first statement is "don't do value-based casting". The second statement says "do support When you put it together, that still leaves open the question of what should be done in case Python scalars have values that are too large for the dtype of the array. For
|
A dashed line indicating undefined behavior on overflow sounds like a good solution to me. |
Made one more tweak - adding a dashed line between Python |
@jakevdp suggested that a nice way to summarize the signed/unsigned integer type promotion rules would be with a lattice, e.g.,

i*
denotes a Pythonint
(with unspecified precision).This is a subset of the full type promotion lattice from the JAX docs:
https://jax.readthedocs.io/en/latest/type_promotion.html
The lattice for floats would just be
f* -> f4 -> f8
.The text was updated successfully, but these errors were encountered: