diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index 459d691f65ac..db66f586f051 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -7,4 +7,16 @@ # NOTE: This module must support Python 2.7 in addition to Python 3.x -# (TODO: Declare TypedDict and other extensions here) + +def TypedDict(typename, fields): + """TypedDict creates a dictionary type that expects all of its + instances to have a certain set of keys, with each key + associated with a value of a consistent type. This expectation + is not checked at runtime but is only enforced by typecheckers. + """ + def new_dict(*args, **kwargs): + return dict(*args, **kwargs) + + new_dict.__name__ = typename + new_dict.__supertype__ = dict + return new_dict diff --git a/mypy/checker.py b/mypy/checker.py index 6a0a987ad783..f857ed67e355 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -20,7 +20,7 @@ Context, ListComprehension, ConditionalExpr, GeneratorExpr, Decorator, SetExpr, TypeVarExpr, NewTypeExpr, PrintStmt, LITERAL_TYPE, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr, - YieldFromExpr, NamedTupleExpr, SetComprehension, + YieldFromExpr, NamedTupleExpr, TypedDictExpr, SetComprehension, DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr, RefExpr, YieldExpr, BackquoteExpr, ImportFrom, ImportAll, ImportBase, AwaitExpr, @@ -2082,6 +2082,10 @@ def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type: # TODO: Perhaps return a type object type? return AnyType() + def visit_typeddict_expr(self, e: TypedDictExpr) -> Type: + # TODO: Perhaps return a type object type? + return AnyType() + def visit_list_expr(self, e: ListExpr) -> Type: return self.expr_checker.visit_list_expr(e) diff --git a/mypy/nodes.py b/mypy/nodes.py index ab5490c30a31..933a9934bf6c 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1747,7 +1747,7 @@ def accept(self, visitor: NodeVisitor[T]) -> T: class NamedTupleExpr(Expression): - """Named tuple expression namedtuple(...).""" + """Named tuple expression namedtuple(...) or NamedTuple(...).""" # The class representation of this named tuple (its tuple_type attribute contains # the tuple item types) @@ -1760,6 +1760,19 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_namedtuple_expr(self) +class TypedDictExpr(Expression): + """Typed dict expression TypedDict(...).""" + + # The class representation of this typed dict + info = None # type: TypeInfo + + def __init__(self, info: 'TypeInfo') -> None: + self.info = info + + def accept(self, visitor: NodeVisitor[T]) -> T: + return visitor.visit_typeddict_expr(self) + + class PromoteExpr(Expression): """Ducktype class decorator expression _promote(...).""" @@ -1882,6 +1895,9 @@ class is generic then it will be a type constructor of higher kind. # Is this a named tuple type? is_named_tuple = False + # Is this a typed dict type? + is_typed_dict = False + # Is this a newtype type? is_newtype = False @@ -1893,7 +1909,7 @@ class is generic then it will be a type constructor of higher kind. FLAGS = [ 'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple', - 'is_newtype', 'is_dummy' + 'is_typed_dict', 'is_newtype', 'is_dummy' ] def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None: diff --git a/mypy/semanal.py b/mypy/semanal.py index 864adaea7371..dccd260b3669 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -60,7 +60,7 @@ FuncExpr, MDEF, FuncBase, Decorator, SetExpr, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, - YieldFromExpr, NamedTupleExpr, NonlocalDecl, SymbolNode, + YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, @@ -1127,6 +1127,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.process_newtype_declaration(s) self.process_typevar_declaration(s) self.process_namedtuple_definition(s) + self.process_typeddict_definition(s) if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and @@ -1498,9 +1499,9 @@ def get_typevar_declaration(self, s: AssignmentStmt) -> Optional[CallExpr]: if not isinstance(s.rvalue, CallExpr): return None call = s.rvalue - if not isinstance(call.callee, RefExpr): - return None callee = call.callee + if not isinstance(callee, RefExpr): + return None if callee.fullname != 'typing.TypeVar': return None return call @@ -1579,10 +1580,9 @@ def process_namedtuple_definition(self, s: AssignmentStmt) -> None: # Yes, it's a valid namedtuple definition. Add it to the symbol table. node = self.lookup(name, s) node.kind = GDEF # TODO locally defined namedtuple - # TODO call.analyzed node.node = named_tuple - def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo: + def check_namedtuple(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines a namedtuple. The optional var_name argument is the name of the variable to @@ -1596,9 +1596,9 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo: if not isinstance(node, CallExpr): return None call = node - if not isinstance(call.callee, RefExpr): - return None callee = call.callee + if not isinstance(callee, RefExpr): + return None fullname = callee.fullname if fullname not in ('collections.namedtuple', 'typing.NamedTuple'): return None @@ -1607,9 +1607,9 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo: # Error. Construct dummy return value. return self.build_namedtuple_typeinfo('namedtuple', [], []) else: - # Give it a unique name derived from the line number. name = cast(StrExpr, call.args[0]).value if name != var_name: + # Give it a unique name derived from the line number. name += '@' + str(call.line) info = self.build_namedtuple_typeinfo(name, items, types) # Store it as a global just in case it would remain anonymous. @@ -1620,7 +1620,7 @@ def check_namedtuple(self, node: Expression, var_name: str = None) -> TypeInfo: def parse_namedtuple_args(self, call: CallExpr, fullname: str) -> Tuple[List[str], List[Type], bool]: - # TODO Share code with check_argument_count in checkexpr.py? + # TODO: Share code with check_argument_count in checkexpr.py? args = call.args if len(args) < 2: return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) @@ -1777,6 +1777,114 @@ def analyze_types(self, items: List[Expression]) -> List[Type]: result.append(AnyType()) return result + def process_typeddict_definition(self, s: AssignmentStmt) -> None: + """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + typed_dict = self.check_typeddict(s.rvalue, name) + if typed_dict is None: + return + # Yes, it's a valid TypedDict definition. Add it to the symbol table. + node = self.lookup(name, s) + node.kind = GDEF # TODO locally defined TypedDict + node.node = typed_dict + + def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines a TypedDict. + + The optional var_name argument is the name of the variable to + which this is assigned, if any. + + If it does, return the corresponding TypeInfo. Return None otherwise. + + If the definition is invalid but looks like a TypedDict, + report errors but return (some) TypeInfo. + """ + if not isinstance(node, CallExpr): + return None + call = node + callee = call.callee + if not isinstance(callee, RefExpr): + return None + fullname = callee.fullname + if fullname != 'mypy_extensions.TypedDict': + return None + items, types, ok = self.parse_typeddict_args(call, fullname) + if not ok: + # Error. Construct dummy return value. + return self.build_typeddict_typeinfo('TypedDict', [], []) + else: + name = cast(StrExpr, call.args[0]).value + if name != var_name: + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_typeddict_typeinfo(name, items, types) + # Store it as a global just in case it would remain anonymous. + self.globals[name] = SymbolTableNode(GDEF, info, self.cur_mod_id) + call.analyzed = TypedDictExpr(info) + call.analyzed.set_line(call.line, call.column) + return info + + def parse_typeddict_args(self, call: CallExpr, + fullname: str) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) + if len(args) > 2: + return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) + # TODO: Support keyword arguments + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_typeddict_arg( + "TypedDict() expects a string literal as the first argument", call) + if not isinstance(args[1], DictExpr): + return self.fail_typeddict_arg( + "TypedDict() expects a dictionary literal as the second argument", call) + dictexpr = args[1] + items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items, call) + return items, types, ok + + def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], + context: Context) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for (field_name_expr, field_type_expr) in dict_items: + if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(field_name_expr.value) + else: + return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) + try: + type = expr_to_unanalyzed_type(field_type_expr) + except TypeTranslationError: + return self.fail_typeddict_arg('Invalid field type', field_type_expr) + types.append(self.anal_type(type)) + return items, types, True + + def fail_typeddict_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def build_typeddict_typeinfo(self, name: str, items: List[str], + types: List[Type]) -> TypeInfo: + strtype = self.named_type('__builtins__.str') # type: Type + dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or self.object_type()) + fallback = dictype + + info = self.basic_new_typeinfo(name, fallback) + info.is_typed_dict = True + + # (TODO: Store {items, types} inside "info" somewhere for use later. + # Probably inside a new "info.keys" field which + # would be analogous to "info.names".) + + return info + def visit_decorator(self, dec: Decorator) -> None: for d in dec.decorators: d.accept(self) diff --git a/mypy/strconv.py b/mypy/strconv.py index 38c36917eefe..bb167f1ef8b6 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -422,6 +422,10 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str: o.info.name(), o.info.tuple_type) + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str: + return 'TypedDictExpr:{}({})'.format(o.line, + o.info.name()) + def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> str: return 'PromoteExpr:{}({})'.format(o.line, o.type) diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index 3486d0fdc528..106d52ff3d17 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -29,6 +29,7 @@ 'semanal-statements.test', 'semanal-abstractclasses.test', 'semanal-namedtuple.test', + 'semanal-typeddict.test', 'semanal-python2.test'] @@ -78,6 +79,7 @@ def test_semanal(testcase): # TODO the test is not reliable if (not f.path.endswith((os.sep + 'builtins.pyi', 'typing.pyi', + 'mypy_extensions.pyi', 'abc.pyi', 'collections.pyi')) and not os.path.basename(f.path).startswith('_') diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 100ff7854a8c..c030f0e1e5dc 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -17,7 +17,7 @@ SliceExpr, OpExpr, UnaryExpr, FuncExpr, TypeApplication, PrintStmt, SymbolTable, RefExpr, TypeVarExpr, NewTypeExpr, PromoteExpr, ComparisonExpr, TempNode, StarExpr, Statement, Expression, - YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension, + YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension, DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) @@ -492,6 +492,9 @@ def visit_newtype_expr(self, node: NewTypeExpr) -> NewTypeExpr: def visit_namedtuple_expr(self, node: NamedTupleExpr) -> NamedTupleExpr: return NamedTupleExpr(node.info) + def visit_typeddict_expr(self, node: TypedDictExpr) -> Node: + return TypedDictExpr(node.info) + def visit__promote_expr(self, node: PromoteExpr) -> PromoteExpr: return PromoteExpr(node.type) diff --git a/mypy/visitor.py b/mypy/visitor.py index b4c2cc86038b..33f287b9863b 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -225,6 +225,9 @@ def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: pass + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: + pass + def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> T: pass diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi new file mode 100644 index 000000000000..6c57954d8a81 --- /dev/null +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -0,0 +1,6 @@ +from typing import Dict, Type, TypeVar + +T = TypeVar('T') + + +def TypedDict(typename: str, fields: Dict[str, Type[T]]) -> Type[dict]: ... diff --git a/test-data/unit/semanal-namedtuple.test b/test-data/unit/semanal-namedtuple.test index bb2592004908..b590ee9bdd7d 100644 --- a/test-data/unit/semanal-namedtuple.test +++ b/test-data/unit/semanal-namedtuple.test @@ -154,9 +154,11 @@ N = namedtuple('N', 1) # E: List or tuple literal expected as the second argumen from collections import namedtuple N = namedtuple('N', ['x', 1]) # E: String literal expected as namedtuple() item -[case testNamedTupleWithInvalidArgs] +-- NOTE: The following code works at runtime but is not yet supported by mypy. +-- Keyword arguments may potentially be supported in the future. +[case testNamedTupleWithNonpositionalArgs] from collections import namedtuple -N = namedtuple('N', x=['x']) # E: Unexpected arguments to namedtuple() +N = namedtuple(typename='N', field_names=['x']) # E: Unexpected arguments to namedtuple() [case testInvalidNamedTupleBaseClass] from typing import NamedTuple diff --git a/test-data/unit/semanal-typeddict.test b/test-data/unit/semanal-typeddict.test new file mode 100644 index 000000000000..98aef1a5ef1e --- /dev/null +++ b/test-data/unit/semanal-typeddict.test @@ -0,0 +1,51 @@ +-- Semantic analysis of typed dicts + +[case testCanDefineTypedDictType] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': int, 'y': int}) +[builtins fixtures/dict.pyi] +[out] +MypyFile:1( + ImportFrom:1(mypy_extensions, [TypedDict]) + AssignmentStmt:2( + NameExpr(Point* [__main__.Point]) + TypedDictExpr:2(Point))) + +-- Errors + +[case testTypedDictWithTooFewArguments] +from mypy_extensions import TypedDict +Point = TypedDict('Point') # E: Too few arguments for TypedDict() +[builtins fixtures/dict.pyi] + +[case testTypedDictWithTooManyArguments] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': int, 'y': int}, dict) # E: Too many arguments for TypedDict() +[builtins fixtures/dict.pyi] + +[case testTypedDictWithInvalidName] +from mypy_extensions import TypedDict +Point = TypedDict(dict, {'x': int, 'y': int}) # E: TypedDict() expects a string literal as the first argument +[builtins fixtures/dict.pyi] + +[case testTypedDictWithInvalidItems] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x'}) # E: TypedDict() expects a dictionary literal as the second argument +[builtins fixtures/dict.pyi] + +-- NOTE: The following code works at runtime but is not yet supported by mypy. +-- Keyword arguments may potentially be supported in the future. +[case testTypedDictWithNonpositionalArgs] +from mypy_extensions import TypedDict +Point = TypedDict(typename='Point', fields={'x': int, 'y': int}) # E: Unexpected arguments to TypedDict() +[builtins fixtures/dict.pyi] + +[case testTypedDictWithInvalidItemName] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {int: int, int: int}) # E: Invalid TypedDict() field name +[builtins fixtures/dict.pyi] + +[case testTypedDictWithInvalidItemType] +from mypy_extensions import TypedDict +Point = TypedDict('Point', {'x': 1, 'y': 1}) # E: Invalid field type +[builtins fixtures/dict.pyi]