diff --git a/vm/src/obj/objdict.rs b/vm/src/obj/objdict.rs index 4fd99ae7ea..df85e75791 100644 --- a/vm/src/obj/objdict.rs +++ b/vm/src/obj/objdict.rs @@ -340,11 +340,13 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, dict_type: // this is not ideal let ptr = PyObjectRef::into_raw(dict_type.clone()) as *mut PyObject; unsafe { - (*ptr).payload = PyObjectPayload::Class { - name: String::from("dict"), - dict: RefCell::new(HashMap::new()), - mro: vec![object_type], + (*ptr).payload = PyObjectPayload::AnyRustValue { + value: Box::new(objtype::PyClass { + name: String::from("dict"), + mro: vec![object_type], + }), }; + (*ptr).dict = Some(RefCell::new(HashMap::new())); (*ptr).typ = Some(type_type.clone()); } } diff --git a/vm/src/obj/objobject.rs b/vm/src/obj/objobject.rs index bf659b2c72..e2bf4463e2 100644 --- a/vm/src/obj/objobject.rs +++ b/vm/src/obj/objobject.rs @@ -1,5 +1,6 @@ use super::objstr; use super::objtype; +use crate::function::PyRef; use crate::pyobject::{ AttributeProtocol, DictProtocol, IdProtocol, PyAttributes, PyContext, PyFuncArgs, PyObject, PyObjectPayload, PyObjectRef, PyResult, TypeProtocol, @@ -8,6 +9,10 @@ use crate::vm::VirtualMachine; use std::cell::RefCell; use std::collections::HashMap; +#[derive(Clone, Debug)] +pub struct PyInstance; +pub type PyInstanceRef = PyRef; + pub fn new_instance(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> PyResult { // more or less __new__ operator let type_ref = args.shift(); @@ -19,11 +24,13 @@ pub fn create_object(type_type: PyObjectRef, object_type: PyObjectRef, _dict_typ // this is not ideal let ptr = PyObjectRef::into_raw(object_type.clone()) as *mut PyObject; unsafe { - (*ptr).payload = PyObjectPayload::Class { - name: String::from("object"), - dict: RefCell::new(HashMap::new()), - mro: vec![], + (*ptr).payload = PyObjectPayload::AnyRustValue { + value: Box::new(objtype::PyClass { + name: String::from("object"), + mro: vec![], + }), }; + (*ptr).dict = Some(RefCell::new(HashMap::new())); (*ptr).typ = Some(type_type.clone()); } } @@ -105,13 +112,13 @@ fn object_delattr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { ] ); - match zelf.payload { - PyObjectPayload::Class { ref dict, .. } | PyObjectPayload::Instance { ref dict, .. } => { + match zelf.dict { + Some(ref dict) => { let attr_name = objstr::get_value(attr); dict.borrow_mut().remove(&attr_name); Ok(vm.get_none()) } - _ => Err(vm.new_type_error("TypeError: no dictionary.".to_string())), + None => Err(vm.new_type_error("TypeError: no dictionary.".to_string())), } } @@ -210,15 +217,14 @@ fn object_class_setter( } fn object_dict(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { - match args.args[0].payload { - PyObjectPayload::Class { ref dict, .. } | PyObjectPayload::Instance { ref dict, .. } => { - let new_dict = vm.new_dict(); - for (attr, value) in dict.borrow().iter() { - new_dict.set_item(&vm.ctx, &attr, value.clone()); - } - Ok(new_dict) + if let Some(ref dict) = args.args[0].dict { + let new_dict = vm.new_dict(); + for (attr, value) in dict.borrow().iter() { + new_dict.set_item(&vm.ctx, &attr, value.clone()); } - _ => Err(vm.new_type_error("TypeError: no dictionary.".to_string())), + Ok(new_dict) + } else { + Err(vm.new_type_error("TypeError: no dictionary.".to_string())) } } @@ -264,7 +270,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> PyAttributes { let mut attributes = objtype::get_attributes(&obj.typ()); // Get instance attributes: - if let PyObjectPayload::Instance { dict } = &obj.payload { + if let Some(dict) = &obj.dict { for (name, value) in dict.borrow().iter() { attributes.insert(name.to_string(), value.clone()); } diff --git a/vm/src/obj/objtype.rs b/vm/src/obj/objtype.rs index dfe11510bc..b82e39465a 100644 --- a/vm/src/obj/objtype.rs +++ b/vm/src/obj/objtype.rs @@ -1,13 +1,27 @@ use super::objdict; use super::objstr; +use crate::function::PyRef; use crate::pyobject::{ AttributeProtocol, IdProtocol, PyAttributes, PyContext, PyFuncArgs, PyObject, PyObjectPayload, - PyObjectRef, PyResult, TypeProtocol, + PyObjectPayload2, PyObjectRef, PyResult, TypeProtocol, }; use crate::vm::VirtualMachine; use std::cell::RefCell; use std::collections::HashMap; +#[derive(Clone, Debug)] +pub struct PyClass { + pub name: String, + pub mro: Vec, +} +pub type PyClassRef = PyRef; + +impl PyObjectPayload2 for PyClass { + fn required_type(ctx: &PyContext) -> PyObjectRef { + ctx.type_type() + } +} + /* * The magical type type */ @@ -16,11 +30,13 @@ pub fn create_type(type_type: PyObjectRef, object_type: PyObjectRef, _dict_type: // this is not ideal let ptr = PyObjectRef::into_raw(type_type.clone()) as *mut PyObject; unsafe { - (*ptr).payload = PyObjectPayload::Class { - name: String::from("type"), - dict: RefCell::new(PyAttributes::new()), - mro: vec![object_type], + (*ptr).payload = PyObjectPayload::AnyRustValue { + value: Box::new(PyClass { + name: String::from("type"), + mro: vec![object_type], + }), }; + (*ptr).dict = Some(RefCell::new(PyAttributes::new())); (*ptr).typ = Some(type_type); } } @@ -80,13 +96,12 @@ fn type_mro(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } fn _mro(cls: PyObjectRef) -> Option> { - match cls.payload { - PyObjectPayload::Class { ref mro, .. } => { - let mut mro = mro.clone(); - mro.insert(0, cls.clone()); - Some(mro) - } - _ => None, + if let Some(PyClass { ref mro, .. }) = cls.payload::() { + let mut mro = mro.clone(); + mro.insert(0, cls.clone()); + Some(mro) + } else { + None } } @@ -127,7 +142,7 @@ fn type_subclass_check(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } pub fn get_type_name(typ: &PyObjectRef) -> String { - if let PyObjectPayload::Class { name, .. } = &typ.payload { + if let Some(PyClass { name, .. }) = &typ.payload::() { name.clone() } else { panic!("Cannot get type_name of non-type type {:?}", typ); @@ -248,7 +263,7 @@ pub fn get_attributes(obj: &PyObjectRef) -> PyAttributes { let mut base_classes = _mro(obj.clone()).expect("Type get_attributes on non-type"); base_classes.reverse(); for bc in base_classes { - if let PyObjectPayload::Class { dict, .. } = &bc.payload { + if let Some(ref dict) = &bc.dict { for (name, value) in dict.borrow().iter() { attributes.insert(name.to_string(), value.clone()); } @@ -313,14 +328,17 @@ pub fn new( ) -> PyResult { let mros = bases.into_iter().map(|x| _mro(x).unwrap()).collect(); let mro = linearise_mro(mros).unwrap(); - Ok(PyObject::new( - PyObjectPayload::Class { - name: String::from(name), - dict: RefCell::new(dict), - mro, + Ok(PyObject { + payload: PyObjectPayload::AnyRustValue { + value: Box::new(PyClass { + name: String::from(name), + mro, + }), }, - typ, - )) + dict: Some(RefCell::new(dict)), + typ: Some(typ), + } + .into_ref()) } fn type_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 785355266a..bf38802d09 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -41,7 +41,7 @@ use crate::obj::objslice; use crate::obj::objstr; use crate::obj::objsuper; use crate::obj::objtuple::{self, PyTuple}; -use crate::obj::objtype; +use crate::obj::objtype::{self, PyClass}; use crate::obj::objzip; use crate::vm::VirtualMachine; @@ -81,18 +81,19 @@ pub type PyAttributes = HashMap; impl fmt::Display for PyObject { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use self::TypeProtocol; + if let Some(PyClass { ref name, .. }) = self.payload::() { + let type_name = objtype::get_type_name(&self.typ()); + // We don't have access to a vm, so just assume that if its parent's name + // is type, it's a type + if type_name == "type" { + return write!(f, "type object '{}'", name); + } else { + return write!(f, "'{}' object", type_name); + } + } + match &self.payload { PyObjectPayload::Module { name, .. } => write!(f, "module '{}'", name), - PyObjectPayload::Class { name, .. } => { - let type_name = objtype::get_type_name(&self.typ()); - // We don't have access to a vm, so just assume that if its parent's name - // is type, it's a type - if type_name == "type" { - write!(f, "type object '{}'", name) - } else { - write!(f, "'{}' object", type_name) - } - } _ => write!(f, "'{}' object", objtype::get_type_name(&self.typ())), } } @@ -156,13 +157,8 @@ pub struct PyContext { } fn _nothing() -> PyObjectRef { - PyObject { - payload: PyObjectPayload::AnyRustValue { - value: Box::new(()), - }, - typ: None, - } - .into_ref() + let obj: PyObject = Default::default(); + obj.into_ref() } pub fn create_type( @@ -706,12 +702,14 @@ impl PyContext { } else { PyAttributes::new() }; - PyObject::new( - PyObjectPayload::Instance { - dict: RefCell::new(dict), + PyObject { + payload: PyObjectPayload::AnyRustValue { + value: Box::new(objobject::PyInstance), }, - class, - ) + typ: Some(class), + dict: Some(RefCell::new(dict)), + } + .into_ref() } // Item set/get: @@ -732,14 +730,12 @@ impl PyContext { } pub fn set_attr(&self, obj: &PyObjectRef, attr_name: &str, value: PyObjectRef) { - match obj.payload { - PyObjectPayload::Module { ref scope, .. } => { - scope.locals.set_item(self, attr_name, value) - } - PyObjectPayload::Instance { ref dict } | PyObjectPayload::Class { ref dict, .. } => { - dict.borrow_mut().insert(attr_name.to_string(), value); - } - ref payload => unimplemented!("set_attr unimplemented for: {:?}", payload), + if let PyObjectPayload::Module { ref scope, .. } = obj.payload { + scope.locals.set_item(self, attr_name, value) + } else if let Some(ref dict) = obj.dict { + dict.borrow_mut().insert(attr_name.to_string(), value); + } else { + unimplemented!("set_attr unimplemented for: {:?}", obj); }; } @@ -774,10 +770,11 @@ impl Default for PyContext { /// This is an actual python object. It consists of a `typ` which is the /// python class, and carries some rust payload optionally. This rust /// payload can be a rust float or rust int in case of float and int objects. +#[derive(Default)] pub struct PyObject { pub payload: PyObjectPayload, pub typ: Option, - // pub dict: HashMap, // __dict__ member + pub dict: Option>, // __dict__ member } pub trait IdProtocol { @@ -824,47 +821,62 @@ pub trait AttributeProtocol { } fn class_get_item(class: &PyObjectRef, attr_name: &str) -> Option { - match class.payload { - PyObjectPayload::Class { ref dict, .. } => dict.borrow().get(attr_name).cloned(), - _ => panic!("Only classes should be in MRO!"), + if let Some(ref dict) = class.dict { + dict.borrow().get(attr_name).cloned() + } else { + panic!("Only classes should be in MRO!"); } } fn class_has_item(class: &PyObjectRef, attr_name: &str) -> bool { - match class.payload { - PyObjectPayload::Class { ref dict, .. } => dict.borrow().contains_key(attr_name), - _ => panic!("Only classes should be in MRO!"), + if let Some(ref dict) = class.dict { + dict.borrow().contains_key(attr_name) + } else { + panic!("Only classes should be in MRO!"); } } impl AttributeProtocol for PyObjectRef { fn get_attr(&self, attr_name: &str) -> Option { - match self.payload { - PyObjectPayload::Module { ref scope, .. } => scope.locals.get_item(attr_name), - PyObjectPayload::Class { ref mro, .. } => { - if let Some(item) = class_get_item(self, attr_name) { + if let Some(PyClass { ref mro, .. }) = self.payload::() { + if let Some(item) = class_get_item(self, attr_name) { + return Some(item); + } + for class in mro { + if let Some(item) = class_get_item(class, attr_name) { return Some(item); } - for class in mro { - if let Some(item) = class_get_item(class, attr_name) { - return Some(item); - } + } + return None; + } + + match self.payload { + PyObjectPayload::Module { ref scope, .. } => scope.locals.get_item(attr_name), + _ => { + if let Some(ref dict) = self.dict { + dict.borrow().get(attr_name).cloned() + } else { + None } - None } - PyObjectPayload::Instance { ref dict } => dict.borrow().get(attr_name).cloned(), - _ => None, } } fn has_attr(&self, attr_name: &str) -> bool { + if let Some(PyClass { ref mro, .. }) = self.payload::() { + return class_has_item(self, attr_name) + || mro.iter().any(|d| class_has_item(d, attr_name)); + } + match self.payload { PyObjectPayload::Module { ref scope, .. } => scope.locals.contains_key(attr_name), - PyObjectPayload::Class { ref mro, .. } => { - class_has_item(self, attr_name) || mro.iter().any(|d| class_has_item(d, attr_name)) + _ => { + if let Some(ref dict) = self.dict { + dict.borrow().contains_key(attr_name) + } else { + false + } } - PyObjectPayload::Instance { ref dict } => dict.borrow().contains_key(attr_name), - _ => false, } } } @@ -1534,17 +1546,9 @@ pub enum PyObjectPayload { name: String, scope: ScopeRef, }, - Class { - name: String, - dict: RefCell, - mro: Vec, - }, WeakRef { referent: PyObjectWeakRef, }, - Instance { - dict: RefCell, - }, RustFunction { function: PyNativeFunc, }, @@ -1553,6 +1557,14 @@ pub enum PyObjectPayload { }, } +impl Default for PyObjectPayload { + fn default() -> Self { + PyObjectPayload::AnyRustValue { + value: Box::new(()), + } + } +} + impl fmt::Debug for PyObjectPayload { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -1571,8 +1583,6 @@ impl fmt::Debug for PyObjectPayload { ref object, } => write!(f, "bound-method: {:?} of {:?}", function, object), PyObjectPayload::Module { .. } => write!(f, "module"), - PyObjectPayload::Class { ref name, .. } => write!(f, "class {:?}", name), - PyObjectPayload::Instance { .. } => write!(f, "instance"), PyObjectPayload::RustFunction { .. } => write!(f, "rust function"), PyObjectPayload::Frame { .. } => write!(f, "frame"), PyObjectPayload::AnyRustValue { value } => value.fmt(f), @@ -1581,14 +1591,11 @@ impl fmt::Debug for PyObjectPayload { } impl PyObject { - pub fn new( - payload: PyObjectPayload, - /* dict: PyObjectRef,*/ typ: PyObjectRef, - ) -> PyObjectRef { + pub fn new(payload: PyObjectPayload, typ: PyObjectRef) -> PyObjectRef { PyObject { payload, typ: Some(typ), - // dict: HashMap::new(), // dict, + dict: None, } .into_ref() } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index f67c931d30..953e016e24 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -300,12 +300,10 @@ impl VirtualMachine { ref scope, ref defaults, } => self.invoke_python_function(code, scope, defaults, args), - PyObjectPayload::Class { .. } => self.call_method(&func_ref, "__call__", args), PyObjectPayload::BoundMethod { ref function, ref object, } => self.invoke(function.clone(), args.insert(object.clone())), - PyObjectPayload::Instance { .. } => self.call_method(&func_ref, "__call__", args), ref payload => { // TODO: is it safe to just invoke __call__ otherwise? trace!("invoke __call__ for: {:?}", payload);