diff --git a/extmod/uasyncio/__init__.py b/extmod/uasyncio/__init__.py index fa64438f6b2a0..3361159587f24 100644 --- a/extmod/uasyncio/__init__.py +++ b/extmod/uasyncio/__init__.py @@ -25,6 +25,6 @@ def __getattr__(attr): mod = _attrs.get(attr, None) if mod is None: raise AttributeError(attr) - value = getattr(__import__(mod, None, None, True, 1), attr) + value = getattr(__import__(mod, globals(), None, True, 1), attr) globals()[attr] = value return value diff --git a/py/builtinimport.c b/py/builtinimport.c index 094959f97d381..19d45d4bac673 100644 --- a/py/builtinimport.c +++ b/py/builtinimport.c @@ -257,7 +257,7 @@ STATIC void do_load(mp_module_context_t *module_obj, vstr_t *file) { // Convert a relative (to the current module) import, going up "level" levels, // into an absolute import. -STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, size_t *module_name_len) { +STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, size_t *module_name_len, mp_obj_t globals) { // What we want to do here is to take the name of the current module, // remove trailing components, and concatenate the passed-in // module name. @@ -266,7 +266,7 @@ STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, s // module's position in the package hierarchy." // http://legacy.python.org/dev/peps/pep-0328/#relative-imports-and-name - mp_obj_t current_module_name_obj = mp_obj_dict_get(MP_OBJ_FROM_PTR(mp_globals_get()), MP_OBJ_NEW_QSTR(MP_QSTR___name__)); + mp_obj_t current_module_name_obj = mp_obj_dict_get(globals, MP_OBJ_NEW_QSTR(MP_QSTR___name__)); assert(current_module_name_obj != MP_OBJ_NULL); #if MICROPY_MODULE_OVERRIDE_MAIN_IMPORT && MICROPY_CPYTHON_COMPAT @@ -274,12 +274,12 @@ STATIC void evaluate_relative_import(mp_int_t level, const char **module_name, s // This is a module loaded by -m command-line switch (e.g. unix port), // and so its __name__ has been set to "__main__". Get its real name // that we stored during import in the __main__ attribute. - current_module_name_obj = mp_obj_dict_get(MP_OBJ_FROM_PTR(mp_globals_get()), MP_OBJ_NEW_QSTR(MP_QSTR___main__)); + current_module_name_obj = mp_obj_dict_get(globals, MP_OBJ_NEW_QSTR(MP_QSTR___main__)); } #endif // If we have a __path__ in the globals dict, then we're a package. - bool is_pkg = mp_map_lookup(&mp_globals_get()->map, MP_OBJ_NEW_QSTR(MP_QSTR___path__), MP_MAP_LOOKUP); + bool is_pkg = mp_map_lookup(mp_obj_dict_get_map(globals), MP_OBJ_NEW_QSTR(MP_QSTR___path__), MP_MAP_LOOKUP); #if DEBUG_PRINT DEBUG_printf("Current module/package: "); @@ -480,6 +480,17 @@ mp_obj_t mp_builtin___import__(size_t n_args, const mp_obj_t *args) { // "from ...foo.bar import baz" --> module_name="foo.bar" mp_obj_t module_name_obj = args[0]; + // This is the dict with all global symbols. + mp_obj_t globals = mp_const_none; + if (n_args >= 2) { + globals = args[1]; + } + if (globals == mp_const_none) { + globals = MP_OBJ_FROM_PTR(mp_globals_get()); + } else if (!mp_obj_is_type(globals, &mp_type_dict)) { + mp_raise_TypeError(MP_ERROR_TEXT("globals must be dict")); + } + // These are the imported names. // i.e. "from foo.bar import baz, zap" --> fromtuple=("baz", "zap",) // Note: There's a special case on the Unix port, where this is set to mp_const_false which means that it's __main__. @@ -505,7 +516,7 @@ mp_obj_t mp_builtin___import__(size_t n_args, const mp_obj_t *args) { if (level != 0) { // Turn "foo.bar" into ".foo.bar". - evaluate_relative_import(level, &module_name, &module_name_len); + evaluate_relative_import(level, &module_name, &module_name_len, globals); } if (module_name_len == 0) { diff --git a/py/runtime.c b/py/runtime.c index 8c93f539e04e6..b1fed7096981c 100644 --- a/py/runtime.c +++ b/py/runtime.c @@ -1408,7 +1408,7 @@ mp_obj_t mp_import_name(qstr name, mp_obj_t fromlist, mp_obj_t level) { // build args array mp_obj_t args[5]; args[0] = MP_OBJ_NEW_QSTR(name); - args[1] = mp_const_none; // TODO should be globals + args[1] = MP_OBJ_FROM_PTR(mp_globals_get()); // globals of the current context args[2] = mp_const_none; // TODO should be locals args[3] = fromlist; args[4] = level; diff --git a/tests/import/import_override2.py b/tests/import/import_override2.py new file mode 100644 index 0000000000000..2cd2da21ebd55 --- /dev/null +++ b/tests/import/import_override2.py @@ -0,0 +1,29 @@ +# test overriding __import__ combined with importing from the filesystem + + +def custom_import(name, globals, locals, fromlist, level): + print("import", name, fromlist, level) + return orig_import(name, globals, locals, fromlist, level) + + +orig_import = __import__ +try: + __import__("builtins").__import__ = custom_import +except AttributeError: + print("SKIP") + raise SystemExit + +# import calls __import__ behind the scenes +import pkg7.subpkg1.subpkg2.mod3 + + +try: + # globals must be a dict or None, not a string + orig_import("builtins", "globals", None, None, 0) +except TypeError: + print("TypeError") +try: + # ... same for relative imports (level > 0) + orig_import("builtins", "globals", None, None, 1) +except TypeError: + print("TypeError") diff --git a/tests/import/import_override2.py.exp b/tests/import/import_override2.py.exp new file mode 100644 index 0000000000000..99dd8d7ac2fcc --- /dev/null +++ b/tests/import/import_override2.py.exp @@ -0,0 +1,15 @@ +import pkg7.subpkg1.subpkg2.mod3 None 0 +pkg __name__: pkg7 +pkg __name__: pkg7.subpkg1 +pkg __name__: pkg7.subpkg1.subpkg2 +import ('mod1',) 3 +import pkg7.mod1 True 0 +mod1 +import mod2 ('bar',) 3 +mod2 +mod1.foo +mod2.bar +import ('mod1',) 4 +ImportError +TypeError +TypeError