diff --git a/parser/src/parser.rs b/parser/src/parser.rs index 853d7cd844..607c6e7a45 100644 --- a/parser/src/parser.rs +++ b/parser/src/parser.rs @@ -253,6 +253,40 @@ mod tests { ) } + #[test] + fn test_parse_tuples() { + let source = String::from("a, b = 4, 5\n"); + + assert_eq!( + parse_statement(&source), + Ok(ast::LocatedStatement { + location: ast::Location::new(1, 1), + node: ast::Statement::Assign { + targets: vec![ast::Expression::Tuple { + elements: vec![ + ast::Expression::Identifier { + name: "a".to_string() + }, + ast::Expression::Identifier { + name: "b".to_string() + } + ] + }], + value: ast::Expression::Tuple { + elements: vec![ + ast::Expression::Number { + value: ast::Number::Integer { value: 4 } + }, + ast::Expression::Number { + value: ast::Number::Integer { value: 5 } + } + ] + } + } + }) + ) + } + #[test] fn test_parse_class() { let source = String::from("class Foo(A, B):\n def __init__(self):\n pass\n def method_with_default(self, arg='default'):\n pass\n"); diff --git a/parser/src/python.lalrpop b/parser/src/python.lalrpop index b7238e8b72..98170ca048 100644 --- a/parser/src/python.lalrpop +++ b/parser/src/python.lalrpop @@ -46,32 +46,38 @@ SmallStatement: ast::LocatedStatement = { }; ExpressionStatement: ast::LocatedStatement = { - => { - //match e2 { - // None => ast::Statement::Expression { expression: e }, - // Some(e3) => ast::Statement::Expression { expression: e }, - //} - if e2.len() > 0 { - // Dealing with assignment here - // TODO: for rhs in e2 { - let rhs = e2.into_iter().next().unwrap(); - // ast::Expression::Tuple { elements: e2.into_iter().next().unwrap() - let v = rhs.into_iter().next().unwrap(); - let lhs = ast::LocatedStatement { - location: loc.clone(), - node: ast::Statement::Assign { targets: e, value: v }, - }; - lhs - } else { - if e.len() > 1 { - panic!("Not good?"); - // ast::Statement::Expression { expression: e[0] } + => { + // Just an expression, no assignment: + if suffix.is_empty() { + if expr.len() > 1 { + ast::LocatedStatement { + location: loc.clone(), + node: ast::Statement::Expression { expression: ast::Expression::Tuple { elements: expr } } + } } else { ast::LocatedStatement { location: loc.clone(), - node: ast::Statement::Expression { expression: e.into_iter().next().unwrap() }, + node: ast::Statement::Expression { expression: expr[0].clone() }, } } + } else { + let mut targets = vec![if expr.len() > 1 { + ast::Expression::Tuple { elements: expr } + } else { + expr[0].clone() + }]; + let mut values : Vec = suffix.into_iter().map(|test_list| if test_list.len() > 1 { ast::Expression::Tuple { elements: test_list }} else { test_list[0].clone() }).collect(); + + while values.len() > 1 { + targets.push(values.remove(0)); + } + + let value = values[0].clone(); + + ast::LocatedStatement { + location: loc.clone(), + node: ast::Statement::Assign { targets, value }, + } } }, => { @@ -120,7 +126,7 @@ FlowStatement: ast::LocatedStatement = { "return" => { ast::LocatedStatement { location: loc, - node: ast::Statement::Return { value: t}, + node: ast::Statement::Return { value: t }, } }, "raise" => { diff --git a/tests/snippets/assignment.py b/tests/snippets/assignment.py new file mode 100644 index 0000000000..457da92005 --- /dev/null +++ b/tests/snippets/assignment.py @@ -0,0 +1,28 @@ +x = 1 +assert x == 1 + +x = 1, 2, 3 +assert x == (1, 2, 3) + +x, y = 1, 2 +assert x == 1 +assert y == 2 + +x, y = (y, x) + +assert x == 2 +assert y == 1 + +((x, y), z) = ((1, 2), 3) + +assert (x, y, z) == (1, 2, 3) + +q = (1, 2, 3) +(x, y, z) = q +assert y == q[1] + +x = (a, b, c) = y = q + +assert (a, b, c) == q +assert x == q +assert y == q diff --git a/vm/src/bytecode.rs b/vm/src/bytecode.rs index c0d5d286d8..22f0fb2a80 100644 --- a/vm/src/bytecode.rs +++ b/vm/src/bytecode.rs @@ -137,6 +137,9 @@ pub enum Instruction { PrintExpr, LoadBuildClass, StoreLocals, + UnpackSequence { + size: usize, + }, } #[derive(Debug, Clone, PartialEq)] diff --git a/vm/src/compile.rs b/vm/src/compile.rs index 3208f97007..702057fb9d 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -508,7 +508,10 @@ impl Compiler { ast::Statement::Assign { targets, value } => { self.compile_expression(value); - for target in targets { + for (i, target) in targets.into_iter().enumerate() { + if i + 1 != targets.len() { + self.emit(Instruction::Duplicate); + } self.compile_store(target); } } @@ -548,6 +551,14 @@ impl Compiler { name: name.to_string(), }); } + ast::Expression::Tuple { elements } => { + self.emit(Instruction::UnpackSequence { + size: elements.len(), + }); + for element in elements { + self.compile_store(element); + } + } _ => { panic!("WTF: {:?}", target); } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index a50b2eb4b8..3730362f82 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -18,6 +18,7 @@ use super::import::import; use super::obj::objlist; use super::obj::objobject; use super::obj::objstr; +use super::obj::objtuple; use super::obj::objtype; use super::objbool; use super::pyobject::{ @@ -1045,6 +1046,19 @@ impl VirtualMachine { } None } + bytecode::Instruction::UnpackSequence { size } => { + let value = self.pop_value(); + + let elements = objtuple::get_elements(&value); + if elements.len() != *size { + panic!("Wrong number of values to unpack"); + } + + for element in elements.into_iter().rev() { + self.push_value(element); + } + None + } } }