Skip to content

Summarize integer type promotion rules with a lattice? #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
shoyer opened this issue Dec 17, 2020 · 32 comments · Fixed by #103
Closed

Summarize integer type promotion rules with a lattice? #101

shoyer opened this issue Dec 17, 2020 · 32 comments · Fixed by #103

Comments

@shoyer
Copy link
Contributor

shoyer commented Dec 17, 2020

@jakevdp suggested that a nice way to summarize the signed/unsigned integer type promotion rules would be with a lattice, e.g.,
image

i* denotes a Python int (with unspecified precision).

This is a subset of the full type promotion lattice from the JAX docs:
https://jax.readthedocs.io/en/latest/type_promotion.html

The lattice for floats would just be f* -> f4 -> f8.

@kgryte
Copy link
Contributor

kgryte commented Dec 17, 2020

I could possibly see such a diagram as a supplement to, but not a replacement for, the current set of tables.

As a replacement, not immediately clear to me that this conveys info effectively and/or more clearly than a table. A table at least has the more practical benefit of being more directly translatable to code; whereas a directed graph requires a bit more work, both in terms of translation and interpretation.

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

I think it would be worth including both: for people accustomed to them, lattices are much easier than tables to understand at a glance.

@rgommers
Copy link
Member

+1 I'm all for more pictures & diagrams, let's do both

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

The comment in JAX's doc source has the networkx code that generates the diagram: https://raw.githubusercontent.com/google/jax/f1b14aa22d91cb9699e5a35503d7bdbe1d58b87a/docs/type_promotion.rst

    import networkx as nx
    import matplotlib.pyplot as plt
    lattice = {
      'b1': ['i*'], 'u1': ['u2', 'i2'], 'u2': ['i4', 'u4'], 'u4': ['u8', 'i8'], 'u8': ['f*'],
      'i*': ['u1', 'i1'], 'i1': ['i2'], 'i2': ['i4'], 'i4': ['i8'], 'i8': ['f*'],
      'f*': ['c*', 'f2', 'bf'], 'bf': ['f4'], 'f2': ['f4'], 'f4': ['c4', 'f8'], 'f8': ['c8'],
      'c*': ['c4'], 'c4': ['c8'], 'c8': [],
    }
    graph = nx.from_dict_of_lists(lattice, create_using=nx.DiGraph)
    pos = {
      'b1': [0, 0], 'u1': [2, 0], 'u2': [3, 0], 'u4': [4, 0], 'u8': [5, 0],
      'i*': [1, 1], 'i1': [2, 2], 'i2': [3, 2], 'i4': [4, 2], 'i8': [5, 2],
      'f*': [6, 1], 'bf': [7.5, 0.6], 'f2': [7.5, 1.4], 'f4': [9, 1], 'f8': [10, 1],
      'c*': [7, 2], 'c4': [10, 2], 'c8': [11, 2],
    }
    fig, ax = plt.subplots(figsize=(8, 2.5))
    nx.draw(graph, with_labels=True, node_size=600, node_color='lightgray', pos=pos, ax=ax)
    fig.savefig('type_lattice.svg', bbox_inches='tight')

@rgommers
Copy link
Member

We may want to use the actual dtype names and make it look a bit prettier, something like:

image

@rgommers
Copy link
Member

I think the 'u8' type strings are a combination of NumPy history and keeping the tables readable, no reason to stick to them.

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

Looks great! We need to remove the arrows that go from signed ints to unsigned ints, I think.

@rgommers
Copy link
Member

Oops yes, was just messing with styling, my brain doesn't actually work anymore after 11pm:(

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

The styling looks really nice!

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

On this topic, I've been meaning to write a public design doc on the criteria used to generate the JAX promotion semantics & the lattice linked above. What would be the most useful thing in terms of the array-api effort?

@kgryte
Copy link
Contributor

kgryte commented Dec 17, 2020

Not clear to me what role the Python int has in the lattice and whether it should be included in the API spec given that we do not recognize it as an explicitly supported array dtype, nor do we explicitly include it in the type promotion tables.

And sorry for the ignorance, but what information/insight is the lattice intended to convey? How would I use it in practice?

For a table, this is relatively straightforward: given two operands, find the respective row and column and that is the promoted type. Not clear to me how the lattice provides an alternative mechanism, but I may just be a bit dense.

@rgommers
Copy link
Member

Not clear to me what role the Python int has in the lattice and whether it should be included in the API spec

It's mixed array-scalar behaviour, which we specify: https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html#mixing-arrays-with-python-scalars

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

The lattice can be used to generate the type promotion table. I think technically it is a visual representation of a partially-ordered set, and the type promotion between any two elements is the join/least upper bound of the two elements.

In real-person terms, if you want to know what two types promote to, you follow the arrows from each and find the first node that is reachable by both (including the node itself, if applicable)

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

So the lattice f4->f8 generates the table

   f4 f8
f4 f4 f8
f8 f8 f8

@rgommers
Copy link
Member

On this topic, I've been meaning to write a public design doc on the criteria used to generate the JAX promotion semantics & the lattice linked above.

That sounds really useful.

What would be the most useful thing in terms of the array-api effort?

In terms of content, what's on https://jax.readthedocs.io/en/latest/type_promotion.html is already quite useful. The lattice is rather straightforward to understand. The most complex part is the mixed-kind behaviour in the table. The details/trade-offs of why (int32, float32) -> float32/64 and (float64, complex64) -> complex64/128 are interesting. It's actually the precision loss for NumPy's behaviour that's harder to explain probably (and not done in the numpy docs).

