Skip to content

Commit ac472b3

Browse files
authored
[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) (GH-23335)
Literal equality no longer depends on the order of arguments. Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function. Add deduplication of `typing.Literal` arguments. (cherry picked from commit f03d318) Co-authored-by: Yurii Karabas <1998uriyyo@gmail.com>
1 parent 656d50f commit ac472b3

File tree

4 files changed

+105
-23
lines changed

4 files changed

+105
-23
lines changed

Lib/test/test_typing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,7 @@ def test_repr(self):
532532
self.assertEqual(repr(Literal[int]), "typing.Literal[int]")
533533
self.assertEqual(repr(Literal), "typing.Literal")
534534
self.assertEqual(repr(Literal[None]), "typing.Literal[None]")
535+
self.assertEqual(repr(Literal[1, 2, 3, 3]), "typing.Literal[1, 2, 3]")
535536

536537
def test_cannot_init(self):
537538
with self.assertRaises(TypeError):
@@ -563,6 +564,30 @@ def test_no_multiple_subscripts(self):
563564
with self.assertRaises(TypeError):
564565
Literal[1][1]
565566

567+
def test_equal(self):
568+
self.assertNotEqual(Literal[0], Literal[False])
569+
self.assertNotEqual(Literal[True], Literal[1])
570+
self.assertNotEqual(Literal[1], Literal[2])
571+
self.assertNotEqual(Literal[1, True], Literal[1])
572+
self.assertEqual(Literal[1], Literal[1])
573+
self.assertEqual(Literal[1, 2], Literal[2, 1])
574+
self.assertEqual(Literal[1, 2, 3], Literal[1, 2, 3, 3])
575+
576+
def test_args(self):
577+
self.assertEqual(Literal[1, 2, 3].__args__, (1, 2, 3))
578+
self.assertEqual(Literal[1, 2, 3, 3].__args__, (1, 2, 3))
579+
self.assertEqual(Literal[1, Literal[2], Literal[3, 4]].__args__, (1, 2, 3, 4))
580+
# Mutable arguments will not be deduplicated
581+
self.assertEqual(Literal[[], []].__args__, ([], []))
582+
583+
def test_flatten(self):
584+
l1 = Literal[Literal[1], Literal[2], Literal[3]]
585+
l2 = Literal[Literal[1, 2], 3]
586+
l3 = Literal[Literal[1, 2, 3]]
587+
for l in l1, l2, l3:
588+
self.assertEqual(l, Literal[1, 2, 3])
589+
self.assertEqual(l.__args__, (1, 2, 3))
590+
566591

567592
XK = TypeVar('XK', str, bytes)
568593
XV = TypeVar('XV')

Lib/typing.py

Lines changed: 77 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
200200
f" actual {alen}, expected {elen}")
201201

202202

