Skip to content

Commit 070f5aa

Browse files
committed
Merge master into pyvaluepayload
2 parents e49d714 + 83788b9 commit 070f5aa

30 files changed

+755
-969
lines changed

Cargo.toml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,3 @@ rustpython_parser = {path = "parser"}
1515
rustpython_vm = {path = "vm"}
1616
rustyline = "2.1.0"
1717
xdg = "2.2.0"
18-
19-
[profile.release]
20-
opt-level = "s"

tests/snippets/builtin_dict.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
d = {}
99
d['a'] = d
1010
assert repr(d) == "{'a': {...}}"
11+
12+
assert {'a': 123}.get('a') == 123
13+
assert {'a': 123}.get('b') == None
14+
assert {'a': 123}.get('b', 456) == 456

tests/snippets/import.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,28 @@
2222
except ImportError:
2323
pass
2424

25+
26+
test = __import__("import_target")
27+
assert test.X == import_target.X
28+
29+
import builtins
30+
class OverrideImportContext():
31+
32+
def __enter__(self):
33+
self.original_import = builtins.__import__
34+
35+
def __exit__(self, exc_type, exc_val, exc_tb):
36+
builtins.__import__ = self.original_import
37+
38+
with OverrideImportContext():
39+
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
40+
return len(name)
41+
42+
builtins.__import__ = fake_import
43+
import test
44+
assert test == 4
45+
46+
2547
# TODO: Once we can determine current directory, use that to construct this
2648
# path:
2749
#import sys

vm/src/builtins.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
// use std::ops::Deref;
66
use std::char;
77
use std::io::{self, Write};
8+
use std::path::PathBuf;
89

910
use crate::compile;
11+
use crate::import::import_module;
1012
use crate::obj::objbool;
1113
use crate::obj::objdict;
1214
use crate::obj::objint;
@@ -674,7 +676,7 @@ fn builtin_round(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
674676
} else {
675677
// without a parameter, the result type is coerced to int
676678
let rounded = &vm.call_method(number, "__round__", vec![])?;
677-
Ok(vm.ctx.new_int(objint::get_value(rounded)))
679+
Ok(vm.ctx.new_int(objint::get_value(rounded).clone()))
678680
}
679681
}
680682

@@ -713,8 +715,27 @@ fn builtin_sum(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
713715
Ok(sum)
714716
}
715717

718+
// Should be renamed to builtin___import__?
719+
fn builtin_import(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
720+
arg_check!(
721+
vm,
722+
args,
723+
required = [(name, Some(vm.ctx.str_type()))],
724+
optional = [
725+
(_globals, Some(vm.ctx.dict_type())),
726+
(_locals, Some(vm.ctx.dict_type()))
727+
]
728+
);
729+
let current_path = {
730+
let mut source_pathbuf = PathBuf::from(&vm.current_frame().code.source_path);
731+
source_pathbuf.pop();
732+
source_pathbuf
733+
};
734+
735+
import_module(vm, current_path, &objstr::get_value(name))
736+
}
737+
716738
// builtin_vars
717-
// builtin___import__
718739

