Skip to content

ENH,API: Add a protocol for representing nested sequences #18155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

BvB93
Copy link
Member

@BvB93 BvB93 commented Jan 12, 2021

This PR adds (and exposes) a protocol representing nested sequences of arbitrary depth: npt.NestedSequence.

Despite the lack of formal support for recursion support in mypy (python/mypy#731), it turns out we don't
need this at all for representing recursive sequences. As this PR demonstrates, a simply protocol is sufficient here.

The advantage of the protocol versus the currently used union-based approach is trifold:

  • It works for arbitrary levels of nesting, while the union-based approach is only suitable for a
    pre-defined number of nesting levels (e.g. 4).
  • Building on top of the previous point: this means that we do require an (unsafe) overload for nesting
    levels not captured by aforementioned union (e.g. ENH: Add dtype support to the array comparison ops #18128 (comment)).
  • As an added bonus, mypy's string representing will be quite be lot more compact for a single type as
    compared to a union consisting of many types.

The new npt.NestedSequence protocol introduced herein is exposed to the public numpy.typing API,
as I imagine it will be rather useful for representing array-like objects, be it either in NumPy or downstream.

Examples

from __future__ import annotations
from typing import Any, List, TYPE_CHECKING
import numpy as np
import numpy.typing as npt

def get_dtype(seq: npt.NestedSequence[int]) -> np.dtype[np.int_]:
    return np.asarray(seq).dtype

a = get_dtype([1])
b = get_dtype([[1]])
c = get_dtype([[[1]]])
d = get_dtype([[[[1]]]])

if TYPE_CHECKING:
    reveal_locals()
    # note: Revealed local types are:
    # note:     a: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
    # note:     b: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
    # note:     c: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]
    # note:     d: numpy.dtype[numpy.signedinteger[numpy.typing._64Bit]]

@BvB93 BvB93 force-pushed the recursive-sequence branch from ecfdf18 to dc33633 Compare January 12, 2021 21:23
@BvB93
Copy link
Member Author

BvB93 commented Jan 12, 2021

@seberg
Copy link
Member

seberg commented Jan 12, 2021

I can see this being nice to have, so I am not opposed. But I am a bit surprised that we want to make nested sequences a common thing for most command? I.e. I would expect it to be mainly a valid input for np.array and friends, and pretty much all other occurrences would require you to call np.asarray manually to type correctly?

@BvB93
Copy link
Member Author

BvB93 commented Jan 12, 2021

I.e. I would expect it to be mainly a valid input for np.array and friends, and pretty much all other occurrences would require you to call np.asarray manually to type correctly?

This is true for scripts that utilize np.array, but if np.array is called under the hood by some function then there is no easy way to reuse its annotations. Barring the case the where its is signatures is absolutely identical, that is (i.e. situations where we can abuse __call__-based protocols).

import numpy as np

# No easy way of grabbing the signature of `np.array` and putting it in `func` 
# (besides manually copying it, that is)
def func(a):  # ??? 
    return np.array(a) * 10

In practice this means that, unfortunately, we're stuck with quite a bit code duplication and thus the need for the likes of NestedSequence. Hence why I feel that making it public would worthwhile, especially for downstream libraries.

@BvB93
Copy link
Member Author

BvB93 commented Jan 15, 2021

Closing this for now, as further testing has unfortunately revealed a number of detrimental mypy bugs/limitations 😕 :

  • NestedSequence fails as function annotation if the, respective, passed parameter has previously
    not been assigned and/or declared.
from typing import TypeVar, List, overload
import numpy.typing as npt

@overload
def func1(a: npt.NestedSequence[bool]) -> bool: ...
@overload
def func1(a: npt.NestedSequence[int]) -> int: ...

int1 = [[1]]
int2: List[List[int]]

# note: Revealed type is 'builtins.int'
# This is ok
reveal_type(func1(int1))

# note: Revealed type is 'builtins.int'
# This is also ok
reveal_type(func1(int2))

# note: Revealed type is 'builtins.bool'
# This is bad; how did we end up at the `bool` overload all of a sudden????
reveal_type(func1([[1]]))  
  • The use of TypeVars is just straight-up broken.
T = TypeVar("T")

def func2(a: npt.NestedSequence[T]) -> T: ...

# error: Argument 1 to "func2" has incompatible type "List[List[int]]"; expected "NestedSequence[<nothing>]"
reveal_type(func2(int1))

# error: Argument 1 to "func2" has incompatible type "List[List[int]]"; expected "NestedSequence[<nothing>]"
reveal_type(func2(int2))

# note: Revealed type is 'Any'
reveal_type(func2([[1]]))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants