Skip to content

Commit f9557a0

Browse files
committed
fix
1 parent 2372920 commit f9557a0

File tree

1 file changed

+91
-45
lines changed

1 file changed

+91
-45
lines changed

vm/src/builtins/type.rs

Lines changed: 91 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ impl Constructor for PyType {
10191019
attributes.insert(identifier!(vm, __hash__), vm.ctx.none.clone().into());
10201020
}
10211021

1022-
let heaptype_slots: Option<PyRef<PyTuple<PyStrRef>>> =
1022+
let (heaptype_slots, add_dict): (Option<PyRef<PyTuple<PyStrRef>>>, bool) =
10231023
if let Some(x) = attributes.get(identifier!(vm, __slots__)) {
10241024
let slots = if x.class().is(vm.ctx.types.str_type) {
10251025
let x = unsafe { x.downcast_unchecked_ref::<PyStr>() };
@@ -1036,9 +1036,26 @@ impl Constructor for PyType {
10361036
let tuple = elements.into_pytuple(vm);
10371037
tuple.try_into_typed(vm)?
10381038
};
1039-
Some(slots)
1039+
1040+
// Check if __dict__ is in slots
1041+
let dict_name = "__dict__";
1042+
let has_dict = slots.iter().any(|s| s.as_str() == dict_name);
1043+
1044+
// Filter out __dict__ from slots
1045+
let filtered_slots = if has_dict {
1046+
let filtered: Vec<PyStrRef> = slots
1047+
.iter()
1048+
.filter(|s| s.as_str() != dict_name)
1049+
.cloned()
1050+
.collect();
1051+
PyTuple::new_ref_typed(filtered, &vm.ctx)
1052+
} else {
1053+
slots
1054+
};
1055+
1056+
(Some(filtered_slots), has_dict)
10401057
} else {
1041-
None
1058+
(None, false)
10421059
};
10431060

10441061
// FIXME: this is a temporary fix. multi bases with multiple slots will break object
@@ -1051,8 +1068,10 @@ impl Constructor for PyType {
10511068
let member_count: usize = base_member_count + heaptype_member_count;
10521069

10531070
let mut flags = PyTypeFlags::heap_type_flags();
1054-
// Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined.
1055-
if heaptype_slots.is_none() {
1071+
// Add HAS_DICT and MANAGED_DICT if:
1072+
// 1. __slots__ is not defined, OR
1073+
// 2. __dict__ is in __slots__
1074+
if heaptype_slots.is_none() || add_dict {
10561075
flags |= PyTypeFlags::HAS_DICT | PyTypeFlags::MANAGED_DICT;
10571076
}
10581077

@@ -1130,13 +1149,14 @@ impl Constructor for PyType {
11301149

11311150
// Add __dict__ descriptor after type creation to ensure correct __objclass__
11321151
if !base_is_type {
1133-
unsafe {
1134-
let descriptor =
1135-
vm.ctx
1136-
.new_getset("__dict__", &typ, subtype_get_dict, subtype_set_dict);
1137-
typ.attributes
1138-
.write()
1139-
.insert(identifier!(vm, __dict__), descriptor.into());
1152+
let __dict__ = identifier!(vm, __dict__);
1153+
if !typ.attributes.read().contains_key(&__dict__) {
1154+
unsafe {
1155+
let descriptor =
1156+
vm.ctx
1157+
.new_getset("__dict__", &typ, subtype_get_dict, subtype_set_dict);
1158+
typ.attributes.write().insert(__dict__, descriptor.into());
1159+
}
11401160
}
11411161
}
11421162

@@ -1445,51 +1465,77 @@ impl Representable for PyType {
14451465
}
14461466
}
14471467

1448-
fn find_base_dict_descr(cls: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
1449-
cls.iter_base_chain().skip(1).find_map(|cls| {
1450-
// TODO: should actually be some translation of:
1451-
// cls.slot_dictoffset != 0 && !cls.flags.contains(HEAPTYPE)
1452-
if cls.is(vm.ctx.types.type_type) {
1453-
cls.get_attr(identifier!(vm, __dict__))
1454-
} else {
1455-
None
1468+
// = get_builtin_base_with_dict
1469+
fn get_builtin_base_with_dict(typ: &Py<PyType>, vm: &VirtualMachine) -> Option<PyTypeRef> {
1470+
let mut current = Some(typ.to_owned());
1471+
while let Some(t) = current {
1472+
// In CPython: type->tp_dictoffset != 0 && !(type->tp_flags & Py_TPFLAGS_HEAPTYPE)
1473+
// Special case: type itself is a builtin with dict support
1474+
if t.is(vm.ctx.types.type_type) {
1475+
return Some(t);
1476+
}
1477+
// We check HAS_DICT flag (equivalent to tp_dictoffset != 0) and HEAPTYPE
1478+
if t.slots.flags.contains(PyTypeFlags::HAS_DICT)
1479+
&& !t.slots.flags.contains(PyTypeFlags::HEAPTYPE)
1480+
{
1481+
return Some(t);
14561482
}
1457-
})
1483+
current = t.__base__();
1484+
}
1485+
None
1486+
}
1487+
1488+
// = get_dict_descriptor
1489+
fn get_dict_descriptor(base: &Py<PyType>, vm: &VirtualMachine) -> Option<PyObjectRef> {
1490+
let dict_attr = identifier!(vm, __dict__);
1491+
// Use _PyType_Lookup (which is lookup_ref in RustPython)
1492+
base.lookup_ref(dict_attr, vm)
1493+
}
1494+
1495+
// = raise_dict_descr_error
1496+
fn raise_dict_descriptor_error(obj: &PyObject, vm: &VirtualMachine) -> PyBaseExceptionRef {
1497+
vm.new_type_error(format!(
1498+
"this __dict__ descriptor does not support '{}' objects",
1499+
obj.class().name()
1500+
))
14581501
}
14591502

14601503
fn subtype_get_dict(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult {
1461-
// TODO: obj.class().as_pyref() need to be supported
1462-
let ret = match find_base_dict_descr(obj.class(), vm) {
1463-
Some(descr) => vm.call_get_descriptor(&descr, obj).unwrap_or_else(|| {
1464-
Err(vm.new_type_error(format!(
1465-
"this __dict__ descriptor does not support '{}' objects",
1466-
descr.class()
1467-
)))
1468-
})?,
1469-
None => object::object_get_dict(obj, vm)?.into(),
1470-
};
1471-
Ok(ret)
1504+
let base = get_builtin_base_with_dict(obj.class(), vm);
1505+
1506+
if let Some(base_type) = base {
1507+
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
1508+
// Call the descriptor's tp_descr_get
1509+
vm.call_get_descriptor(&descr, obj.clone())
1510+
.unwrap_or_else(|| Err(raise_dict_descriptor_error(&obj, vm)))
1511+
} else {
1512+
Err(raise_dict_descriptor_error(&obj, vm))
1513+
}
1514+
} else {
1515+
// PyObject_GenericGetDict
1516+
object::object_get_dict(obj, vm).map(Into::into)
1517+
}
14721518
}
14731519

1520+
// = subtype_setdict
14741521
fn subtype_set_dict(obj: PyObjectRef, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
1475-
let cls = obj.class();
1476-
match find_base_dict_descr(cls, vm) {
1477-
Some(descr) => {
1522+
let base = get_builtin_base_with_dict(obj.class(), vm);
1523+
1524+
if let Some(base_type) = base {
1525+
if let Some(descr) = get_dict_descriptor(&base_type, vm) {
1526+
// Call the descriptor's tp_descr_set
14781527
let descr_set = descr
14791528
.class()
14801529
.mro_find_map(|cls| cls.slots.descr_set.load())
1481-
.ok_or_else(|| {
1482-
vm.new_type_error(format!(
1483-
"this __dict__ descriptor does not support '{}' objects",
1484-
cls.name()
1485-
))
1486-
})?;
1530+
.ok_or_else(|| raise_dict_descriptor_error(&obj, vm))?;
14871531
descr_set(&descr, obj, PySetterValue::Assign(value), vm)
1532+
} else {
1533+
Err(raise_dict_descriptor_error(&obj, vm))
14881534
}
1489-
None => {
1490-
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
1491-
Ok(())
1492-
}
1535+
} else {
1536+
// PyObject_GenericSetDict
1537+
object::object_set_dict(obj, value.try_into_value(vm)?, vm)?;
1538+
Ok(())
14931539
}
14941540
}
14951541

0 commit comments

Comments
 (0)