From 429c35bf44e76f31660a92beb145256cce871c27 Mon Sep 17 00:00:00 2001 From: Dan Nasman Date: Fri, 13 Oct 2023 12:51:33 +0300 Subject: [PATCH] Add object protocol correspoinding to PyObject_GetAIter --- vm/src/protocol/object.rs | 11 +++++++++-- vm/src/stdlib/builtins.rs | 7 +------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vm/src/protocol/object.rs b/vm/src/protocol/object.rs index 639e24dda5..0b87bbf9d3 100644 --- a/vm/src/protocol/object.rs +++ b/vm/src/protocol/object.rs @@ -3,8 +3,8 @@ use crate::{ builtins::{ - pystr::AsPyStr, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, PyStr, PyStrRef, - PyTuple, PyTupleRef, PyType, PyTypeRef, + pystr::AsPyStr, PyAsyncGen, PyBytes, PyDict, PyDictRef, PyGenericAlias, PyInt, PyList, + PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, }, bytesinner::ByteInnerNewOptions, common::{hash::PyHash, str::to_ascii}, @@ -92,6 +92,13 @@ impl PyObject { } // PyObject *PyObject_GetAIter(PyObject *o) + pub fn get_aiter(&self, vm: &VirtualMachine) -> PyResult { + if self.payload_is::() { + vm.call_special_method(self, identifier!(vm, __aiter__), ()) + } else { + Err(vm.new_type_error("wrong argument type".to_owned())) + } + } pub fn has_attr<'a>(&self, attr_name: impl AsPyStr<'a>, vm: &VirtualMachine) -> PyResult { self.get_attr(attr_name, vm).map(|o| !vm.is_none(&o)) diff --git a/vm/src/stdlib/builtins.rs b/vm/src/stdlib/builtins.rs index 835f4152ea..4cc3a53634 100644 --- a/vm/src/stdlib/builtins.rs +++ b/vm/src/stdlib/builtins.rs @@ -9,7 +9,6 @@ pub use builtins::{ascii, print, reversed}; mod builtins { use crate::{ builtins::{ - asyncgenerator::PyAsyncGen, enumerate::PyReverseSequenceIterator, function::{PyCellRef, PyFunction}, int::PyIntRef, @@ -459,11 +458,7 @@ mod builtins { #[pyfunction] fn aiter(iter_target: PyObjectRef, vm: &VirtualMachine) -> PyResult { - if iter_target.payload_is::() { - vm.call_special_method(&iter_target, identifier!(vm, __aiter__), ()) - } else { - Err(vm.new_type_error("wrong argument type".to_owned())) - } + iter_target.get_aiter(vm) } #[pyfunction]