diff --git a/Lib/test/test_bool.py b/Lib/test/test_bool.py index f413b7aaaf..4b32aad241 100644 --- a/Lib/test/test_bool.py +++ b/Lib/test/test_bool.py @@ -7,8 +7,6 @@ class BoolTest(unittest.TestCase): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_subclass(self): try: class C(bool): diff --git a/vm/src/builtins/type.rs b/vm/src/builtins/type.rs index a4db53225b..015811ac4d 100644 --- a/vm/src/builtins/type.rs +++ b/vm/src/builtins/type.rs @@ -358,6 +358,7 @@ impl Py { /// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__, /// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic /// method. + /// Similar to CPython PyType_IsSubtype pub fn fast_issubclass(&self, cls: &impl Borrow) -> bool { self.as_object().is(cls.borrow()) || self.mro.iter().any(|c| c.is(cls.borrow())) } @@ -383,7 +384,36 @@ impl PyType { subtype = subtype.name(), ))); } - call_slot_new(zelf, subtype, args, vm) + + let typ = zelf; + // Check that the use doesn't do something silly and unsafe like + // object.__new__(dict). To do this, we check that the + // most derived base that's not a heap type is this type. */ + let mut static_base = subtype.clone(); + loop { + if static_base + .slots + .new + .load() + .map_or(false, |f| f as usize == crate::types::new_wrapper as usize) + { + static_base = static_base.base().unwrap(); + } else { + break; + } + } + if static_base.slots.new.load().map_or(0, |f| f as usize) + != typ.slots.new.load().map_or(0, |f| f as usize) + { + return Err(vm.new_type_error(format!( + "{}.__new__({}) is not safe, use {}.__new__()", + typ.name(), + subtype.name(), + static_base.name() + ))); + } + + call_slot_new(typ, subtype, args, vm) } #[pygetset(name = "__mro__")] @@ -1114,6 +1144,7 @@ pub(crate) fn init(ctx: &Context) { PyType::extend_class(ctx, ctx.types.type_type); } +// part of tp_new_wrapper *after* subtype check pub(crate) fn call_slot_new( typ: PyTypeRef, subtype: PyTypeRef, diff --git a/vm/src/types/slot.rs b/vm/src/types/slot.rs index 1be87b4f14..8521a719e2 100644 --- a/vm/src/types/slot.rs +++ b/vm/src/types/slot.rs @@ -334,7 +334,8 @@ fn init_wrapper(obj: PyObjectRef, args: FuncArgs, vm: &VirtualMachine) -> PyResu Ok(()) } -fn new_wrapper(cls: PyTypeRef, mut args: FuncArgs, vm: &VirtualMachine) -> PyResult { +// = slot_tp_new +pub(crate) fn new_wrapper(cls: PyTypeRef, mut args: FuncArgs, vm: &VirtualMachine) -> PyResult { let new = cls.get_attr(identifier!(vm, __new__)).unwrap(); args.prepend_arg(cls.into()); vm.invoke(&new, args)