Skip to content

Commit 7c6d063

Browse files
authored
Update isintance (#5868)
* PyUnion::get_args * issubclass * isinstance
1 parent 4cca8b0 commit 7c6d063

File tree

2 files changed

+33
-28
lines changed

2 files changed

+33
-28
lines changed

vm/src/builtins/union.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ impl PyUnion {
4040
Self { args, parameters }
4141
}
4242

43+
/// Direct access to args field, matching CPython's _Py_union_args
44+
#[inline]
45+
pub fn args(&self) -> &PyTupleRef {
46+
&self.args
47+
}
48+
4349
fn repr(&self, vm: &VirtualMachine) -> PyResult<String> {
4450
fn repr_item(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<String> {
4551
if obj.is(vm.ctx.types.none_type) {

vm/src/protocol/object.rs

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -497,45 +497,46 @@ impl PyObject {
497497
/// via the __subclasscheck__ magic method.
498498
/// PyObject_IsSubclass/object_issubclass
499499
pub fn is_subclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
500+
let derived = self;
500501
// PyType_CheckExact(cls)
501502
if cls.class().is(vm.ctx.types.type_type) {
502-
if self.is(cls) {
503+
if derived.is(cls) {
503504
return Ok(true);
504505
}
505-
return self.recursive_issubclass(cls, vm);
506+
return derived.recursive_issubclass(cls, vm);
506507
}
507508

508509
// Check for Union type - CPython handles this before tuple
509-
let cls_to_check = if cls.class().is(vm.ctx.types.union_type) {
510+
let cls = if cls.class().is(vm.ctx.types.union_type) {
510511
// Get the __args__ attribute which contains the union members
511-
if let Ok(args) = cls.get_attr(identifier!(vm, __args__), vm) {
512-
args
513-
} else {
514-
cls.to_owned()
515-
}
512+
// Match CPython's _Py_union_args which directly accesses the args field
513+
let union = cls
514+
.downcast_ref::<crate::builtins::PyUnion>()
515+
.expect("union is already checked");
516+
union.args().as_object()
516517
} else {
517-
cls.to_owned()
518+
cls
518519
};
519520

520-
// Check if cls_to_check is a tuple
521-
if let Ok(tuple) = cls_to_check.try_to_value::<&Py<PyTuple>>(vm) {
522-
for typ in tuple {
523-
if vm.with_recursion("in __subclasscheck__", || self.is_subclass(typ, vm))? {
521+
// Check if cls is a tuple
522+
if let Some(tuple) = cls.downcast_ref::<PyTuple>() {
523+
for item in tuple {
524+
if vm.with_recursion("in __subclasscheck__", || derived.is_subclass(item, vm))? {
524525
return Ok(true);
525526
}
526527
}
527528
return Ok(false);
528529
}
529530

530531
// Check for __subclasscheck__ method
531-
if let Some(meth) = vm.get_special_method(cls, identifier!(vm, __subclasscheck__))? {
532-
let ret = vm.with_recursion("in __subclasscheck__", || {
533-
meth.invoke((self.to_owned(),), vm)
532+
if let Some(checker) = vm.get_special_method(cls, identifier!(vm, __subclasscheck__))? {
533+
let res = vm.with_recursion("in __subclasscheck__", || {
534+
checker.invoke((derived.to_owned(),), vm)
534535
})?;
535-
return ret.try_to_bool(vm);
536+
return res.try_to_bool(vm);
536537
}
537538

538-
self.recursive_issubclass(cls, vm)
539+
derived.recursive_issubclass(cls, vm)
539540
}
540541

541542
/// Real isinstance check without going through __instancecheck__
@@ -601,16 +602,14 @@ impl PyObject {
601602

602603
// Check for Union type (e.g., int | str) - CPython checks this before tuple
603604
if cls.class().is(vm.ctx.types.union_type) {
604-
if let Ok(args) = cls.get_attr(identifier!(vm, __args__), vm) {
605-
if let Ok(tuple) = args.try_to_ref::<PyTuple>(vm) {
606-
for typ in tuple {
607-
if vm
608-
.with_recursion("in __instancecheck__", || self.is_instance(typ, vm))?
609-
{
610-
return Ok(true);
611-
}
612-
}
613-
return Ok(false);
605+
// Match CPython's _Py_union_args which directly accesses the args field
606+
let union = cls
607+
.try_to_ref::<crate::builtins::PyUnion>(vm)
608+
.expect("checked by is");
609+
let tuple = union.args();
610+
for typ in tuple.iter() {
611+
if vm.with_recursion("in __instancecheck__", || self.is_instance(typ, vm))? {
612+
return Ok(true);
614613
}
615614
}
616615
}

0 commit comments

Comments
 (0)