diff --git a/tests/snippets/ast_snippet.py b/tests/snippets/ast_snippet.py index 033d596966..43bf74756b 100644 --- a/tests/snippets/ast_snippet.py +++ b/tests/snippets/ast_snippet.py @@ -5,13 +5,19 @@ source = """ def foo(): print('bar') + pass """ n = ast.parse(source) print(n) print(n.body) print(n.body[0].name) assert n.body[0].name == 'foo' -print(n.body[0].body) -print(n.body[0].body[0]) -print(n.body[0].body[0].value.func.id) -assert n.body[0].body[0].value.func.id == 'print' +foo = n.body[0] +assert foo.lineno == 2 +print(foo.body) +assert len(foo.body) == 2 +print(foo.body[0]) +print(foo.body[0].value.func.id) +assert foo.body[0].value.func.id == 'print' +assert foo.body[0].lineno == 3 +assert foo.body[1].lineno == 4 diff --git a/tests/snippets/iterations.py b/tests/snippets/iterations.py new file mode 100644 index 0000000000..98031a935e --- /dev/null +++ b/tests/snippets/iterations.py @@ -0,0 +1,11 @@ + + +ls = [1, 2, 3] + +i = iter(ls) +assert i.__next__() == 1 +assert i.__next__() == 2 +assert next(i) == 3 + +assert next(i, 'w00t') == 'w00t' + diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 4ecdc2b1b6..297187d2ed 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::io::{self, Write}; use super::compile; +use super::obj::objiter; use super::obj::objstr; use super::obj::objtype; use super::objbool; @@ -221,7 +222,10 @@ fn builtin_issubclass(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.context().new_bool(objtype::issubclass(cls1, cls2))) } -// builtin_iter +fn builtin_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter_target, None)]); + objiter::get_iter(vm, iter_target) +} fn builtin_len(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, None)]); @@ -254,7 +258,30 @@ fn builtin_locals(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { // builtin_max // builtin_memoryview // builtin_min -// builtin_next + +fn builtin_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(iterator, None)], + optional = [(default_value, None)] + ); + + match vm.call_method(iterator.clone(), "__next__", vec![]) { + Ok(value) => Ok(value), + Err(value) => { + if objtype::isinstance(&value, vm.ctx.exceptions.stop_iteration.clone()) { + match default_value { + None => Err(value), + Some(value) => Ok(value.clone()), + } + } else { + Err(value) + } + } + } +} + // builtin_object // builtin_oct // builtin_open @@ -378,9 +405,11 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { String::from("issubclass"), ctx.new_rustfunc(builtin_issubclass), ); + dict.insert(String::from("iter"), ctx.new_rustfunc(builtin_iter)); dict.insert(String::from("len"), ctx.new_rustfunc(builtin_len)); dict.insert(String::from("list"), ctx.list_type()); dict.insert(String::from("locals"), ctx.new_rustfunc(builtin_locals)); + dict.insert(String::from("next"), ctx.new_rustfunc(builtin_next)); dict.insert(String::from("pow"), ctx.new_rustfunc(builtin_pow)); dict.insert(String::from("print"), ctx.new_rustfunc(builtin_print)); dict.insert(String::from("range"), ctx.new_rustfunc(builtin_range)); diff --git a/vm/src/exceptions.rs b/vm/src/exceptions.rs index f962515d4e..6de078f0e6 100644 --- a/vm/src/exceptions.rs +++ b/vm/src/exceptions.rs @@ -89,6 +89,7 @@ pub struct ExceptionZoo { pub name_error: PyObjectRef, pub runtime_error: PyObjectRef, pub not_implemented_error: PyObjectRef, + pub stop_iteration: PyObjectRef, pub type_error: PyObjectRef, pub value_error: PyObjectRef, } @@ -138,6 +139,12 @@ impl ExceptionZoo { &runtime_error, &dict_type, ); + let stop_iteration = create_type( + &String::from("StopIteration"), + &type_type, + &exception_type, + &dict_type, + ); let type_error = create_type( &String::from("TypeError"), &type_type, @@ -159,6 +166,7 @@ impl ExceptionZoo { name_error: name_error, runtime_error: runtime_error, not_implemented_error: not_implemented_error, + stop_iteration: stop_iteration, type_error: type_error, value_error: value_error, } diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index 8dfa798ec7..f49b4ac358 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -3,6 +3,7 @@ pub mod objdict; pub mod objfloat; pub mod objfunction; pub mod objint; +pub mod objiter; pub mod objlist; pub mod objobject; pub mod objsequence; diff --git a/vm/src/obj/objiter.rs b/vm/src/obj/objiter.rs new file mode 100644 index 0000000000..09afd8ce37 --- /dev/null +++ b/vm/src/obj/objiter.rs @@ -0,0 +1,92 @@ +/* + * Various types to support iteration. + */ + +use super::super::pyobject::{ + AttributeProtocol, PyContext, PyFuncArgs, PyObject, PyObjectKind, PyObjectRef, PyResult, + TypeProtocol, +}; +use super::super::vm::VirtualMachine; +use super::objstr; +use super::objtype; // Required for arg_check! to use isinstance + +/* + * This helper function is called at multiple places. First, it is called + * in the vm when a for loop is entered. Next, it is used when the builtin + * function 'iter' is called. + */ +pub fn get_iter(vm: &mut VirtualMachine, iter_target: &PyObjectRef) -> PyResult { + // Check what we are going to iterate over: + let iterated_obj = if objtype::isinstance(iter_target, vm.ctx.iter_type()) { + // If object is already an iterator, return that one. + return Ok(iter_target.clone()); + } else if objtype::isinstance(iter_target, vm.ctx.list_type()) { + iter_target.clone() + } else { + let type_str = objstr::get_value(&vm.to_str(iter_target.typ()).unwrap()); + let type_error = vm.new_type_error(format!("Cannot iterate over {}", type_str)); + return Err(type_error); + }; + + let iter_obj = PyObject::new( + PyObjectKind::Iterator { + position: 0, + iterated_obj: iterated_obj, + }, + vm.ctx.iter_type(), + ); + + // We are all good here: + Ok(iter_obj) +} + +// Sequence iterator: +fn iter_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter_target, None)]); + + get_iter(vm, iter_target) +} + +fn iter_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); + // Return self: + Ok(iter.clone()) +} + +fn iter_next(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(iter, Some(vm.ctx.iter_type()))]); + + if let PyObjectKind::Iterator { + ref mut position, + iterated_obj: ref iterated_obj_ref, + } = iter.borrow_mut().kind + { + let iterated_obj = &*iterated_obj_ref.borrow_mut(); + match iterated_obj.kind { + PyObjectKind::List { ref elements } => { + if *position < elements.len() { + let obj_ref = elements[*position].clone(); + *position += 1; + Ok(obj_ref) + } else { + let stop_iteration_type = vm.ctx.exceptions.stop_iteration.clone(); + let stop_iteration = + vm.new_exception(stop_iteration_type, "End of iterator".to_string()); + Err(stop_iteration) + } + } + _ => { + panic!("NOT IMPL"); + } + } + } else { + panic!("NOT IMPL"); + } +} + +pub fn init(context: &PyContext) { + let ref iter_type = context.iter_type; + iter_type.set_attr("__new__", context.new_rustfunc(iter_new)); + iter_type.set_attr("__iter__", context.new_rustfunc(iter_iter)); + iter_type.set_attr("__next__", context.new_rustfunc(iter_next)); +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index cff8cca82a..5a66135cec 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -5,6 +5,7 @@ use super::obj::objdict; use super::obj::objfloat; use super::obj::objfunction; use super::obj::objint; +use super::obj::objiter; use super::obj::objlist; use super::obj::objobject; use super::obj::objstr; @@ -60,6 +61,7 @@ pub struct PyContext { pub false_value: PyObjectRef, pub list_type: PyObjectRef, pub tuple_type: PyObjectRef, + pub iter_type: PyObjectRef, pub str_type: PyObjectRef, pub function_type: PyObjectRef, pub module_type: PyObjectRef, @@ -123,6 +125,7 @@ impl PyContext { let float_type = create_type("float", &type_type, &object_type, &dict_type); let bytes_type = create_type("bytes", &type_type, &object_type, &dict_type); let tuple_type = create_type("tuple", &type_type, &object_type, &dict_type); + let iter_type = create_type("iter", &type_type, &object_type, &dict_type); let bool_type = create_type("bool", &type_type, &int_type, &dict_type); let exceptions = exceptions::ExceptionZoo::new(&type_type, &object_type, &dict_type); @@ -142,6 +145,7 @@ impl PyContext { true_value: true_value, false_value: false_value, tuple_type: tuple_type, + iter_type: iter_type, dict_type: dict_type, none: none, str_type: str_type, @@ -164,6 +168,7 @@ impl PyContext { objbytes::init(&context); objstr::init(&context); objtuple::init(&context); + objiter::init(&context); objbool::init(&context); exceptions::init(&context); context @@ -190,6 +195,9 @@ impl PyContext { pub fn tuple_type(&self) -> PyObjectRef { self.tuple_type.clone() } + pub fn iter_type(&self) -> PyObjectRef { + self.iter_type.clone() + } pub fn dict_type(&self) -> PyObjectRef { self.dict_type.clone() } @@ -750,35 +758,6 @@ impl PyObject { } } - // Implement iterator protocol: - pub fn nxt(&mut self) -> Option { - match self.kind { - PyObjectKind::Iterator { - ref mut position, - iterated_obj: ref iterated_obj_ref, - } => { - let iterated_obj = &*iterated_obj_ref.borrow_mut(); - match iterated_obj.kind { - PyObjectKind::List { ref elements } => { - if *position < elements.len() { - let obj_ref = elements[*position].clone(); - *position += 1; - Some(obj_ref) - } else { - None - } - } - _ => { - panic!("NOT IMPL"); - } - } - } - _ => { - panic!("NOT IMPL"); - } - } - } - // Move this object into a reference object, transferring ownership. pub fn into_ref(self) -> PyObjectRef { Rc::new(RefCell::new(self)) diff --git a/vm/src/vm.rs b/vm/src/vm.rs index a50b2eb4b8..eba749fbb6 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -7,14 +7,13 @@ extern crate rustpython_parser; use self::rustpython_parser::ast; -use std::cell::RefMut; use std::collections::hash_map::HashMap; -use std::ops::Deref; use super::builtins; use super::bytecode; use super::frame::{copy_code, Block, Frame}; use super::import::import; +use super::obj::objiter; use super::obj::objlist; use super::obj::objobject; use super::obj::objstr; @@ -829,46 +828,50 @@ impl VirtualMachine { } bytecode::Instruction::GetIter => { let iterated_obj = self.pop_value(); - let iter_obj = PyObject::new( - PyObjectKind::Iterator { - position: 0, - iterated_obj: iterated_obj, - }, - self.ctx.type_type(), - ); - self.push_value(iter_obj); - None + match objiter::get_iter(self, &iterated_obj) { + Ok(iter_obj) => { + self.push_value(iter_obj); + None + } + Err(err) => Some(Err(err)), + } } bytecode::Instruction::ForIter => { // The top of stack contains the iterator, lets push it forward: - let next_obj: Option = { + let next_obj: PyResult = { let top_of_stack = self.last_value(); - let mut ref_mut: RefMut = top_of_stack.deref().borrow_mut(); - // We require a mutable pyobject here to update the iterator: - let mut iterator = ref_mut; // &mut PyObject = ref_mut.; - // let () = iterator; - iterator.nxt() + self.call_method(top_of_stack, "__next__", vec![]) }; // Check the next object: match next_obj { - Some(value) => { + Ok(value) => { self.push_value(value); + None } - None => { - // Pop iterator from stack: - self.pop_value(); - - // End of for loop - let end_label = if let Block::Loop { start: _, end } = self.last_block() { - *end + Err(next_error) => { + // Check if we have stopiteration, or something else: + if objtype::isinstance( + &next_error, + self.ctx.exceptions.stop_iteration.clone(), + ) { + // Pop iterator from stack: + self.pop_value(); + + // End of for loop + let end_label = if let Block::Loop { start: _, end } = self.last_block() + { + *end + } else { + panic!("Wrong block type") + }; + self.jump(&end_label); + None } else { - panic!("Wrong block type") - }; - self.jump(&end_label); + Some(Err(next_error)) + } } - }; - None + } } bytecode::Instruction::MakeFunction { flags } => { let _qualified_name = self.pop_value();