Skip to content

Rework issubclass #5867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions vm/src/builtins/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<
}
}

fn is_subtype_with_mro(a_mro: &[PyTypeRef], a: &Py<PyType>, b: &Py<PyType>) -> bool {
if a.is(b) {
return true;
}
for item in a_mro {
if item.is(b) {
return true;
}
}
false
}

impl PyType {
pub fn new_simple_heap(
name: &str,
Expand Down Expand Up @@ -197,6 +209,12 @@ impl PyType {
Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx)
}

/// Equivalent to CPython's PyType_Check macro
/// Checks if obj is an instance of type (or its subclass)
pub(crate) fn check(obj: &PyObject) -> Option<&Py<Self>> {
obj.downcast_ref::<Self>()
}

fn resolve_mro(bases: &[PyRef<Self>]) -> Result<Vec<PyTypeRef>, String> {
// Check for duplicates in bases.
let mut unique_bases = HashSet::new();
Expand Down Expand Up @@ -439,6 +457,16 @@ impl PyType {
}

impl Py<PyType> {
pub(crate) fn is_subtype(&self, other: &Py<PyType>) -> bool {
is_subtype_with_mro(&self.mro.read(), self, other)
}

/// Equivalent to CPython's PyType_CheckExact macro
/// Checks if obj is exactly a type (not a subclass)
pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py<PyType>> {
obj.downcast_ref_if_exact::<PyType>(vm)
}

/// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__,
/// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic
/// method.
Expand Down
150 changes: 95 additions & 55 deletions vm/src/protocol/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,80 +371,120 @@ impl PyObject {
})
}

// Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything
// else go through.
fn check_cls<F>(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult
// Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class,
// Err with TypeError if not. Uses abstract_get_bases internally.
fn check_class<F>(&self, vm: &VirtualMachine, msg: F) -> PyResult<()>
where
F: Fn() -> String,
{
cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| {
// Only mask AttributeErrors.
if e.class().is(vm.ctx.exceptions.attribute_error) {
vm.new_type_error(msg())
} else {
e
let cls = self;
match cls.abstract_get_bases(vm)? {
Some(_bases) => Ok(()), // Has __bases__, it's a valid class
None => {
// No __bases__ or __bases__ is not a tuple
Err(vm.new_type_error(msg()))
}
})
}
}

fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
let mut derived = self;
let mut first_item: PyObjectRef;
loop {
if derived.is(cls) {
return Ok(true);
/// abstract_get_bases() has logically 4 return states:
/// 1. getattr(cls, '__bases__') could raise an AttributeError
/// 2. getattr(cls, '__bases__') could raise some other exception
/// 3. getattr(cls, '__bases__') could return a tuple
/// 4. getattr(cls, '__bases__') could return something other than a tuple
///
/// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None.
/// If an object other than a tuple comes out of __bases__, then again, None is returned.
/// Other exceptions are propagated.
fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult<Option<PyTupleRef>> {
match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? {
Some(bases) => {
// Check if it's a tuple
match PyTupleRef::try_from_object(vm, bases) {
Ok(tuple) => Ok(Some(tuple)),
Err(_) => Ok(None), // Not a tuple, return None
}
}
None => Ok(None), // AttributeError was masked
}
}

let bases = derived.get_attr(identifier!(vm, __bases__), vm)?;
let tuple = PyTupleRef::try_from_object(vm, bases)?;
fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
// # Safety: The lifetime of `derived` is forced to be ignored
let bases = unsafe {
let mut derived = self;
// First loop: handle single inheritance without recursion
loop {
if derived.is(cls) {
return Ok(true);
}

let n = tuple.len();
match n {
0 => {
let Some(bases) = derived.abstract_get_bases(vm)? else {
return Ok(false);
}
1 => {
first_item = tuple[0].clone();
derived = &first_item;
continue;
}
_ => {
for i in 0..n {
let check = vm.with_recursion("in abstract_issubclass", || {
tuple[i].abstract_issubclass(cls, vm)
})?;
if check {
return Ok(true);
}
};
let n = bases.len();
match n {
0 => return Ok(false),
1 => {
// Avoid recursion in the single inheritance case
// # safety
// Intention:
// ```
// derived = bases.as_slice()[0].as_object();
// ```
// Though type-system cannot guarantee, derived does live long enough in the loop.
derived = &*(bases.as_slice()[0].as_object() as *const _);
continue;
}
_ => {
// Multiple inheritance - break out to handle recursively
break bases;
}
}
}
};

return Ok(false);
// Second loop: handle multiple inheritance with recursion
// At this point we know n >= 2
let n = bases.len();
debug_assert!(n >= 2);

for i in 0..n {
let result = vm.with_recursion("in __issubclass__", || {
bases.as_slice()[i].abstract_issubclass(cls, vm)
})?;
if result {
return Ok(true);
}
}

Ok(false)
}

fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
if let (Ok(obj), Ok(cls)) = (self.try_to_ref::<PyType>(vm), cls.try_to_ref::<PyType>(vm)) {
Ok(obj.fast_issubclass(cls))
} else {
// Check if derived is a class
self.check_cls(self, vm, || {
format!("issubclass() arg 1 must be a class, not {}", self.class())
// Fast path for both being types (matches CPython's PyType_Check)
if let Some(cls) = PyType::check(cls)
&& let Some(derived) = PyType::check(self)
{
// PyType_IsSubtype equivalent
return Ok(derived.is_subtype(cls));
}
// Check if derived is a class
self.check_class(vm, || {
format!("issubclass() arg 1 must be a class, not {}", self.class())
})?;

// Check if cls is a class, tuple, or union (matches CPython's order and message)
if !cls.class().is(vm.ctx.types.union_type) {
cls.check_class(vm, || {
format!(
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
cls.class()
)
})?;

// Check if cls is a class, tuple, or union
if !cls.class().is(vm.ctx.types.union_type) {
self.check_cls(cls, vm, || {
format!(
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
cls.class()
)
})?;
}

self.abstract_issubclass(cls, vm)
}

self.abstract_issubclass(cls, vm)
}

/// Real issubclass check without going through __subclasscheck__
Expand Down Expand Up @@ -520,7 +560,7 @@ impl PyObject {
Ok(retval)
} else {
// Not a type object, check if it's a valid class
self.check_cls(cls, vm, || {
cls.check_class(vm, || {
format!(
"isinstance() arg 2 must be a type, a tuple of types, or a union, not {}",
cls.class()
Expand Down
Loading