diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 2223bd619e..c33653d07f 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -1578,8 +1578,29 @@ impl PyObject { // The intention is for this to replace `PyObjectPayload` once everything is // converted to use `PyObjectPayload::AnyRustvalue`. -pub trait PyValue: Any + fmt::Debug { +pub trait PyValue: Any + fmt::Debug + Sized { fn required_type(ctx: &PyContext) -> PyObjectRef; + + fn into_ref(self, ctx: &PyContext) -> PyRef { + PyRef { + obj: PyObject::new(self, Self::required_type(ctx)), + _payload: PhantomData, + } + } + + fn into_ref_with_type(self, vm: &mut VirtualMachine, cls: PyClassRef) -> PyResult> { + let required_type = Self::required_type(&vm.ctx); + if objtype::issubclass(&cls.obj, &required_type) { + Ok(PyRef { + obj: PyObject::new(self, cls.obj), + _payload: PhantomData, + }) + } else { + let subtype = vm.to_pystr(&cls.obj)?; + let basetype = vm.to_pystr(&required_type)?; + Err(vm.new_type_error(format!("{} is not a subtype of {}", subtype, basetype))) + } + } } impl FromPyObjectRef for PyRef {