719740
pub fn make_module(ctx: &PyContext) -> PyObjectRef {
720741
let py_mod = py_module!(ctx, "__builtins__", {
@@ -783,6 +804,7 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
783804
"tuple" => ctx.tuple_type(),
784805
"type" => ctx.type_type(),
785806
"zip" => ctx.zip_type(),
807+
"__import__" => ctx.new_rustfunc(builtin_import),
786808

787809
// Constants
788810
"NotImplemented" => ctx.not_implemented.clone(),

vm/src/exceptions.rs

Lines changed: 22 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -108,56 +108,29 @@ pub struct ExceptionZoo {
108108
}
109109

110110
impl ExceptionZoo {
111-
pub fn new(
112-
type_type: &PyObjectRef,
113-
object_type: &PyObjectRef,
114-
dict_type: &PyObjectRef,
115-
) -> Self {
111+
pub fn new(type_type: &PyObjectRef, object_type: &PyObjectRef) -> Self {
116112
// Sorted By Hierarchy then alphabetized.
117-
let base_exception_type =
118-
create_type("BaseException", &type_type, &object_type, &dict_type);
119-
120-
let exception_type = create_type("Exception", &type_type, &base_exception_type, &dict_type);
121-
122-
let arithmetic_error =
123-
create_type("ArithmeticError", &type_type, &exception_type, &dict_type);
124-
let assertion_error =
125-
create_type("AssertionError", &type_type, &exception_type, &dict_type);
126-
let attribute_error =
127-
create_type("AttributeError", &type_type, &exception_type, &dict_type);
128-
let import_error = create_type("ImportError", &type_type, &exception_type, &dict_type);
129-
let index_error = create_type("IndexError", &type_type, &exception_type, &dict_type);
130-
let key_error = create_type("KeyError", &type_type, &exception_type, &dict_type);
131-
let name_error = create_type("NameError", &type_type, &exception_type, &dict_type);
132-
let os_error = create_type("OSError", &type_type, &exception_type, &dict_type);
133-
let runtime_error = create_type("RuntimeError", &type_type, &exception_type, &dict_type);
134-
let stop_iteration = create_type("StopIteration", &type_type, &exception_type, &dict_type);
135-
let syntax_error = create_type("SyntaxError", &type_type, &exception_type, &dict_type);
136-
let type_error = create_type("TypeError", &type_type, &exception_type, &dict_type);
137-
let value_error = create_type("ValueError", &type_type, &exception_type, &dict_type);
138-
139-
let overflow_error =
140-
create_type("OverflowError", &type_type, &arithmetic_error, &dict_type);
141-
let zero_division_error = create_type(
142-
"ZeroDivisionError",
143-
&type_type,
144-
&arithmetic_error,
145-
&dict_type,
146-
);
147-
148-
let module_not_found_error =
149-
create_type("ModuleNotFoundError", &type_type, &import_error, &dict_type);
150-
151-
let not_implemented_error = create_type(
152-
"NotImplementedError",
153-
&type_type,
154-
&runtime_error,
155-
&dict_type,
156-
);
157-
158-
let file_not_found_error =
159-
create_type("FileNotFoundError", &type_type, &os_error, &dict_type);
160-
let permission_error = create_type("PermissionError", &type_type, &os_error, &dict_type);
113+
let base_exception_type = create_type("BaseException", &type_type, &object_type);
114+
let exception_type = create_type("Exception", &type_type, &base_exception_type);
115+
let arithmetic_error = create_type("ArithmeticError", &type_type, &exception_type);
116+
let assertion_error = create_type("AssertionError", &type_type, &exception_type);
117+
let attribute_error = create_type("AttributeError", &type_type, &exception_type);
118+
let import_error = create_type("ImportError", &type_type, &exception_type);
119+
let index_error = create_type("IndexError", &type_type, &exception_type);
120+
let key_error = create_type("KeyError", &type_type, &exception_type);
121+
let name_error = create_type("NameError", &type_type, &exception_type);
122+
let os_error = create_type("OSError", &type_type, &exception_type);
123+
let runtime_error = create_type("RuntimeError", &type_type, &exception_type);
124+
let stop_iteration = create_type("StopIteration", &type_type, &exception_type);
125+
let syntax_error = create_type("SyntaxError", &type_type, &exception_type);
126+
let type_error = create_type("TypeError", &type_type, &exception_type);
127+
let value_error = create_type("ValueError", &type_type, &exception_type);
128+
let overflow_error = create_type("OverflowError", &type_type, &arithmetic_error);
129+
let zero_division_error = create_type("ZeroDivisionError", &type_type, &arithmetic_error);
130+
let module_not_found_error = create_type("ModuleNotFoundError", &type_type, &import_error);
131+
let not_implemented_error = create_type("NotImplementedError", &type_type, &runtime_error);
132+
let file_not_found_error = create_type("FileNotFoundError", &type_type, &os_error);
133+
let permission_error = create_type("PermissionError", &type_type, &os_error);
161134

162135
ExceptionZoo {
163136
arithmetic_error,

vm/src/frame.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use std::cell::RefCell;
22
use std::fmt;
3-
use std::path::PathBuf;
43
use std::rc::Rc;
54

65
use num_bigint::BigInt;
@@ -9,7 +8,6 @@ use rustpython_parser::ast;
98

109
use crate::builtins;
1110
use crate::bytecode;
12-
use crate::import::{import, import_module};
1311
use crate::obj::objbool;
1412
use crate::obj::objbuiltinfunc::PyBuiltinFunction;
1513
use crate::obj::objcode;
@@ -22,8 +20,8 @@ use crate::obj::objslice::PySlice;
2220
use crate::obj::objstr;
2321
use crate::obj::objtype;
2422
use crate::pyobject::{
25-
DictProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectRef, PyResult, PyValue,
26-
TryFromObject, TypeProtocol,
23+
AttributeProtocol, DictProtocol, IdProtocol, PyContext, PyFuncArgs, PyObject, PyObjectRef,
24+
PyResult, PyValue, TryFromObject, TypeProtocol,
2725
};
2826
use crate::vm::VirtualMachine;
2927

@@ -806,30 +804,31 @@ impl Frame {
806804
module: &str,
807805
symbol: &Option<String>,
808806
) -> FrameResult {
809-
let current_path = {
810-
let mut source_pathbuf = PathBuf::from(&self.code.source_path);
811-
source_pathbuf.pop();
812-
source_pathbuf
807+
let module = vm.import(module)?;
808+
809+
// If we're importing a symbol, look it up and use it, otherwise construct a module and return
810+
// that
811+
let obj = match symbol {
812+
Some(symbol) => module.get_attr(symbol).map_or_else(
813+
|| {
814+
let import_error = vm.context().exceptions.import_error.clone();
815+
Err(vm.new_exception(import_error, format!("cannot import name '{}'", symbol)))
816+
},
817+
Ok,
818+
),
819+
None => Ok(module),
813820
};
814821

815-
let obj = import(vm, current_path, module, symbol)?;
816-
817822
// Push module on stack:
818-
self.push_value(obj);
823+
self.push_value(obj?);
819824
Ok(None)
820825
}
821826

822827
fn import_star(&self, vm: &mut VirtualMachine, module: &str) -> FrameResult {
823-
let current_path = {
824-
let mut source_pathbuf = PathBuf::from(&self.code.source_path);
825-
source_pathbuf.pop();
826-
source_pathbuf
827-
};
828+
let module = vm.import(module)?;
828829

829830
// Grab all the names from the module and put them in the context
830-
let obj = import_module(vm, current_path, module)?;
831-
832-
for (k, v) in obj.get_key_value_pairs().iter() {
831+
for (k, v) in module.get_key_value_pairs().iter() {
833832
self.scope.store_name(&vm, &objstr::get_value(k), v.clone());
834833
}
835834
Ok(None)

vm/src/import.rs

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,28 +63,6 @@ pub fn import_module(
6363
Ok(module)
6464
}
6565

66-
pub fn import(
67-
vm: &mut VirtualMachine,
68-
current_path: PathBuf,
69-
module_name: &str,
70-
symbol: &Option<String>,
71-
) -> PyResult {
72-
let module = import_module(vm, current_path, module_name)?;
73-
// If we're importing a symbol, look it up and use it, otherwise construct a module and return
74-
// that
75-
if let Some(symbol) = symbol {
76-
module.get_attr(symbol).map_or_else(
77-
|| {
78-
let import_error = vm.context().exceptions.import_error.clone();
79-
Err(vm.new_exception(import_error, format!("cannot import name '{}'", symbol)))
80-
},
81-
Ok,
82-
)
83-
} else {
84-
Ok(module)
85-
}
86-
}
87-
8866
fn find_source(vm: &VirtualMachine, current_path: PathBuf, name: &str) -> Result<PathBuf, String> {
8967
let sys_path = vm.sys_module.get_attr("path").unwrap();
9068
let mut paths: Vec<PathBuf> = objsequence::get_elements(&sys_path)

vm/src/obj/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod objbool;
44
pub mod objbuiltinfunc;
55
pub mod objbytearray;
66
pub mod objbytes;
7+
pub mod objclassmethod;
78
pub mod objcode;
89
pub mod objcomplex;
910
pub mod objdict;
@@ -27,6 +28,7 @@ pub mod objrange;
2728
pub mod objsequence;
2829
pub mod objset;
2930
pub mod objslice;
31+
pub mod objstaticmethod;
3032
pub mod objstr;
3133
pub mod objsuper;
3234
pub mod objtuple;

vm/src/obj/objclassmethod.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use crate::pyobject::{AttributeProtocol, PyContext, PyFuncArgs, PyResult, TypeProtocol};
2+
use crate::vm::VirtualMachine;
3+
4+
pub fn init(context: &PyContext) {
5+
let classmethod_type = &context.classmethod_type;
6+
extend_class!(context, classmethod_type, {
7+
"__get__" => context.new_rustfunc(classmethod_get),
8+
"__new__" => context.new_rustfunc(classmethod_new)
9+
});
10+
}
11+
12+
fn classmethod_get(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
13+
trace!("classmethod.__get__ {:?}", args.args);
14+
arg_check!(
15+
vm,
16+
args,
17+
required = [
18+
(cls, Some(vm.ctx.classmethod_type())),
19+
(_inst, None),
20+
(owner, None)
21+
]
22+
);
23+
match cls.get_attr("function") {
24+
Some(function) => {
25+
let py_obj = owner.clone();
26+
let py_method = vm.ctx.new_bound_method(function, py_obj);
27+
Ok(py_method)
28+
}
29+
None => Err(vm.new_attribute_error(
30+
"Attribute Error: classmethod must have 'function' attribute".to_string(),
31+
)),
32+
}
33+
}
34+
35+
fn classmethod_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
36+
trace!("classmethod.__new__ {:?}", args.args);
37+
arg_check!(vm, args, required = [(cls, None), (callable, None)]);
38+
39+
let py_obj = vm.ctx.new_instance(cls.clone(), None);
40+
vm.ctx.set_attr(&py_obj, "function", callable.clone());
41+
Ok(py_obj)
42+
}

vm/src/obj/objdict.rs

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,27 @@ fn dict_getitem(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
338338
}
339339
}
340340

341-
pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: PyObjectRef) {
342-
// this is not ideal
343-
let ptr = PyObjectRef::into_raw(dict_type.clone()) as *mut PyObject;
344-
unsafe {
345-
(*ptr).payload = Box::new(objtype::PyClass {
346-
name: String::from("dict"),
347-
mro: vec![object_type],
348-
});
349-
(*ptr).dict = Some(RefCell::new(HashMap::new()));
350-
(*ptr).typ = Some(type_type.clone());
341+
fn dict_get(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
342+
arg_check!(
343+
vm,
344+
args,
345+
required = [
346+
(dict, Some(vm.ctx.dict_type())),
347+
(key, Some(vm.ctx.str_type()))
348+
],
349+
optional = [(default, None)]
350+
);
351+
352+
// What we are looking for:
353+
let key = objstr::get_value(&key);
354+
355+
let elements = get_elements(dict);
356+
if elements.contains_key(&key) {
357+
Ok(elements[&key].1.clone())
358+
} else if let Some(value) = default {
359+
Ok(value.clone())
360+
} else {
361+
Ok(vm.get_none())
351362
}
352363
}
353364

@@ -381,4 +392,5 @@ pub fn init(context: &PyContext) {
381392
context.set_attr(&dict_type, "values", context.new_rustfunc(dict_values));
382393
context.set_attr(&dict_type, "items", context.new_rustfunc(dict_items));
383394
context.set_attr(&dict_type, "keys", context.new_rustfunc(dict_iter));
395+
context.set_attr(&dict_type, "get", context.new_rustfunc(dict_get));
384396
}

vm/src/obj/objenumerate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn enumerate_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
3030
optional = [(start, Some(vm.ctx.int_type()))]
3131
);
3232
let counter = if let Some(x) = start {
33-
objint::get_value(x)
33+
objint::get_value(x).clone()
3434
} else {
3535
BigInt::zero()
3636
};

0 commit comments

Comments
 (0)