For the API standard that's probably also the most interesting: are there mixed-kind rules that we could sensibly add either now or in a future version.

@rgommers
Copy link
Member

MXNet actually has a global switch between casting rules optimal for deep learning (default) and numpy's rules. PyTorch is adding mixed-kind promotion at the moment too, as is tf.experimental.numpy. So I think having good rationales and see if there's convergence between all those libraries possible would be healthy.

@kgryte
Copy link
Contributor

kgryte commented Dec 17, 2020

I am clearly in the minority here, but I am a -1 on the need for including a lattice in the spec. While the lattice may be an effective data structure for generating the type promotion table, I don't see its value add as part of the specification beyond being a pretty figure, as it doesn't convey any new information, and I, for one, struggled to, when compared side-to-side, understand what information I was supposed to glean that one or the other representation did not provide.

If we include both, then a user (e.g., me) will wonder: okay, I see a table; got it; now, I see a lattice; what new info do I gain? If they are the same thing, I might wonder if I am missing something. Is the lattice telling me something new? Maybe I am just not seeing it; I'll stare a bit longer. If the table and lattice convey the same information, then the diagrams are redundant and we're wasting reader time by including superfluous information.

I understand how directed graphs work, so that is not the issue. The issue is that I have to work a fair amount harder to divine the same info that is explicitly enumerated in a table.

@jakevdp
Copy link

jakevdp commented Dec 17, 2020

Fair enough.

My first time through the doc, it took me five minutes to peruse the several tables and get a big picture view of what they were saying, and even then I didn't understand how Python scalars play into things until my next read-through when I noticed the "Notes" section below them. I would have grokked all of that in under 15 seconds had there been a lattice displaying the same information.

@leofang
Copy link
Contributor

leofang commented Dec 18, 2020

I like to have visual aids like this too.

@shoyer
Copy link
Contributor Author

shoyer commented Dec 18, 2020

I would prefer the lattice for human understanding, but the table for writing code.

@rgommers
Copy link
Member

I would prefer the lattice for human understanding, but the table for writing code.

Same here.

I can open a PR, and clarify the purpose of the diagram in a caption. Does this look good and complete?

image

@shoyer
Copy link
Contributor Author

shoyer commented Dec 18, 2020

@rgommers not quite -- the unsigned types are promoting to the wrong signed types, e.g., uint8 should connect to int16, not int8.

@leofang
Copy link
Contributor

leofang commented Dec 18, 2020

I suppose bool (either Python or non-Python) can be cast to int too.

@asmeurer
Copy link
Member

The current spec doesn't include bool to int promotion.

Is the diagram shown correct? The "mixed" table at https://data-apis.github.io/array-api/latest/API_specification/type_promotion.html shows that unsigned int and a signed should always promote to a higher bitness than the unsigned input. But my understanding of the diagram doesn't seem to show that.

Also, I would omit Python scalars, as they don't participate in type promotion in the same way (see the notes section in the spec as well as #98).

@leofang
Copy link
Contributor

leofang commented Dec 18, 2020

The current spec doesn't include bool to int promotion.

Interesting, though making sense. How about the other way around (int to bool)?

@shoyer
Copy link
Contributor Author

shoyer commented Dec 18, 2020

Also, I would omit Python scalars, as they don't participate in type promotion in the same way (see the notes section in the spec as well as #98).

The Python scalar type is always converted "to a 0-D array with the same dtype as that of the array used in the expression". That's what is shown in the graphs above.

@asmeurer
Copy link
Member

The "always converted" is wrong though. From the conversation in #98, the spec will only require ints to convert when they are in the same range as the dtype. Otherwise the spec doesn't specify the behavior.

@jakevdp
Copy link

jakevdp commented Dec 19, 2020

I found the spec confusing on this point, because it seems to include two contradictory statements:

Non-array (“scalar”) operands must not participate in type promotion.

then shortly after:

Using Python scalars (i.e., instances of bool , int , float ) together with arrays must be supported for...

@rgommers
Copy link
Member

Also, I would omit Python scalars, as they don't participate in type promotion in the same way (see the notes section in the spec as well as #98).

The Python scalar type is always converted "to a 0-D array with the same dtype as that of the array used in the expression". That's what is shown in the graphs above.

I added an overview of the behaviour of each library in case Python ints or floats have values that are out of range for the dtype of the array in #98 (comment).

I found the spec confusing on this point, because it seems to include two contradictory statements:

Yes indeed, that could use a tweak. What is meant with the first statement is "don't do value-based casting". The second statement says "do support array <operator> scalar".

When you put it together, that still leaves open the question of what should be done in case Python scalars have values that are too large for the dtype of the array. For float32 the obvious choice is to make the results inf. For integers the options are:

@rgommers
Copy link
Member

In the diagram maybe that should be indicated with a dashed line (and explained in its caption):

image

@shoyer
Copy link
Contributor Author

shoyer commented Dec 19, 2020

A dashed line indicating undefined behavior on overflow sounds like a good solution to me.

@rgommers
Copy link
Member

Made one more tweak - adding a dashed line between Python int and float scalars, because 2 * x should work if x has a floating-point dtype - and opened gh-103 to add the diagram and clarify the text.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants