diff --git a/tests/snippets/class.py b/tests/snippets/class.py index 06e125f217..9dbce566dc 100644 --- a/tests/snippets/class.py +++ b/tests/snippets/class.py @@ -37,8 +37,7 @@ def kungfu(x): assert x == 3 -# TODO: -# assert Bar.__doc__ == " W00t " +assert Bar.__doc__ == " W00t " bar = Bar(42) @@ -117,3 +116,31 @@ def f(self): assert type(a) is super assert a.conjugate() == 1 + +class T1: + "test1" + +assert T1.__doc__ == "test1" + +class T2: + '''test2''' + +assert T2.__doc__ == "test2" + +class T3: + """ + test3 + """ + +assert T3.__doc__ == "\n test3\n " + +class T4: + + """test4""" + + def t1(self): + """t1""" + pass + +assert T4.__doc__ == "test4" +assert T4.t1.__doc__ == "t1" diff --git a/tests/snippets/function.py b/tests/snippets/function.py index b1ea968609..7d3f25709a 100644 --- a/tests/snippets/function.py +++ b/tests/snippets/function.py @@ -1,9 +1,39 @@ def foo(): + """test""" return 42 assert foo() == 42 +assert foo.__doc__ == "test" def my_func(a,): return a+2 assert my_func(2) == 4 + + +def f1(): + + """test1""" + pass + +assert f1.__doc__ == "test1" + +def f2(): + '''test2''' + pass + +assert f2.__doc__ == "test2" + +def f3(): + """ + test3 + """ + pass + +assert f3.__doc__ == "\n test3\n " + +def f4(): + "test4" + pass + +assert f4.__doc__ == "test4" diff --git a/vm/src/compile.rs b/vm/src/compile.rs index 00bafe7278..f79980af9f 100644 --- a/vm/src/compile.rs +++ b/vm/src/compile.rs @@ -597,7 +597,10 @@ impl Compiler { self.in_loop = false; self.in_function_def = true; let mut flags = self.enter_function(name, args)?; - self.compile_statements(body)?; + + let (new_body, doc_str) = get_doc(body); + + self.compile_statements(new_body)?; // Emit None at end: self.emit(Instruction::LoadConst { @@ -662,6 +665,20 @@ impl Compiler { self.emit(Instruction::StoreName { name: name.to_string(), }); + + if let Some(doc_string) = doc_str { + self.emit(Instruction::LoadConst { + value: bytecode::Constant::String { + value: doc_string.to_string(), + }, + }); + self.emit(Instruction::LoadName { + name: name.to_string(), + }); + self.emit(Instruction::StoreAttr { + name: "__doc__".to_string(), + }); + } self.in_loop = was_in_loop; self.in_function_def = was_in_function_def; Ok(()) @@ -689,13 +706,17 @@ impl Compiler { line_number, name.to_string(), )); - self.compile_statements(body)?; + + let (new_body, doc_str) = get_doc(body); + + self.compile_statements(new_body)?; self.emit(Instruction::LoadConst { value: bytecode::Constant::None, }); self.emit(Instruction::ReturnValue); let code = self.pop_code_object(); + self.emit(Instruction::LoadConst { value: bytecode::Constant::Code { code: Box::new(code), @@ -755,6 +776,19 @@ impl Compiler { self.emit(Instruction::StoreName { name: name.to_string(), }); + if let Some(doc_string) = doc_str { + self.emit(Instruction::LoadConst { + value: bytecode::Constant::String { + value: doc_string.to_string(), + }, + }); + self.emit(Instruction::LoadName { + name: name.to_string(), + }); + self.emit(Instruction::StoreAttr { + name: "__doc__".to_string(), + }); + } self.in_loop = was_in_loop; Ok(()) } @@ -1511,6 +1545,21 @@ impl Compiler { } } +fn get_doc(body: &[ast::LocatedStatement]) -> (&[ast::LocatedStatement], Option) { + if let Some(val) = body.get(0) { + if let ast::Statement::Expression { ref expression } = val.node { + if let ast::Expression::String { ref value } = expression { + if let ast::StringGroup::Constant { ref value } = value { + if let Some((_, body_rest)) = body.split_first() { + return (body_rest, Some(value.to_string())); + } + } + } + } + } + (body, None) +} + #[cfg(test)] mod tests { use super::Compiler;