@@ -1019,7 +1019,7 @@ impl Constructor for PyType {
1019
1019
attributes. insert ( identifier ! ( vm, __hash__) , vm. ctx . none . clone ( ) . into ( ) ) ;
1020
1020
}
1021
1021
1022
- let heaptype_slots: Option < PyRef < PyTuple < PyStrRef > > > =
1022
+ let ( heaptype_slots, add_dict ) : ( Option < PyRef < PyTuple < PyStrRef > > > , bool ) =
1023
1023
if let Some ( x) = attributes. get ( identifier ! ( vm, __slots__) ) {
1024
1024
let slots = if x. class ( ) . is ( vm. ctx . types . str_type ) {
1025
1025
let x = unsafe { x. downcast_unchecked_ref :: < PyStr > ( ) } ;
@@ -1036,9 +1036,26 @@ impl Constructor for PyType {
1036
1036
let tuple = elements. into_pytuple ( vm) ;
1037
1037
tuple. try_into_typed ( vm) ?
1038
1038
} ;
1039
- Some ( slots)
1039
+
1040
+ // Check if __dict__ is in slots
1041
+ let dict_name = "__dict__" ;
1042
+ let has_dict = slots. iter ( ) . any ( |s| s. as_str ( ) == dict_name) ;
1043
+
1044
+ // Filter out __dict__ from slots
1045
+ let filtered_slots = if has_dict {
1046
+ let filtered: Vec < PyStrRef > = slots
1047
+ . iter ( )
1048
+ . filter ( |s| s. as_str ( ) != dict_name)
1049
+ . cloned ( )
1050
+ . collect ( ) ;
1051
+ PyTuple :: new_ref_typed ( filtered, & vm. ctx )
1052
+ } else {
1053
+ slots
1054
+ } ;
1055
+
1056
+ ( Some ( filtered_slots) , has_dict)
1040
1057
} else {
1041
- None
1058
+ ( None , false )
1042
1059
} ;
1043
1060
1044
1061
// FIXME: this is a temporary fix. multi bases with multiple slots will break object
@@ -1051,8 +1068,10 @@ impl Constructor for PyType {
1051
1068
let member_count: usize = base_member_count + heaptype_member_count;
1052
1069
1053
1070
let mut flags = PyTypeFlags :: heap_type_flags ( ) ;
1054
- // Only add HAS_DICT and MANAGED_DICT if __slots__ is not defined.
1055
- if heaptype_slots. is_none ( ) {
1071
+ // Add HAS_DICT and MANAGED_DICT if:
1072
+ // 1. __slots__ is not defined, OR
1073
+ // 2. __dict__ is in __slots__
1074
+ if heaptype_slots. is_none ( ) || add_dict {
1056
1075
flags |= PyTypeFlags :: HAS_DICT | PyTypeFlags :: MANAGED_DICT ;
1057
1076
}
1058
1077
@@ -1130,13 +1149,14 @@ impl Constructor for PyType {
1130
1149
1131
1150
// Add __dict__ descriptor after type creation to ensure correct __objclass__
1132
1151
if !base_is_type {
1133
- unsafe {
1134
- let descriptor =
1135
- vm. ctx
1136
- . new_getset ( "__dict__" , & typ, subtype_get_dict, subtype_set_dict) ;
1137
- typ. attributes
1138
- . write ( )
1139
- . insert ( identifier ! ( vm, __dict__) , descriptor. into ( ) ) ;
1152
+ let __dict__ = identifier ! ( vm, __dict__) ;
1153
+ if !typ. attributes . read ( ) . contains_key ( & __dict__) {
1154
+ unsafe {
1155
+ let descriptor =
1156
+ vm. ctx
1157
+ . new_getset ( "__dict__" , & typ, subtype_get_dict, subtype_set_dict) ;
1158
+ typ. attributes . write ( ) . insert ( __dict__, descriptor. into ( ) ) ;
1159
+ }
1140
1160
}
1141
1161
}
1142
1162
@@ -1445,51 +1465,77 @@ impl Representable for PyType {
1445
1465
}
1446
1466
}
1447
1467
1448
- fn find_base_dict_descr ( cls : & Py < PyType > , vm : & VirtualMachine ) -> Option < PyObjectRef > {
1449
- cls. iter_base_chain ( ) . skip ( 1 ) . find_map ( |cls| {
1450
- // TODO: should actually be some translation of:
1451
- // cls.slot_dictoffset != 0 && !cls.flags.contains(HEAPTYPE)
1452
- if cls. is ( vm. ctx . types . type_type ) {
1453
- cls. get_attr ( identifier ! ( vm, __dict__) )
1454
- } else {
1455
- None
1468
+ // = get_builtin_base_with_dict
1469
+ fn get_builtin_base_with_dict ( typ : & Py < PyType > , vm : & VirtualMachine ) -> Option < PyTypeRef > {
1470
+ let mut current = Some ( typ. to_owned ( ) ) ;
1471
+ while let Some ( t) = current {
1472
+ // In CPython: type->tp_dictoffset != 0 && !(type->tp_flags & Py_TPFLAGS_HEAPTYPE)
1473
+ // Special case: type itself is a builtin with dict support
1474
+ if t. is ( vm. ctx . types . type_type ) {
1475
+ return Some ( t) ;
1476
+ }
1477
+ // We check HAS_DICT flag (equivalent to tp_dictoffset != 0) and HEAPTYPE
1478
+ if t. slots . flags . contains ( PyTypeFlags :: HAS_DICT )
1479
+ && !t. slots . flags . contains ( PyTypeFlags :: HEAPTYPE )
1480
+ {
1481
+ return Some ( t) ;
1456
1482
}
1457
- } )
1483
+ current = t. __base__ ( ) ;
1484
+ }
1485
+ None
1486
+ }
1487
+
1488
+ // = get_dict_descriptor
1489
+ fn get_dict_descriptor ( base : & Py < PyType > , vm : & VirtualMachine ) -> Option < PyObjectRef > {
1490
+ let dict_attr = identifier ! ( vm, __dict__) ;
1491
+ // Use _PyType_Lookup (which is lookup_ref in RustPython)
1492
+ base. lookup_ref ( dict_attr, vm)
1493
+ }
1494
+
1495
+ // = raise_dict_descr_error
1496
+ fn raise_dict_descriptor_error ( obj : & PyObject , vm : & VirtualMachine ) -> PyBaseExceptionRef {
1497
+ vm. new_type_error ( format ! (
1498
+ "this __dict__ descriptor does not support '{}' objects" ,
1499
+ obj. class( ) . name( )
1500
+ ) )
1458
1501
}
1459
1502
1460
1503
fn subtype_get_dict ( obj : PyObjectRef , vm : & VirtualMachine ) -> PyResult {
1461
- // TODO: obj.class().as_pyref() need to be supported
1462
- let ret = match find_base_dict_descr ( obj. class ( ) , vm) {
1463
- Some ( descr) => vm. call_get_descriptor ( & descr, obj) . unwrap_or_else ( || {
1464
- Err ( vm. new_type_error ( format ! (
1465
- "this __dict__ descriptor does not support '{}' objects" ,
1466
- descr. class( )
1467
- ) ) )
1468
- } ) ?,
1469
- None => object:: object_get_dict ( obj, vm) ?. into ( ) ,
1470
- } ;
1471
- Ok ( ret)
1504
+ let base = get_builtin_base_with_dict ( obj. class ( ) , vm) ;
1505
+
1506
+ if let Some ( base_type) = base {
1507
+ if let Some ( descr) = get_dict_descriptor ( & base_type, vm) {
1508
+ // Call the descriptor's tp_descr_get
1509
+ vm. call_get_descriptor ( & descr, obj. clone ( ) )
1510
+ . unwrap_or_else ( || Err ( raise_dict_descriptor_error ( & obj, vm) ) )
1511
+ } else {
1512
+ Err ( raise_dict_descriptor_error ( & obj, vm) )
1513
+ }
1514
+ } else {
1515
+ // PyObject_GenericGetDict
1516
+ object:: object_get_dict ( obj, vm) . map ( Into :: into)
1517
+ }
1472
1518
}
1473
1519
1520
+ // = subtype_setdict
1474
1521
fn subtype_set_dict ( obj : PyObjectRef , value : PyObjectRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
1475
- let cls = obj. class ( ) ;
1476
- match find_base_dict_descr ( cls, vm) {
1477
- Some ( descr) => {
1522
+ let base = get_builtin_base_with_dict ( obj. class ( ) , vm) ;
1523
+
1524
+ if let Some ( base_type) = base {
1525
+ if let Some ( descr) = get_dict_descriptor ( & base_type, vm) {
1526
+ // Call the descriptor's tp_descr_set
1478
1527
let descr_set = descr
1479
1528
. class ( )
1480
1529
. mro_find_map ( |cls| cls. slots . descr_set . load ( ) )
1481
- . ok_or_else ( || {
1482
- vm. new_type_error ( format ! (
1483
- "this __dict__ descriptor does not support '{}' objects" ,
1484
- cls. name( )
1485
- ) )
1486
- } ) ?;
1530
+ . ok_or_else ( || raise_dict_descriptor_error ( & obj, vm) ) ?;
1487
1531
descr_set ( & descr, obj, PySetterValue :: Assign ( value) , vm)
1532
+ } else {
1533
+ Err ( raise_dict_descriptor_error ( & obj, vm) )
1488
1534
}
1489
- None => {
1490
- object :: object_set_dict ( obj , value . try_into_value ( vm ) ? , vm ) ? ;
1491
- Ok ( ( ) )
1492
- }
1535
+ } else {
1536
+ // PyObject_GenericSetDict
1537
+ object :: object_set_dict ( obj , value . try_into_value ( vm ) ? , vm ) ? ;
1538
+ Ok ( ( ) )
1493
1539
}
1494
1540
}
1495
1541
0 commit comments