diff --git a/tests/snippets/dir_module/__init__.py b/tests/snippets/dir_module/__init__.py new file mode 100644 index 0000000000..5d9faff33d --- /dev/null +++ b/tests/snippets/dir_module/__init__.py @@ -0,0 +1,2 @@ +from .relative import value +from .dir_module_inner import value2 diff --git a/tests/snippets/dir_module/dir_module_inner/__init__.py b/tests/snippets/dir_module/dir_module_inner/__init__.py new file mode 100644 index 0000000000..20e95590f1 --- /dev/null +++ b/tests/snippets/dir_module/dir_module_inner/__init__.py @@ -0,0 +1,3 @@ +from ..relative import value + +value2 = value + 2 diff --git a/tests/snippets/dir_module/relative.py b/tests/snippets/dir_module/relative.py new file mode 100644 index 0000000000..78ed77f40c --- /dev/null +++ b/tests/snippets/dir_module/relative.py @@ -0,0 +1 @@ +value = 5 diff --git a/tests/snippets/import_module.py b/tests/snippets/import_module.py new file mode 100644 index 0000000000..b82aad5091 --- /dev/null +++ b/tests/snippets/import_module.py @@ -0,0 +1,3 @@ +import dir_module +assert dir_module.value == 5 +assert dir_module.value2 == 7 diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 8ea1148547..7844902dc8 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -912,7 +912,9 @@ impl Frame { .iter() .map(|symbol| vm.ctx.new_str(symbol.to_string())) .collect(); - let module = vm.import(module, &vm.ctx.new_tuple(from_list))?; + let level = module.chars().take_while(|char| *char == '.').count(); + let module_name = &module[level..]; + let module = vm.import(module_name, &vm.ctx.new_tuple(from_list), level)?; if symbols.is_empty() { self.push_value(module); @@ -928,7 +930,8 @@ impl Frame { } fn import_star(&self, vm: &VirtualMachine, module: &str) -> FrameResult { - let module = vm.import(module, &vm.ctx.new_tuple(vec![]))?; + let level = module.chars().take_while(|char| *char == '.').count(); + let module = vm.import(module, &vm.ctx.new_tuple(vec![]), level)?; // Grab all the names from the module and put them in the context if let Some(dict) = &module.dict { diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 1857307588..9bcb77128c 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -136,7 +136,7 @@ impl VirtualMachine { pub fn try_class(&self, module: &str, class: &str) -> PyResult { let class = self - .get_attribute(self.import(module, &self.ctx.new_tuple(vec![]))?, class)? + .get_attribute(self.import(module, &self.ctx.new_tuple(vec![]), 0)?, class)? .downcast() .expect("not a class"); Ok(class) @@ -144,7 +144,7 @@ impl VirtualMachine { pub fn class(&self, module: &str, class: &str) -> PyClassRef { let module = self - .import(module, &self.ctx.new_tuple(vec![])) + .import(module, &self.ctx.new_tuple(vec![]), 0) .unwrap_or_else(|_| panic!("unable to import {}", module)); let class = self .get_attribute(module.clone(), class) @@ -302,7 +302,7 @@ impl VirtualMachine { TryFromObject::try_from_object(self, repr) } - pub fn import(&self, module: &str, from_list: &PyObjectRef) -> PyResult { + pub fn import(&self, module: &str, from_list: &PyObjectRef, level: usize) -> PyResult { let sys_modules = self .get_attribute(self.sys_module.clone(), "modules") .unwrap(); @@ -314,9 +314,14 @@ impl VirtualMachine { func, vec![ self.ctx.new_str(module.to_string()), - self.get_none(), + if self.current_frame().is_some() { + self.get_locals().into_object() + } else { + self.get_none() + }, self.get_none(), from_list.clone(), + self.ctx.new_int(level), ], ), Err(_) => Err(self.new_exception(