diff --git a/vm/src/builtins/pytype.rs b/vm/src/builtins/pytype.rs index 20cb9046b4..1c8e5a50a1 100644 --- a/vm/src/builtins/pytype.rs +++ b/vm/src/builtins/pytype.rs @@ -49,6 +49,61 @@ impl PyValue for PyType { } impl PyType { + pub fn new( + metaclass: PyRef, + name: &str, + base: PyRef, + bases: Vec>, + attrs: PyAttributes, + mut slots: PyTypeSlots, + ) -> Result, String> { + // Check for duplicates in bases. + let mut unique_bases = HashSet::new(); + for base in bases.iter() { + if !unique_bases.insert(base.get_id()) { + return Err(format!("duplicate base class {}", base.name())); + } + } + + let mros = bases + .iter() + .map(|x| x.iter_mro().cloned().collect()) + .collect(); + let mro = linearise_mro(mros)?; + + if base.slots.flags.has_feature(PyTpFlags::HAS_DICT) { + slots.flags |= PyTpFlags::HAS_DICT + } + + *slots.name.write() = Some(String::from(name)); + + let new_type = PyRef::new_ref( + PyType { + base: Some(base), + bases, + mro, + subclasses: PyRwLock::default(), + attributes: PyRwLock::new(attrs), + slots, + }, + metaclass, + None, + ); + + for attr_name in new_type.attributes.read().keys() { + if attr_name.starts_with("__") && attr_name.ends_with("__") { + new_type.update_slot(attr_name, true); + } + } + for base in &new_type.bases { + base.subclasses + .write() + .push(PyWeak::downgrade(new_type.as_object())); + } + + Ok(new_type) + } + pub fn tp_name(&self) -> String { self.slots.name.read().as_ref().unwrap().to_string() } @@ -493,7 +548,7 @@ impl PyType { let flags = PyTpFlags::heap_type_flags() | PyTpFlags::HAS_DICT; let slots = PyTypeSlots::from_flags(flags); - let typ = new(metatype, name.as_str(), base, bases, attributes, slots) + let typ = Self::new(metatype, name.as_str(), base, bases, attributes, slots) .map_err(|e| vm.new_type_error(e))?; // avoid deadlock @@ -844,61 +899,6 @@ fn linearise_mro(mut bases: Vec>) -> Result, Strin Ok(result) } -pub fn new( - typ: PyTypeRef, - name: &str, - base: PyTypeRef, - bases: Vec, - attrs: PyAttributes, - mut slots: PyTypeSlots, -) -> Result { - // Check for duplicates in bases. - let mut unique_bases = HashSet::new(); - for base in bases.iter() { - if !unique_bases.insert(base.get_id()) { - return Err(format!("duplicate base class {}", base.name())); - } - } - - let mros = bases - .iter() - .map(|x| x.iter_mro().cloned().collect()) - .collect(); - let mro = linearise_mro(mros)?; - - if base.slots.flags.has_feature(PyTpFlags::HAS_DICT) { - slots.flags |= PyTpFlags::HAS_DICT - } - - *slots.name.write() = Some(String::from(name)); - - let new_type = PyRef::new_ref( - PyType { - base: Some(base), - bases, - mro, - subclasses: PyRwLock::default(), - attributes: PyRwLock::new(attrs), - slots, - }, - typ, - None, - ); - - for attr_name in new_type.attributes.read().keys() { - if attr_name.starts_with("__") && attr_name.ends_with("__") { - new_type.update_slot(attr_name, true); - } - } - for base in &new_type.bases { - base.subclasses - .write() - .push(PyWeak::downgrade(new_type.as_object())); - } - - Ok(new_type) -} - fn calculate_meta_class( metatype: PyTypeRef, bases: &[PyTypeRef], @@ -986,7 +986,7 @@ mod tests { let object = &context.types.object_type; let type_type = &context.types.type_type; - let a = new( + let a = PyType::new( type_type.clone(), "A", object.clone(), @@ -995,7 +995,7 @@ mod tests { Default::default(), ) .unwrap(); - let b = new( + let b = PyType::new( type_type.clone(), "B", object.clone(), diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index e7c565f782..06ac0c1813 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -78,7 +78,7 @@ mod _io { }; use crate::{ builtins::{ - pytype, PyByteArray, PyBytes, PyBytesRef, PyMemoryView, PyStr, PyStrRef, PyTypeRef, + PyByteArray, PyBytes, PyBytesRef, PyMemoryView, PyStr, PyStrRef, PyType, PyTypeRef, }, byteslike::{ArgBytesLike, ArgMemoryBuffer}, exceptions::{self, PyBaseExceptionRef}, @@ -3621,7 +3621,7 @@ mod _io { } pub(super) fn make_unsupportedop(ctx: &PyContext) -> PyTypeRef { - pytype::new( + PyType::new( ctx.types.type_type.clone(), "UnsupportedOperation", ctx.exceptions.os_error.clone(), diff --git a/vm/src/stdlib/ssl.rs b/vm/src/stdlib/ssl.rs index c59fc02f0f..9b765e7e12 100644 --- a/vm/src/stdlib/ssl.rs +++ b/vm/src/stdlib/ssl.rs @@ -1,7 +1,7 @@ use super::socket::{self, PySocketRef}; use crate::common::lock::{PyRwLock, PyRwLockWriteGuard}; use crate::{ - builtins::{pytype, weakref::PyWeak, PyStrRef, PyTypeRef}, + builtins::{PyStrRef, PyType, PyTypeRef, PyWeak}, byteslike::{ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike}, exceptions::{create_exception_type, IntoPyException, PyBaseExceptionRef}, function::{ArgCallable, OptionalArg}, @@ -1112,7 +1112,7 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef { let ssl_error = create_exception_type("SSLError", &vm.ctx.exceptions.os_error); - let ssl_cert_verification_error = pytype::new( + let ssl_cert_verification_error = PyType::new( ctx.types.type_type.clone(), "SSLCertVerificationError", ssl_error.clone(), diff --git a/vm/src/types.rs b/vm/src/types.rs index 2215b3df13..d905dc1cfd 100644 --- a/vm/src/types.rs +++ b/vm/src/types.rs @@ -227,7 +227,7 @@ pub fn create_type_with_slots( slots: PyTypeSlots, ) -> PyTypeRef { let dict = PyAttributes::default(); - pytype::new( + PyType::new( type_type.clone(), name, base.clone(),