Skip to content

Commit 90f75e1

Browse files
miss-islingtonsobolevnAlexWaygood
authored
[3.12] gh-112281: Allow Union with unhashable Annotated metadata (GH-112283) (#116213)
Co-authored-by: Nikita Sobolev <mail@sobolevn.me> Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
1 parent 16be4a3 commit 90f75e1

File tree

4 files changed

+156
-18
lines changed

4 files changed

+156
-18
lines changed

Lib/test/test_types.py

+20
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,26 @@ def test_hash(self):
709709
self.assertEqual(hash(int | str), hash(str | int))
710710
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
711711

712+
def test_union_of_unhashable(self):
713+
class UnhashableMeta(type):
714+
__hash__ = None
715+
716+
class A(metaclass=UnhashableMeta): ...
717+
class B(metaclass=UnhashableMeta): ...
718+
719+
self.assertEqual((A | B).__args__, (A, B))
720+
union1 = A | B
721+
with self.assertRaises(TypeError):
722+
hash(union1)
723+
724+
union2 = int | B
725+
with self.assertRaises(TypeError):
726+
hash(union2)
727+
728+
union3 = A | int
729+
with self.assertRaises(TypeError):
730+
hash(union3)
731+
712732
def test_instancecheck_and_subclasscheck(self):
713733
for x in (int | str, typing.Union[int, str]):
714734
with self.subTest(x=x):

Lib/test/test_typing.py

+103-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import collections
33
import collections.abc
44
from collections import defaultdict
5-
from functools import lru_cache, wraps
5+
from functools import lru_cache, wraps, reduce
66
import gc
77
import inspect
88
import itertools
9+
import operator
910
import pickle
1011
import re
1112
import sys
@@ -1770,6 +1771,26 @@ def test_union_union(self):
17701771
v = Union[u, Employee]
17711772
self.assertEqual(v, Union[int, float, Employee])
17721773

1774+
def test_union_of_unhashable(self):
1775+
class UnhashableMeta(type):
1776+
__hash__ = None
1777+
1778+
class A(metaclass=UnhashableMeta): ...
1779+
class B(metaclass=UnhashableMeta): ...
1780+
1781+
self.assertEqual(Union[A, B].__args__, (A, B))
1782+
union1 = Union[A, B]
1783+
with self.assertRaises(TypeError):
1784+
hash(union1)
1785+
1786+
union2 = Union[int, B]
1787+
with self.assertRaises(TypeError):
1788+
hash(union2)
1789+
1790+
union3 = Union[A, int]
1791+
with self.assertRaises(TypeError):
1792+
hash(union3)
1793+
17731794
def test_repr(self):
17741795
self.assertEqual(repr(Union), 'typing.Union')
17751796
u = Union[Employee, int]
@@ -5295,10 +5316,8 @@ def some(self):
52955316
self.assertFalse(hasattr(WithOverride.some, "__override__"))
52965317

52975318
def test_multiple_decorators(self):
5298-
import functools
5299-
53005319
def with_wraps(f): # similar to `lru_cache` definition
5301-
@functools.wraps(f)
5320+
@wraps(f)
53025321
def wrapper(*args, **kwargs):
53035322
return f(*args, **kwargs)
53045323
return wrapper
@@ -8183,6 +8202,76 @@ def test_flatten(self):
81838202
self.assertEqual(A.__metadata__, (4, 5))
81848203
self.assertEqual(A.__origin__, int)
81858204

8205+
def test_deduplicate_from_union(self):
8206+
# Regular:
8207+
self.assertEqual(get_args(Annotated[int, 1] | int),
8208+
(Annotated[int, 1], int))
8209+
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
8210+
(Annotated[int, 1], int))
8211+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
8212+
(Annotated[int, 1], Annotated[int, 2], int))
8213+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
8214+
(Annotated[int, 1], Annotated[int, 2], int))
8215+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
8216+
(Annotated[int, 1], Annotated[str, 1], int))
8217+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
8218+
(Annotated[int, 1], Annotated[str, 1], int))
8219+
8220+
# Duplicates:
8221+
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
8222+
Annotated[int, 1] | int)
8223+
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
8224+
Union[Annotated[int, 1], int])
8225+
8226+
# Unhashable metadata:
8227+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
8228+
(str, Annotated[int, {}], Annotated[int, set()], int))
8229+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
8230+
(str, Annotated[int, {}], Annotated[int, set()], int))
8231+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
8232+
(str, Annotated[int, {}], Annotated[str, {}], int))
8233+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
8234+
(str, Annotated[int, {}], Annotated[str, {}], int))
8235+
8236+
self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
8237+
(Annotated[int, 1], str, Annotated[str, {}], int))
8238+
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
8239+
(Annotated[int, 1], str, Annotated[str, {}], int))
8240+
8241+
import dataclasses
8242+
@dataclasses.dataclass
8243+
class ValueRange:
8244+
lo: int
8245+
hi: int
8246+
v = ValueRange(1, 2)
8247+
self.assertEqual(get_args(Annotated[int, v] | None),
8248+
(Annotated[int, v], types.NoneType))
8249+
self.assertEqual(get_args(Union[Annotated[int, v], None]),
8250+
(Annotated[int, v], types.NoneType))
8251+
self.assertEqual(get_args(Optional[Annotated[int, v]]),
8252+
(Annotated[int, v], types.NoneType))
8253+
8254+
# Unhashable metadata duplicated:
8255+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
8256+
Annotated[int, {}] | int)
8257+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
8258+
int | Annotated[int, {}])
8259+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
8260+
Union[Annotated[int, {}], int])
8261+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
8262+
Union[int, Annotated[int, {}]])
8263+
8264+
def test_order_in_union(self):
8265+
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
8266+
for args in itertools.permutations(get_args(expr1)):
8267+
with self.subTest(args=args):
8268+
self.assertEqual(expr1, reduce(operator.or_, args))
8269+
8270+
expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
8271+
for args in itertools.permutations(get_args(expr2)):
8272+
with self.subTest(args=args):
8273+
self.assertEqual(expr2, Union[args])
8274+
81868275
def test_specialize(self):
81878276
L = Annotated[List[T], "my decoration"]
81888277
LI = Annotated[List[int], "my decoration"]
@@ -8203,6 +8292,16 @@ def test_hash_eq(self):
82038292
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
82048293
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
82058294
)
8295+
# Unhashable `metadata` raises `TypeError`:
8296+
a1 = Annotated[int, []]
8297+
with self.assertRaises(TypeError):
8298+
hash(a1)
8299+
8300+
class A:
8301+
__hash__ = None
8302+
a2 = Annotated[int, A()]
8303+
with self.assertRaises(TypeError):
8304+
hash(a2)
82068305

