@@ -200,6 +200,20 @@ def _check_generic(cls, parameters, elen):
200
200
f" actual { alen } , expected { elen } " )
201
201
202
202
203
+ def _deduplicate (params ):
204
+ # Weed out strict duplicates, preserving the first of each occurrence.
205
+ all_params = set (params )
206
+ if len (all_params ) < len (params ):
207
+ new_params = []
208
+ for t in params :
209
+ if t in all_params :
210
+ new_params .append (t )
211
+ all_params .remove (t )
212
+ params = new_params
213
+ assert not all_params , all_params
214
+ return params
215
+
216
+
203
217
def _remove_dups_flatten (parameters ):
204
218
"""An internal helper for Union creation and substitution: flatten Unions
205
219
among parameters, then remove duplicates.
@@ -213,38 +227,45 @@ def _remove_dups_flatten(parameters):
213
227
params .extend (p [1 :])
214
228
else :
215
229
params .append (p )
216
- # Weed out strict duplicates, preserving the first of each occurrence.
217
- all_params = set (params )
218
- if len (all_params ) < len (params ):
219
- new_params = []
220
- for t in params :
221
- if t in all_params :
222
- new_params .append (t )
223
- all_params .remove (t )
224
- params = new_params
225
- assert not all_params , all_params
230
+
231
+ return tuple (_deduplicate (params ))
232
+
233
+
234
+ def _flatten_literal_params (parameters ):
235
+ """An internal helper for Literal creation: flatten Literals among parameters"""
236
+ params = []
237
+ for p in parameters :
238
+ if isinstance (p , _LiteralGenericAlias ):
239
+ params .extend (p .__args__ )
240
+ else :
241
+ params .append (p )
226
242
return tuple (params )
227
243
228
244
229
245
_cleanups = []
230
246
231
247
232
- def _tp_cache (func ):
248
+ def _tp_cache (func = None , / , * , typed = False ):
233
249
"""Internal wrapper caching __getitem__ of generic types with a fallback to
234
250
original function for non-hashable arguments.
235
251
"""
236
- cached = functools .lru_cache ()(func )
237
- _cleanups .append (cached .cache_clear )
252
+ def decorator (func ):
253
+ cached = functools .lru_cache (typed = typed )(func )
254
+ _cleanups .append (cached .cache_clear )
238
255
239
- @functools .wraps (func )
240
- def inner (* args , ** kwds ):
241
- try :
242
- return cached (* args , ** kwds )
243
- except TypeError :
244
- pass # All real errors (not unhashable args) are raised below.
245
- return func (* args , ** kwds )
246
- return inner
256
+ @functools .wraps (func )
257
+ def inner (* args , ** kwds ):
258
+ try :
259
+ return cached (* args , ** kwds )
260
+ except TypeError :
261
+ pass # All real errors (not unhashable args) are raised below.
262
+ return func (* args , ** kwds )
263
+ return inner
264
+
265
+ if func is not None :
266
+ return decorator (func )
247
267
268
+ return decorator
248
269
249
270
def _eval_type (t , globalns , localns , recursive_guard = frozenset ()):
250
271
"""Evaluate all forward references in the given type t.
@@ -317,6 +338,13 @@ def __subclasscheck__(self, cls):
317
338
def __getitem__ (self , parameters ):
318
339
return self ._getitem (self , parameters )
319
340
341
+
342
+ class _LiteralSpecialForm (_SpecialForm , _root = True ):
343
+ @_tp_cache (typed = True )
344
+ def __getitem__ (self , parameters ):
345
+ return self ._getitem (self , parameters )
346
+
347
+
320
348
@_SpecialForm
321
349
def Any (self , parameters ):
322
350
"""Special type indicating an unconstrained type.
@@ -434,7 +462,7 @@ def Optional(self, parameters):
434
462
arg = _type_check (parameters , f"{ self } requires a single type." )
435
463
return Union [arg , type (None )]
436
464
437
- @_SpecialForm
465
+ @_LiteralSpecialForm
438
466
def Literal (self , parameters ):
439
467
"""Special typing form to define literal types (a.k.a. value types).
440
468
@@ -458,7 +486,17 @@ def open_helper(file: str, mode: MODE) -> str:
458
486
"""
459
487
# There is no '_type_check' call because arguments to Literal[...] are
460
488
# values, not types.
461
- return _GenericAlias (self , parameters )
489
+ if not isinstance (parameters , tuple ):
490
+ parameters = (parameters ,)
491
+
492
+ parameters = _flatten_literal_params (parameters )
493
+
494
+ try :
495
+ parameters = tuple (p for p , _ in _deduplicate (list (_value_and_type_iter (parameters ))))
496
+ except TypeError : # unhashable parameters
497
+ pass
498
+
499
+ return _LiteralGenericAlias (self , parameters )
462
500
463
501
464
502
class ForwardRef (_Final , _root = True ):
@@ -881,6 +919,22 @@ def __repr__(self):
881
919
return super ().__repr__ ()
882
920
883
921
922
+ def _value_and_type_iter (parameters ):
923
+ return ((p , type (p )) for p in parameters )
924
+
925
+
926
+ class _LiteralGenericAlias (_GenericAlias , _root = True ):
927
+
928
+ def __eq__ (self , other ):
929
+ if not isinstance (other , _LiteralGenericAlias ):
930
+ return NotImplemented
931
+
932
+ return set (_value_and_type_iter (self .__args__ )) == set (_value_and_type_iter (other .__args__ ))
933
+
934
+ def __hash__ (self ):
935
+ return hash (tuple (_value_and_type_iter (self .__args__ )))
936
+
937
+
884
938
class Generic :
885
939
"""Abstract base class for generic types.
886
940
0 commit comments