diff --git a/vm/src/compile.rs b/vm/src/compile.rs index 866b0dad74..9c455581e4 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -17,6 +17,7 @@ struct Compiler { source_path: Option, current_source_location: ast::Location, in_loop: bool, + in_function_def: bool, } /// Compile a given sourcecode into a bytecode object. @@ -75,6 +76,7 @@ impl Compiler { source_path: None, current_source_location: ast::Location::default(), in_loop: false, + in_function_def: false, } } @@ -232,9 +234,10 @@ impl Compiler { self.compile_test(test, None, Some(else_label), EvalContext::Statement)?; + let was_in_loop = self.in_loop; self.in_loop = true; self.compile_statements(body)?; - self.in_loop = false; + self.in_loop = was_in_loop; self.emit(Instruction::Jump { target: start_label, }); @@ -295,10 +298,10 @@ impl Compiler { // Start of loop iteration, set targets: self.compile_store(target)?; - // Body of loop: + let was_in_loop = self.in_loop; self.in_loop = true; self.compile_statements(body)?; - self.in_loop = false; + self.in_loop = was_in_loop; self.emit(Instruction::Jump { target: start_label, @@ -431,6 +434,11 @@ impl Compiler { decorator_list, } => { // Create bytecode for this function: + // remember to restore self.in_loop to the original after the function is compiled + let was_in_loop = self.in_loop; + let was_in_function_def = self.in_function_def; + self.in_loop = false; + self.in_function_def = true; let flags = self.enter_function(name, args)?; self.compile_statements(body)?; @@ -458,6 +466,8 @@ impl Compiler { self.emit(Instruction::StoreName { name: name.to_string(), }); + self.in_loop = was_in_loop; + self.in_function_def = was_in_function_def; } ast::Statement::ClassDef { name, @@ -466,6 +476,8 @@ impl Compiler { keywords, decorator_list, } => { + let was_in_loop = self.in_loop; + self.in_loop = false; self.prepare_decorators(decorator_list)?; self.emit(Instruction::LoadBuildClass); let line_number = self.get_source_line_number(); @@ -546,6 +558,7 @@ impl Compiler { self.emit(Instruction::StoreName { name: name.to_string(), }); + self.in_loop = was_in_loop; } ast::Statement::Assert { test, msg } => { // TODO: if some flag, ignore all assert statements! @@ -584,6 +597,9 @@ impl Compiler { self.emit(Instruction::Continue); } ast::Statement::Return { value } => { + if !self.in_function_def { + return Err(CompileError::InvalidReturn); + } match value { Some(e) => { let size = e.len(); @@ -663,7 +679,6 @@ impl Compiler { name: &str, args: &ast::Parameters, ) -> Result { - self.in_loop = false; let have_kwargs = !args.defaults.is_empty(); if have_kwargs { // Construct a tuple: @@ -971,6 +986,9 @@ impl Compiler { self.emit(Instruction::BuildSlice { size }); } ast::Expression::Yield { value } => { + if !self.in_function_def { + return Err(CompileError::InvalidYield); + } self.mark_generator(); match value { Some(expression) => self.compile_expression(expression)?, @@ -1021,6 +1039,7 @@ impl Compiler { } ast::Expression::Lambda { args, body } => { let name = "".to_string(); + // no need to worry about the self.loop_depth because there are no loops in lambda expressions let flags = self.enter_function(&name, args)?; self.compile_expression(body)?; self.emit(Instruction::ReturnValue); @@ -1362,10 +1381,11 @@ impl Compiler { // Low level helper functions: fn emit(&mut self, instruction: Instruction) { - self.current_code_object().instructions.push(instruction); - // TODO: insert source filename let location = self.current_source_location.clone(); - self.current_code_object().locations.push(location); + let mut cur_code_obj = self.current_code_object(); + cur_code_obj.instructions.push(instruction); + cur_code_obj.locations.push(location); + // TODO: insert source filename } fn current_code_object(&mut self) -> &mut CodeObject { @@ -1406,6 +1426,7 @@ mod tests { use crate::bytecode::Constant::*; use crate::bytecode::Instruction::*; use rustpython_parser::parser; + fn compile_exec(source: &str) -> CodeObject { let mut compiler = Compiler::new(); compiler.source_path = Some("source_path".to_string()); diff --git a/vm/src/error.rs b/vm/src/error.rs index 70c8ff71ec..06eb387e2e 100644 --- a/vm/src/error.rs +++ b/vm/src/error.rs @@ -19,6 +19,8 @@ pub enum CompileError { InvalidBreak, /// Continue statement outside of loop. InvalidContinue, + InvalidReturn, + InvalidYield, } impl fmt::Display for CompileError { @@ -29,8 +31,10 @@ impl fmt::Display for CompileError { CompileError::ExpectExpr => write!(f, "Expecting expression, got statement"), CompileError::Parse(err) => write!(f, "{}", err), CompileError::StarArgs => write!(f, "Two starred expressions in assignment"), - CompileError::InvalidBreak => write!(f, "break outside loop"), - CompileError::InvalidContinue => write!(f, "continue outside loop"), + CompileError::InvalidBreak => write!(f, "'break' outside loop"), + CompileError::InvalidContinue => write!(f, "'continue' outside loop"), + CompileError::InvalidReturn => write!(f, "'return' outside function"), + CompileError::InvalidYield => write!(f, "'yield' outside function"), } } }