diff --git a/mypy/build.py b/mypy/build.py index 69884f255b0e..cf5a069b77fa 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -1935,7 +1935,7 @@ def parse_file(self) -> None: # definitions in the file to the symbol table. We must do # this before processing imports, since this may mark some # import statements as unreachable. - first = SemanticAnalyzerPass1(manager.semantic_analyzer) + first = SemanticAnalyzerPass1(manager.semantic_analyzer, modules) with self.wrap_context(): first.visit_file(self.tree, self.xpath, self.id, self.options) diff --git a/mypy/semanal.py b/mypy/semanal.py index daac4612f6de..3c02c2f94cef 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -32,6 +32,7 @@ traverse the entire AST. """ +from collections import defaultdict from contextlib import contextmanager from typing import ( @@ -216,6 +217,9 @@ class SemanticAnalyzerPass2(NodeVisitor[None], # Postponed functions collected if # postpone_nested_functions_stack[-1] == FUNCTION_FIRST_PHASE_POSTPONE_SECOND. postponed_functions_stack = None # type: List[List[Node]] + # Callbacks to add names from modules with import * that have not yet been + # processed (used in pass 1, but defined here because the object is longer-lived). + import_all_callbacks = None # type: Dict[str, List[SymbolTable]] loop_depth = 0 # Depth of breakable loops cur_mod_id = '' # Current module id (or None) (phase 2) @@ -256,6 +260,7 @@ def __init__(self, # for processing module top levels in fine-grained incremental mode. self.recurse_into_functions = True self.scope = Scope() + self.import_all_callbacks = defaultdict(list) def visit_file(self, file_node: MypyFile, fnam: str, options: Options, patches: List[Tuple[int, Callable[[], None]]]) -> None: diff --git a/mypy/semanal_pass1.py b/mypy/semanal_pass1.py index 2037685656bb..3647ce4efa4e 100644 --- a/mypy/semanal_pass1.py +++ b/mypy/semanal_pass1.py @@ -16,8 +16,7 @@ This pass also infers the reachability of certain if staments, such as those with platform checks. """ - -from typing import List, Tuple +from typing import Dict, List, Tuple from mypy import experiments from mypy.nodes import ( @@ -38,11 +37,22 @@ class SemanticAnalyzerPass1(NodeVisitor[None]): """First phase of semantic analysis. - See docstring of 'analyze()' below for a description of what this does. + See the module docstring for a description of what this does. """ - def __init__(self, sem: SemanticAnalyzerPass2) -> None: + def __init__(self, sem: SemanticAnalyzerPass2, modules: Dict[str, MypyFile]) -> None: self.sem = sem + self.modules = modules + self.is_stub_file = False + + def add_import_all_callback(self, imported: str, globals: SymbolTable, + mod_node: MypyFile, node: ImportAll) -> None: + self.sem.import_all_callbacks[imported].append((globals, mod_node, node)) + + def process_import_all_callback(self, imported: str, source_globals: SymbolTable) -> None: + callbacks = self.sem.import_all_callbacks.pop(imported, []) + for target_globals, mod_node, import_node in callbacks: + self.import_globals_into(source_globals, target_globals, mod_node, import_node) def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) -> None: """Perform the first analysis pass. @@ -63,6 +73,7 @@ def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) - self.sem.options = options # Needed because we sometimes call into it self.pyversion = options.python_version self.platform = options.platform + self.is_stub_file = fnam.endswith('.pyi') sem.cur_mod_id = mod_id sem.cur_mod_node = file sem.errors.set_file(fnam, mod_id, scope=sem.scope) @@ -130,6 +141,7 @@ def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) - del self.sem.options sem.scope.leave() + self.process_import_all_callback(mod_id, sem.globals) def visit_block(self, b: Block) -> None: if b.is_unreachable: @@ -238,6 +250,21 @@ def process_nested_classes(self, outer_def: ClassDef) -> None: node.accept(self) self.sem.leave_class() + def import_globals_into(self, source_module: SymbolTable, target_module: SymbolTable, + mod_node: MypyFile, import_node: ImportAll) -> None: + print('import globals from', source_module, 'into', target_module) + for name, orig_node in source_module.items(): + if (not orig_node.module_public or + (name.startswith('_') and '__all__' not in source_module)): + continue + if name not in target_module: + sym = create_indirect_imported_name(mod_node, + import_node.id, + import_node.relative, + name) + if sym: + target_module[name] = sym + def visit_import_from(self, node: ImportFrom) -> None: # We can't bind module names during the first pass, as the target module might be # unprocessed. However, we add dummy unbound imported names to the symbol table so @@ -254,6 +281,7 @@ def visit_import_from(self, node: ImportFrom) -> None: node.relative, name) if sym: + sym.module_public = not self.is_stub_file or as_name is not None self.add_symbol(imported_name, sym, context=node) def visit_import(self, node: Import) -> None: @@ -266,10 +294,20 @@ def visit_import(self, node: Import) -> None: # For 'import a.b.c' we create symbol 'a'. imported_id = imported_id.split('.')[0] if imported_id not in self.sem.globals: - self.add_symbol(imported_id, SymbolTableNode(UNBOUND_IMPORTED, None), node) + module_public = not self.is_stub_file or as_id is not None + sym = SymbolTableNode(UNBOUND_IMPORTED, None, module_public=module_public) + self.add_symbol(imported_id, sym, node) def visit_import_all(self, node: ImportAll) -> None: node.is_top_level = self.sem.is_module_scope() + if not self.sem.is_module_scope(): + return + # TODO relative import + if node.id in self.modules: + self.import_globals_into(self.modules[node.id].names, self.sem.globals, + self.sem.cur_mod_node, node) + else: + self.add_import_all_callback(node.id, self.sem.globals, self.sem.cur_mod_node, node) def visit_while_stmt(self, s: WhileStmt) -> None: if self.sem.is_module_scope(): diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index 6186664a746e..363ba6ec7314 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -262,6 +262,23 @@ main:1: error: Cannot find module named 'nonexistent' main:1: note: (Perhaps setting MYPYPATH or using the "--ignore-missing-imports" flag would help) main:2: error: Unsupported left operand type for + ("None") +[case testRepeatedImportStar] +import a + +[file a.py] +from b import * +from c import greet + +[file b.py] +HELLO = 'hello there' + +[file c.py] +from a import HELLO +def greet() -> None: + pass + +[builtins fixtures/list.pyi] + [case testAccessingUnknownModule] import xyz xyz.foo() @@ -2074,6 +2091,14 @@ reveal_type(A().g) # E: Revealed type is 'def (x: builtins.list[builtins.int])' from typing import List [builtins fixtures/list.pyi] +[case testStarImportWithinCycle] +import a +[file a.py] +from b import f +[file b.py] +from a import * +def f() -> None: pass + [case testIndirectStarImportWithinCycle1] import a [file a.py]