From 6230a25c4bfbd14b4b65017f6fe0818d2aa46c64 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Tue, 19 Mar 2019 22:50:58 +0200 Subject: [PATCH 1/3] Use first argument in super --- tests/snippets/class.py | 13 +++++++++++++ vm/src/obj/objsuper.rs | 42 ++++++++++++++++++++++++++--------------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/tests/snippets/class.py b/tests/snippets/class.py index 84ed872460..ee766f38f8 100644 --- a/tests/snippets/class.py +++ b/tests/snippets/class.py @@ -75,6 +75,19 @@ def test1(self): assert c.test() == 100 assert c.test1() == 200 +class Me(): + + def test(me): + return 100 + +class Me2(Me): + + def test(me): + return super().test() + +me = Me2() +assert me.test() == 100 + a = super(bool, True) assert isinstance(a, super) assert type(a) is super diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index 390e591833..ebb9e8d950 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -7,6 +7,7 @@ https://github.com/python/cpython/blob/50b48572d9a90c5bb36e2bef6179548ea927a35a/ */ use crate::function::PyFuncArgs; +use crate::obj::objstr; use crate::obj::objtype::PyClass; use crate::pyobject::{ DictProtocol, PyContext, PyObject, PyObjectRef, PyResult, PyValue, TypeProtocol, @@ -75,7 +76,11 @@ fn super_getattribute(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { return Ok(vm.ctx.new_bound_method(item, inst.clone())); } } - Err(vm.new_attribute_error(format!("{} has no attribute '{}'", inst, name_str))) + Err(vm.new_attribute_error(format!( + "{} has no attribute '{}'", + inst, + objstr::get_value(name_str) + ))) } _ => panic!("not Class"), } @@ -94,14 +99,31 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { return Err(vm.new_type_error(format!("{:?} is not a subtype of super", cls))); } + // Get the bound object: + let py_obj = if let Some(obj) = py_obj { + obj.clone() + } else { + let frame = vm.current_frame(); + if let Some(first_arg) = frame.code.arg_names.get(0) { + match vm.get_locals().get_item(first_arg) { + Some(obj) => obj.clone(), + _ => { + return Err(vm + .new_type_error(format!("super arguement {} was not supplied", first_arg))); + } + } + } else { + return Err(vm.new_type_error( + "super must be called with 1 argument or from inside class method".to_string(), + )); + } + }; + // Get the type: let py_type = if let Some(ty) = py_type { ty.clone() } else { - match vm.get_locals().get_item("self") { - Some(obj) => obj.typ().clone(), - _ => panic!("No self"), - } + py_obj.typ().clone() }; // Check type argument: @@ -113,16 +135,6 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { ))); } - // Get the bound object: - let py_obj = if let Some(obj) = py_obj { - obj.clone() - } else { - match vm.get_locals().get_item("self") { - Some(obj) => obj, - _ => panic!("No self"), - } - }; - // Check obj type: if !(objtype::isinstance(&py_obj, &py_type) || objtype::issubclass(&py_obj, &py_type)) { return Err(vm.new_type_error( From 84e89d37e2a3b03f6c7a9a8a5ada99a2fc577b23 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Thu, 21 Mar 2019 18:56:18 +0200 Subject: [PATCH 2/3] Use __class__ cell in super --- tests/snippets/class.py | 14 +++++++++++ vm/src/frame.rs | 28 ++++++++++++++-------- vm/src/obj/objsuper.rs | 51 ++++++++++++++++++++++++++--------------- 3 files changed, 65 insertions(+), 28 deletions(-) diff --git a/tests/snippets/class.py b/tests/snippets/class.py index ee766f38f8..35d043a6d4 100644 --- a/tests/snippets/class.py +++ b/tests/snippets/class.py @@ -85,6 +85,20 @@ class Me2(Me): def test(me): return super().test() +class A(): + def f(self): + pass + +class B(A): + def f(self): + super().f() + +class C(B): + def f(self): + super().f() + +C().f() + me = Me2() assert me.test() == 100 diff --git a/vm/src/frame.rs b/vm/src/frame.rs index 3a6809f0ea..e257ea21fa 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -106,6 +106,20 @@ impl Scope { } } + pub fn get(&self, name: &str) -> Option { + for dict in self.locals.iter() { + if let Some(value) = dict.get_item(name) { + return Some(value); + } + } + + if let Some(value) = self.globals.get_item(name) { + return Some(value); + } + + None + } + pub fn get_only_locals(&self) -> Option { self.locals.iter().next().cloned() } @@ -130,17 +144,11 @@ pub trait NameProtocol { impl NameProtocol for Scope { fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option { - for dict in self.locals.iter() { - if let Some(value) = dict.get_item(name) { - return Some(value); - } - } - - if let Some(value) = self.globals.get_item(name) { - return Some(value); + if let Some(value) = self.get(name) { + Some(value) + } else { + vm.builtins.get_item(name) } - - vm.builtins.get_item(name) } fn store_name(&self, vm: &VirtualMachine, key: &str, value: PyObjectRef) { diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index ebb9e8d950..38e9cf0e7c 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -19,6 +19,7 @@ use super::objtype; #[derive(Debug)] pub struct PySuper { obj: PyObjectRef, + typ: PyObjectRef, } impl PyValue for PySuper { @@ -68,8 +69,9 @@ fn super_getattribute(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { ); let inst = super_obj.payload::().unwrap().obj.clone(); + let typ = super_obj.payload::().unwrap().typ.clone(); - match inst.typ().payload::() { + match typ.payload::() { Some(PyClass { ref mro, .. }) => { for class in mro { if let Ok(item) = vm.get_attribute(class.as_object().clone(), name_str.clone()) { @@ -99,6 +101,29 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { return Err(vm.new_type_error(format!("{:?} is not a subtype of super", cls))); } + // Get the type: + let py_type = if let Some(ty) = py_type { + ty.clone() + } else { + match vm.current_scope().get("__class__") { + Some(obj) => obj.clone(), + _ => { + return Err(vm.new_type_error( + "super must be called with 1 argument or from inside class method".to_string(), + )); + } + } + }; + + // Check type argument: + if !objtype::isinstance(&py_type, &vm.get_type()) { + let type_name = objtype::get_type_name(&py_type.typ()); + return Err(vm.new_type_error(format!( + "super() argument 1 must be type, not {}", + type_name + ))); + } + // Get the bound object: let py_obj = if let Some(obj) = py_obj { obj.clone() @@ -119,22 +144,6 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { } }; - // Get the type: - let py_type = if let Some(ty) = py_type { - ty.clone() - } else { - py_obj.typ().clone() - }; - - // Check type argument: - if !objtype::isinstance(&py_type, &vm.get_type()) { - let type_name = objtype::get_type_name(&py_type.typ()); - return Err(vm.new_type_error(format!( - "super() argument 1 must be type, not {}", - type_name - ))); - } - // Check obj type: if !(objtype::isinstance(&py_obj, &py_type) || objtype::issubclass(&py_obj, &py_type)) { return Err(vm.new_type_error( @@ -142,5 +151,11 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { )); } - Ok(PyObject::new(PySuper { obj: py_obj }, cls.clone())) + Ok(PyObject::new( + PySuper { + obj: py_obj, + typ: py_type, + }, + cls.clone(), + )) } From 2c8657c3b37a8ceed583ff6a014168b3fcc82fb6 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 22 Mar 2019 10:45:58 +0200 Subject: [PATCH 3/3] Add load_cell to NameProtocol --- vm/src/frame.rs | 38 ++++++++++++++++++++------------------ vm/src/obj/objsuper.rs | 3 ++- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/vm/src/frame.rs b/vm/src/frame.rs index e257ea21fa..0015557fec 100644 --- a/vm/src/frame.rs +++ b/vm/src/frame.rs @@ -106,20 +106,6 @@ impl Scope { } } - pub fn get(&self, name: &str) -> Option { - for dict in self.locals.iter() { - if let Some(value) = dict.get_item(name) { - return Some(value); - } - } - - if let Some(value) = self.globals.get_item(name) { - return Some(value); - } - - None - } - pub fn get_only_locals(&self) -> Option { self.locals.iter().next().cloned() } @@ -140,15 +126,31 @@ pub trait NameProtocol { fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option; fn store_name(&self, vm: &VirtualMachine, name: &str, value: PyObjectRef); fn delete_name(&self, vm: &VirtualMachine, name: &str); + fn load_cell(&self, vm: &VirtualMachine, name: &str) -> Option; } impl NameProtocol for Scope { fn load_name(&self, vm: &VirtualMachine, name: &str) -> Option { - if let Some(value) = self.get(name) { - Some(value) - } else { - vm.builtins.get_item(name) + for dict in self.locals.iter() { + if let Some(value) = dict.get_item(name) { + return Some(value); + } + } + + if let Some(value) = self.globals.get_item(name) { + return Some(value); } + + vm.builtins.get_item(name) + } + + fn load_cell(&self, _vm: &VirtualMachine, name: &str) -> Option { + for dict in self.locals.iter().skip(1) { + if let Some(value) = dict.get_item(name) { + return Some(value); + } + } + None } fn store_name(&self, vm: &VirtualMachine, key: &str, value: PyObjectRef) { diff --git a/vm/src/obj/objsuper.rs b/vm/src/obj/objsuper.rs index 38e9cf0e7c..2b2be02414 100644 --- a/vm/src/obj/objsuper.rs +++ b/vm/src/obj/objsuper.rs @@ -6,6 +6,7 @@ https://github.com/python/cpython/blob/50b48572d9a90c5bb36e2bef6179548ea927a35a/ */ +use crate::frame::NameProtocol; use crate::function::PyFuncArgs; use crate::obj::objstr; use crate::obj::objtype::PyClass; @@ -105,7 +106,7 @@ fn super_new(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { let py_type = if let Some(ty) = py_type { ty.clone() } else { - match vm.current_scope().get("__class__") { + match vm.current_scope().load_cell(vm, "__class__") { Some(obj) => obj.clone(), _ => { return Err(vm.new_type_error(