diff --git a/tests/snippets/import.py b/tests/snippets/import.py index 0a36e63429..a481e3f1ae 100644 --- a/tests/snippets/import.py +++ b/tests/snippets/import.py @@ -22,6 +22,28 @@ except ImportError: pass + +test = __import__("import_target") +assert test.X == import_target.X + +import builtins +class OverrideImportContext(): + + def __enter__(self): + self.original_import = builtins.__import__ + + def __exit__(self, exc_type, exc_val, exc_tb): + builtins.__import__ = self.original_import + +with OverrideImportContext(): + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + return len(name) + + builtins.__import__ = fake_import + import test + assert test == 4 + + # TODO: Once we can determine current directory, use that to construct this # path: #import sys diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 5a0c151957..9611f63276 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -5,8 +5,10 @@ // use std::ops::Deref; use std::char; use std::io::{self, Write}; +use std::path::PathBuf; use crate::compile; +use crate::import::import_module; use crate::obj::objbool; use crate::obj::objdict; use crate::obj::objint; @@ -713,8 +715,27 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(sum) } +// Should be renamed to builtin___import__? +fn builtin_import(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(name, Some(vm.ctx.str_type()))], + optional = [ + (_globals, Some(vm.ctx.dict_type())), + (_locals, Some(vm.ctx.dict_type())) + ] + ); + let current_path = { + let mut source_pathbuf = PathBuf::from(&vm.current_frame().code.source_path); + source_pathbuf.pop(); + source_pathbuf + }; + + import_module(vm, current_path, &objstr::get_value(name)) +} + // builtin_vars -// builtin___import__ pub fn make_module(ctx: &PyContext) -> PyObjectRef { let py_mod = py_module!(ctx, "__builtins__", { @@ -783,6 +804,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { "tuple" => ctx.tuple_type(), "type" => ctx.type_type(), "zip" => ctx.zip_type(), + "__import__" => ctx.new_rustfunc(builtin_import), // Constants "NotImplemented" => ctx.not_implemented.clone(), diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 48d3d6229c..2519733cb7 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -1,6 +1,5 @@ use std::cell::RefCell; use std::fmt; -use std::path::PathBuf; use std::rc::Rc; use num_bigint::BigInt; @@ -9,7 +8,6 @@ use rustpython_parser::ast; use crate::builtins; use crate::bytecode; -use crate::import::{import, import_module}; use crate::obj::objbool; use crate::obj::objbuiltinfunc::PyBuiltinFunction; use crate::obj::objcode; @@ -22,8 +20,8 @@ use crate::obj::objslice::PySlice; use crate::obj::objstr; use crate::obj::objtype; use crate::pyobject::{ - DictProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue, - TryFromObject, TypeProtocol, + AttributeProtocol, DictProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectRef, + PyResult, PyValue, TryFromObject, TypeProtocol, }; use crate::vm::VirtualMachine; @@ -806,30 +804,31 @@ impl Frame { module: &str, symbol: &Option, ) -> FrameResult { - let current_path = { - let mut source_pathbuf = PathBuf::from(&self.code.source_path); - source_pathbuf.pop(); - source_pathbuf + let module = vm.import(module)?; + + // If we're importing a symbol, look it up and use it, otherwise construct a module and return + // that + let obj = match symbol { + Some(symbol) => module.get_attr(symbol).map_or_else( + || { + let import_error = vm.context().exceptions.import_error.clone(); + Err(vm.new_exception(import_error, format!("cannot import name '{}'", symbol))) + }, + Ok, + ), + None => Ok(module), }; - let obj = import(vm, current_path, module, symbol)?; - // Push module on stack: - self.push_value(obj); + self.push_value(obj?); Ok(None) } fn import_star(&self, vm: &mut VirtualMachine, module: &str) -> FrameResult { - let current_path = { - let mut source_pathbuf = PathBuf::from(&self.code.source_path); - source_pathbuf.pop(); - source_pathbuf - }; + let module = vm.import(module)?; // Grab all the names from the module and put them in the context - let obj = import_module(vm, current_path, module)?; - - for (k, v) in obj.get_key_value_pairs().iter() { + for (k, v) in module.get_key_value_pairs().iter() { self.scope.store_name(&vm, &objstr::get_value(k), v.clone()); } Ok(None) diff --git a/vm/src/import.rs b/vm/src/import.rs index 3031ed38b7..a36168c287 100644 --- a/vm/src/import.rs +++ b/vm/src/import.rs @@ -63,28 +63,6 @@ pub fn import_module( Ok(module) } -pub fn import( - vm: &mut VirtualMachine, - current_path: PathBuf, - module_name: &str, - symbol: &Option, -) -> PyResult { - let module = import_module(vm, current_path, module_name)?; - // If we're importing a symbol, look it up and use it, otherwise construct a module and return - // that - if let Some(symbol) = symbol { - module.get_attr(symbol).map_or_else( - || { - let import_error = vm.context().exceptions.import_error.clone(); - Err(vm.new_exception(import_error, format!("cannot import name '{}'", symbol))) - }, - Ok, - ) - } else { - Ok(module) - } -} - fn find_source(vm: &VirtualMachine, current_path: PathBuf, name: &str) -> Result { let sys_path = vm.sys_module.get_attr("path").unwrap(); let mut paths: Vec = objsequence::get_elements(&sys_path) diff --git a/vm/src/vm.rs b/vm/src/vm.rs index 5dca659467..851345f1d4 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -13,8 +13,7 @@ use std::sync::{Mutex, MutexGuard}; use crate::builtins; use crate::bytecode; -use crate::frame::ExecutionResult; -use crate::frame::Scope; +use crate::frame::{ExecutionResult, Frame, Scope}; use crate::obj::objbool; use crate::obj::objbuiltinfunc::PyBuiltinFunction; use crate::obj::objcode; @@ -94,9 +93,13 @@ impl VirtualMachine { result } - pub fn current_scope(&self) -> &Scope { + pub fn current_frame(&self) -> &Frame { let current_frame = &self.frames[self.frames.len() - 1]; - let frame = objframe::get_value(current_frame); + objframe::get_value(current_frame) + } + + pub fn current_scope(&self) -> &Scope { + let frame = self.current_frame(); &frame.scope } @@ -231,6 +234,17 @@ impl VirtualMachine { self.call_method(obj, "__repr__", vec![]) } + pub fn import(&mut self, module: &str) -> PyResult { + let builtins_import = self.builtins.get_item("__import__"); + match builtins_import { + Some(func) => self.invoke(func, vec![self.ctx.new_str(module.to_string())]), + None => Err(self.new_exception( + self.ctx.exceptions.import_error.clone(), + "__import__ not found".to_string(), + )), + } + } + /// Determines if `obj` is an instance of `cls`, either directly, indirectly or virtually via /// the __instancecheck__ magic method. pub fn isinstance(&mut self, obj: &PyObjectRef, cls: &PyObjectRef) -> PyResult { diff --git a/wasm/lib/src/browser_module.rs b/wasm/lib/src/browser_module.rs index 6469d64b1f..035954a7c7 100644 --- a/wasm/lib/src/browser_module.rs +++ b/wasm/lib/src/browser_module.rs @@ -4,9 +4,10 @@ use js_sys::Promise; use num_traits::cast::ToPrimitive; use rustpython_vm::obj::{objint, objstr}; use rustpython_vm::pyobject::{ - PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue, TypeProtocol, + AttributeProtocol, PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue, + TypeProtocol, }; -use rustpython_vm::{import::import, VirtualMachine}; +use rustpython_vm::{import::import_module, VirtualMachine}; use std::path::PathBuf; use wasm_bindgen::{prelude::*, JsCast}; use wasm_bindgen_futures::{future_to_promise, JsFuture}; @@ -177,12 +178,10 @@ pub fn get_promise_value(obj: &PyObjectRef) -> Promise { } pub fn import_promise_type(vm: &mut VirtualMachine) -> PyResult { - import( - vm, - PathBuf::default(), - BROWSER_NAME, - &Some("Promise".into()), - ) + match import_module(vm, PathBuf::default(), BROWSER_NAME)?.get_attr("Promise".into()) { + Some(promise) => Ok(promise), + None => Err(vm.new_not_implemented_error("No Promise".to_string())), + } } fn promise_then(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/wasm/lib/src/convert.rs b/wasm/lib/src/convert.rs index 2f0640f1b6..a76f802447 100644 --- a/wasm/lib/src/convert.rs +++ b/wasm/lib/src/convert.rs @@ -128,13 +128,10 @@ pub fn py_to_js(vm: &mut VirtualMachine, py_obj: PyObjectRef) -> JsValue { } arr.into() } else { - let dumps = rustpython_vm::import::import( - vm, - std::path::PathBuf::default(), - "json", - &Some("dumps".into()), - ) - .expect("Couldn't get json.dumps function"); + let dumps = rustpython_vm::import::import_module(vm, std::path::PathBuf::default(), "json") + .expect("Couldn't get json module") + .get_attr("dumps".into()) + .expect("Couldn't get json dumps"); match vm.invoke(dumps, pyobject::PyFuncArgs::new(vec![py_obj], vec![])) { Ok(value) => { let json = vm.to_pystr(&value).unwrap(); @@ -231,13 +228,10 @@ pub fn js_to_py(vm: &mut VirtualMachine, js_val: JsValue) -> PyObjectRef { // Because `JSON.stringify(undefined)` returns undefined vm.get_none() } else { - let loads = rustpython_vm::import::import( - vm, - std::path::PathBuf::default(), - "json", - &Some("loads".into()), - ) - .expect("json.loads function to be available"); + let loads = rustpython_vm::import::import_module(vm, std::path::PathBuf::default(), "json") + .expect("Couldn't get json module") + .get_attr("loads".into()) + .expect("Couldn't get json dumps"); let json = match js_sys::JSON::stringify(&js_val) { Ok(json) => String::from(json),