82078306
def test_instantiate(self):
82088307
class C:

Lib/typing.py

+31-14
Original file line numberDiff line numberDiff line change
@@ -314,19 +314,33 @@ def _unpack_args(args):
314314
newargs.append(arg)
315315
return newargs
316316

317-
def _deduplicate(params):
317+
def _deduplicate(params, *, unhashable_fallback=False):
318318
# Weed out strict duplicates, preserving the first of each occurrence.
319-
all_params = set(params)
320-
if len(all_params) < len(params):
321-
new_params = []
322-
for t in params:
323-
if t in all_params:
324-
new_params.append(t)
325-
all_params.remove(t)
326-
params = new_params
327-
assert not all_params, all_params
328-
return params
329-
319+
try:
320+
return dict.fromkeys(params)
321+
except TypeError:
322+
if not unhashable_fallback:
323+
raise
324+
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
325+
return _deduplicate_unhashable(params)
326+
327+
def _deduplicate_unhashable(unhashable_params):
328+
new_unhashable = []
329+
for t in unhashable_params:
330+
if t not in new_unhashable:
331+
new_unhashable.append(t)
332+
return new_unhashable
333+
334+
def _compare_args_orderless(first_args, second_args):
335+
first_unhashable = _deduplicate_unhashable(first_args)
336+
second_unhashable = _deduplicate_unhashable(second_args)
337+
t = list(second_unhashable)
338+
try:
339+
for elem in first_unhashable:
340+
t.remove(elem)
341+
except ValueError:
342+
return False
343+
return not t
330344

331345
def _remove_dups_flatten(parameters):
332346
"""Internal helper for Union creation and substitution.
@@ -341,7 +355,7 @@ def _remove_dups_flatten(parameters):
341355
else:
342356
params.append(p)
343357

344-
return tuple(_deduplicate(params))
358+
return tuple(_deduplicate(params, unhashable_fallback=True))
345359

346360

347361
def _flatten_literal_params(parameters):
@@ -1548,7 +1562,10 @@ def copy_with(self, params):
15481562
def __eq__(self, other):
15491563
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
15501564
return NotImplemented
1551-
return set(self.__args__) == set(other.__args__)
1565+
try: # fast path
1566+
return set(self.__args__) == set(other.__args__)
1567+
except TypeError: # not hashable, slow path
1568+
return _compare_args_orderless(self.__args__, other.__args__)
15521569

15531570
def __hash__(self):
15541571
return hash(frozenset(self.__args__))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Allow creating :ref:`union of types<types-union>` for
2+
:class:`typing.Annotated` with unhashable metadata.

0 commit comments

Comments
 (0)