203+
def _deduplicate(params):
204+
# Weed out strict duplicates, preserving the first of each occurrence.
205+
all_params = set(params)
206+
if len(all_params) < len(params):
207+
new_params = []
208+
for t in params:
209+
if t in all_params:
210+
new_params.append(t)
211+
all_params.remove(t)
212+
params = new_params
213+
assert not all_params, all_params
214+
return params
215+
216+
203217
def _remove_dups_flatten(parameters):
204218
"""An internal helper for Union creation and substitution: flatten Unions
205219
among parameters, then remove duplicates.
@@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
213227
params.extend(p[1:])
214228
else:
215229
params.append(p)
216-
# Weed out strict duplicates, preserving the first of each occurrence.
217-
all_params = set(params)
218-
if len(all_params) < len(params):
219-
new_params = []
220-
for t in params:
221-
if t in all_params:
222-
new_params.append(t)
223-
all_params.remove(t)
224-
params = new_params
225-
assert not all_params, all_params
230+
231+
return tuple(_deduplicate(params))
232+
233+
234+
def _flatten_literal_params(parameters):
235+
"""An internal helper for Literal creation: flatten Literals among parameters"""
236+
params = []
237+
for p in parameters:
238+
if isinstance(p, _LiteralGenericAlias):
239+
params.extend(p.__args__)
240+
else:
241+
params.append(p)
226242
return tuple(params)
227243

228244

229245
_cleanups = []
230246

231247

232-
def _tp_cache(func):
248+
def _tp_cache(func=None, /, *, typed=False):
233249
"""Internal wrapper caching __getitem__ of generic types with a fallback to
234250
original function for non-hashable arguments.
235251
"""
236-
cached = functools.lru_cache()(func)
237-
_cleanups.append(cached.cache_clear)
252+
def decorator(func):
253+
cached = functools.lru_cache(typed=typed)(func)
254+
_cleanups.append(cached.cache_clear)
238255

239-
@functools.wraps(func)
240-
def inner(*args, **kwds):
241-
try:
242-
return cached(*args, **kwds)
243-
except TypeError:
244-
pass # All real errors (not unhashable args) are raised below.
245-
return func(*args, **kwds)
246-
return inner
256+
@functools.wraps(func)
257+
def inner(*args, **kwds):
258+
try:
259+
return cached(*args, **kwds)
260+
except TypeError:
261+
pass # All real errors (not unhashable args) are raised below.
262+
return func(*args, **kwds)
263+
return inner
264+
265+
if func is not None:
266+
return decorator(func)
247267

268+
return decorator
248269

249270
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
250271
"""Evaluate all forward references in the given type t.
@@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
317338
def __getitem__(self, parameters):
318339
return self._getitem(self, parameters)
319340

341+
342+
class _LiteralSpecialForm(_SpecialForm, _root=True):
343+
@_tp_cache(typed=True)
344+
def __getitem__(self, parameters):
345+
return self._getitem(self, parameters)
346+
347+
320348
@_SpecialForm
321349
def Any(self, parameters):
322350
"""Special type indicating an unconstrained type.
@@ -434,7 +462,7 @@ def Optional(self, parameters):
434462
arg = _type_check(parameters, f"{self} requires a single type.")
435463
return Union[arg, type(None)]
436464

437-
@_SpecialForm
465+
@_LiteralSpecialForm
438466
def Literal(self, parameters):
439467
"""Special typing form to define literal types (a.k.a. value types).
440468
@@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
458486
"""
459487
# There is no '_type_check' call because arguments to Literal[...] are
460488
# values, not types.
461-
return _GenericAlias(self, parameters)
489+
if not isinstance(parameters, tuple):
490+
parameters = (parameters,)
491+
492+
parameters = _flatten_literal_params(parameters)
493+
494+
try:
495+
parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
496+
except TypeError: # unhashable parameters
497+
pass
498+
499+
return _LiteralGenericAlias(self, parameters)
462500

463501

464502
class ForwardRef(_Final, _root=True):
@@ -881,6 +919,22 @@ def __repr__(self):
881919
return super().__repr__()
882920

883921

922+
def _value_and_type_iter(parameters):
923+
return ((p, type(p)) for p in parameters)
924+
925+
926+
class _LiteralGenericAlias(_GenericAlias, _root=True):
927+
928+
def __eq__(self, other):
929+
if not isinstance(other, _LiteralGenericAlias):
930+
return NotImplemented
931+
932+
return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
933+
934+
def __hash__(self):
935+
return hash(tuple(_value_and_type_iter(self.__args__)))
936+
937+
884938
class Generic:
885939
"""Abstract base class for generic types.
886940

Misc/ACKS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,7 @@ Jan Kanis
855855
Rafe Kaplan
856856
Jacob Kaplan-Moss
857857
Allison Kaptur
858+
Yurii Karabas
858859
Janne Karila
859860
Per Øyvind Karlsen
860861
Anton Kasyanov
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix various issues with ``typing.Literal`` parameter handling (flatten,
2+
deduplicate, use type to cache key). Patch provided by Yurii Karabas.

0 commit comments

Comments
 (0)