Skip to content

gh-112281: Allow Union with unhashable Annotated metadata #112283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 1, 2024
20 changes: 20 additions & 0 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,26 @@ def test_hash(self):
self.assertEqual(hash(int | str), hash(str | int))
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))

def test_union_of_unhashable(self):
class UnhashableMeta(type):
__hash__ = None

class A(metaclass=UnhashableMeta): ...
class B(metaclass=UnhashableMeta): ...

self.assertEqual((A | B).__args__, (A, B))
union1 = A | B
with self.assertRaises(TypeError):
hash(union1)

union2 = int | B
with self.assertRaises(TypeError):
hash(union2)

union3 = A | int
with self.assertRaises(TypeError):
hash(union3)

def test_instancecheck_and_subclasscheck(self):
for x in (int | str, typing.Union[int, str]):
with self.subTest(x=x):
Expand Down
107 changes: 103 additions & 4 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import collections
import collections.abc
from collections import defaultdict
from functools import lru_cache, wraps
from functools import lru_cache, wraps, reduce
import gc
import inspect
import itertools
import operator
import pickle
import re
import sys
Expand Down Expand Up @@ -1769,6 +1770,26 @@ def test_union_union(self):
v = Union[u, Employee]
self.assertEqual(v, Union[int, float, Employee])

def test_union_of_unhashable(self):
class UnhashableMeta(type):
__hash__ = None

class A(metaclass=UnhashableMeta): ...
class B(metaclass=UnhashableMeta): ...

self.assertEqual(Union[A, B].__args__, (A, B))
union1 = Union[A, B]
with self.assertRaises(TypeError):
hash(union1)

union2 = Union[int, B]
with self.assertRaises(TypeError):
hash(union2)

union3 = Union[A, int]
with self.assertRaises(TypeError):
hash(union3)

def test_repr(self):
self.assertEqual(repr(Union), 'typing.Union')
u = Union[Employee, int]
Expand Down Expand Up @@ -5506,10 +5527,8 @@ def some(self):
self.assertFalse(hasattr(WithOverride.some, "__override__"))

def test_multiple_decorators(self):
import functools

def with_wraps(f): # similar to `lru_cache` definition
@functools.wraps(f)
@wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs)
return wrapper
Expand Down Expand Up @@ -8524,6 +8543,76 @@ def test_flatten(self):
self.assertEqual(A.__metadata__, (4, 5))
self.assertEqual(A.__origin__, int)

def test_deduplicate_from_union(self):
# Regular:
self.assertEqual(get_args(Annotated[int, 1] | int),
(Annotated[int, 1], int))
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
(Annotated[int, 1], int))
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
(Annotated[int, 1], Annotated[int, 2], int))
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
(Annotated[int, 1], Annotated[int, 2], int))
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
(Annotated[int, 1], Annotated[str, 1], int))
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
(Annotated[int, 1], Annotated[str, 1], int))

# Duplicates:
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
Annotated[int, 1] | int)
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
Union[Annotated[int, 1], int])

# Unhashable metadata:
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
(str, Annotated[int, {}], Annotated[int, set()], int))
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
(str, Annotated[int, {}], Annotated[int, set()], int))
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
(str, Annotated[int, {}], Annotated[str, {}], int))
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
(str, Annotated[int, {}], Annotated[str, {}], int))

self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
(Annotated[int, 1], str, Annotated[str, {}], int))
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
(Annotated[int, 1], str, Annotated[str, {}], int))

import dataclasses
@dataclasses.dataclass
class ValueRange:
lo: int
hi: int
v = ValueRange(1, 2)
self.assertEqual(get_args(Annotated[int, v] | None),
(Annotated[int, v], types.NoneType))
self.assertEqual(get_args(Union[Annotated[int, v], None]),
(Annotated[int, v], types.NoneType))
self.assertEqual(get_args(Optional[Annotated[int, v]]),
(Annotated[int, v], types.NoneType))

# Unhashable metadata duplicated:
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
Annotated[int, {}] | int)
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
int | Annotated[int, {}])
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
Union[Annotated[int, {}], int])
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
Union[int, Annotated[int, {}]])

def test_order_in_union(self):
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
for args in itertools.permutations(get_args(expr1)):
with self.subTest(args=args):
self.assertEqual(expr1, reduce(operator.or_, args))

expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
for args in itertools.permutations(get_args(expr2)):
with self.subTest(args=args):
self.assertEqual(expr2, Union[args])

def test_specialize(self):
L = Annotated[List[T], "my decoration"]
LI = Annotated[List[int], "my decoration"]
Expand All @@ -8544,6 +8633,16 @@ def test_hash_eq(self):
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
)
# Unhashable `metadata` raises `TypeError`:
a1 = Annotated[int, []]
with self.assertRaises(TypeError):
hash(a1)

class A:
__hash__ = None
a2 = Annotated[int, A()]
with self.assertRaises(TypeError):
hash(a2)

def test_instantiate(self):
class C:
Expand Down
45 changes: 31 additions & 14 deletions Lib/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,19 +308,33 @@ def _unpack_args(args):
newargs.append(arg)
return newargs

def _deduplicate(params):
def _deduplicate(params, *, unhashable_fallback=False):
# Weed out strict duplicates, preserving the first of each occurrence.
all_params = set(params)
if len(all_params) < len(params):
new_params = []
for t in params:
if t in all_params:
new_params.append(t)
all_params.remove(t)
params = new_params
assert not all_params, all_params
return params

try:
return dict.fromkeys(params)
except TypeError:
if not unhashable_fallback:
raise
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
return _deduplicate_unhashable(params)

def _deduplicate_unhashable(unhashable_params):
new_unhashable = []
for t in unhashable_params:
if t not in new_unhashable:
new_unhashable.append(t)
return new_unhashable

def _compare_args_orderless(first_args, second_args):
first_unhashable = _deduplicate_unhashable(first_args)
second_unhashable = _deduplicate_unhashable(second_args)
t = list(second_unhashable)
try:
for elem in first_unhashable:
t.remove(elem)
except ValueError:
return False
return not t

def _remove_dups_flatten(parameters):
"""Internal helper for Union creation and substitution.
Expand All @@ -335,7 +349,7 @@ def _remove_dups_flatten(parameters):
else:
params.append(p)

return tuple(_deduplicate(params))
return tuple(_deduplicate(params, unhashable_fallback=True))


def _flatten_literal_params(parameters):
Expand Down Expand Up @@ -1555,7 +1569,10 @@ def copy_with(self, params):
def __eq__(self, other):
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
return NotImplemented
return set(self.__args__) == set(other.__args__)
try: # fast path
return set(self.__args__) == set(other.__args__)
except TypeError: # not hashable, slow path
return _compare_args_orderless(self.__args__, other.__args__)

def __hash__(self):
return hash(frozenset(self.__args__))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Allow creating :ref:`union of types<types-union>` for
:class:`typing.Annotated` with unhashable metadata.