diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index 94c4f2f668..35bf6d3eba 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -158,6 +158,18 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult, b: &Py) -> bool { + if a.is(b) { + return true; + } + for item in a_mro { + if item.is(b) { + return true; + } + } + false +} + impl PyType { pub fn new_simple_heap( name: &str, @@ -197,6 +209,12 @@ impl PyType { Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx) } + /// Equivalent to CPython's PyType_Check macro + /// Checks if obj is an instance of type (or its subclass) + pub(crate) fn check(obj: &PyObject) -> Option<&Py> { + obj.downcast_ref::() + } + fn resolve_mro(bases: &[PyRef]) -> Result, String> { // Check for duplicates in bases. let mut unique_bases = HashSet::new(); @@ -439,6 +457,16 @@ impl PyType { } impl Py { + pub(crate) fn is_subtype(&self, other: &Py) -> bool { + is_subtype_with_mro(&self.mro.read(), self, other) + } + + /// Equivalent to CPython's PyType_CheckExact macro + /// Checks if obj is exactly a type (not a subclass) + pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py> { + obj.downcast_ref_if_exact::(vm) + } + /// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__, /// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic /// method. diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 804918abb3..61973def4b 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -371,80 +371,120 @@ impl PyObject { }) } - // Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything - // else go through. - fn check_cls(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult + // Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class, + // Err with TypeError if not. Uses abstract_get_bases internally. + fn check_class(&self, vm: &VirtualMachine, msg: F) -> PyResult<()> where F: Fn() -> String, { - cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| { - // Only mask AttributeErrors. - if e.class().is(vm.ctx.exceptions.attribute_error) { - vm.new_type_error(msg()) - } else { - e + let cls = self; + match cls.abstract_get_bases(vm)? { + Some(_bases) => Ok(()), // Has __bases__, it's a valid class + None => { + // No __bases__ or __bases__ is not a tuple + Err(vm.new_type_error(msg())) } - }) + } } - fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { - let mut derived = self; - let mut first_item: PyObjectRef; - loop { - if derived.is(cls) { - return Ok(true); + /// abstract_get_bases() has logically 4 return states: + /// 1. getattr(cls, '__bases__') could raise an AttributeError + /// 2. getattr(cls, '__bases__') could raise some other exception + /// 3. getattr(cls, '__bases__') could return a tuple + /// 4. getattr(cls, '__bases__') could return something other than a tuple + /// + /// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None. + /// If an object other than a tuple comes out of __bases__, then again, None is returned. + /// Other exceptions are propagated. + fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult> { + match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? { + Some(bases) => { + // Check if it's a tuple + match PyTupleRef::try_from_object(vm, bases) { + Ok(tuple) => Ok(Some(tuple)), + Err(_) => Ok(None), // Not a tuple, return None + } } + None => Ok(None), // AttributeError was masked + } + } - let bases = derived.get_attr(identifier!(vm, __bases__), vm)?; - let tuple = PyTupleRef::try_from_object(vm, bases)?; + fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { + // # Safety: The lifetime of `derived` is forced to be ignored + let bases = unsafe { + let mut derived = self; + // First loop: handle single inheritance without recursion + loop { + if derived.is(cls) { + return Ok(true); + } - let n = tuple.len(); - match n { - 0 => { + let Some(bases) = derived.abstract_get_bases(vm)? else { return Ok(false); - } - 1 => { - first_item = tuple[0].clone(); - derived = &first_item; - continue; - } - _ => { - for i in 0..n { - let check = vm.with_recursion("in abstract_issubclass", || { - tuple[i].abstract_issubclass(cls, vm) - })?; - if check { - return Ok(true); - } + }; + let n = bases.len(); + match n { + 0 => return Ok(false), + 1 => { + // Avoid recursion in the single inheritance case + // # safety + // Intention: + // ``` + // derived = bases.as_slice()[0].as_object(); + // ``` + // Though type-system cannot guarantee, derived does live long enough in the loop. + derived = &*(bases.as_slice()[0].as_object() as *const _); + continue; + } + _ => { + // Multiple inheritance - break out to handle recursively + break bases; } } } + }; - return Ok(false); + // Second loop: handle multiple inheritance with recursion + // At this point we know n >= 2 + let n = bases.len(); + debug_assert!(n >= 2); + + for i in 0..n { + let result = vm.with_recursion("in __issubclass__", || { + bases.as_slice()[i].abstract_issubclass(cls, vm) + })?; + if result { + return Ok(true); + } } + + Ok(false) } fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult { - if let (Ok(obj), Ok(cls)) = (self.try_to_ref::(vm), cls.try_to_ref::(vm)) { - Ok(obj.fast_issubclass(cls)) - } else { - // Check if derived is a class - self.check_cls(self, vm, || { - format!("issubclass() arg 1 must be a class, not {}", self.class()) + // Fast path for both being types (matches CPython's PyType_Check) + if let Some(cls) = PyType::check(cls) + && let Some(derived) = PyType::check(self) + { + // PyType_IsSubtype equivalent + return Ok(derived.is_subtype(cls)); + } + // Check if derived is a class + self.check_class(vm, || { + format!("issubclass() arg 1 must be a class, not {}", self.class()) + })?; + + // Check if cls is a class, tuple, or union (matches CPython's order and message) + if !cls.class().is(vm.ctx.types.union_type) { + cls.check_class(vm, || { + format!( + "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}", + cls.class() + ) })?; - - // Check if cls is a class, tuple, or union - if !cls.class().is(vm.ctx.types.union_type) { - self.check_cls(cls, vm, || { - format!( - "issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}", - cls.class() - ) - })?; - } - - self.abstract_issubclass(cls, vm) } + + self.abstract_issubclass(cls, vm) } /// Real issubclass check without going through __subclasscheck__ @@ -520,7 +560,7 @@ impl PyObject { Ok(retval) } else { // Not a type object, check if it's a valid class - self.check_cls(cls, vm, || { + cls.check_class(vm, || { format!( "isinstance() arg 2 must be a type, a tuple of types, or a union, not {}", cls.class()