Skip to content

use PyIterReturn for gen internal #3183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 34 additions & 28 deletions vm/src/builtins/asyncgenerator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{PyCode, PyStrRef, PyTypeRef};
use crate::{
builtins::PyBaseExceptionRef,
coroutine::{Coro, Variant},
coroutine::Coro,
frame::FrameRef,
function::OptionalArg,
protocol::PyIterReturn,
Expand Down Expand Up @@ -34,7 +34,7 @@ impl PyAsyncGen {

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
PyAsyncGen {
inner: Coro::new(frame, Variant::AsyncGen, name),
inner: Coro::new(frame, name),
running_async: AtomicCell::new(false),
}
}
Expand All @@ -50,8 +50,8 @@ impl PyAsyncGen {
}

#[pymethod(magic)]
fn repr(zelf: PyRef<Self>) -> String {
zelf.inner.repr(zelf.get_id())
fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> String {
zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm)
}

#[pymethod(magic)]
Expand Down Expand Up @@ -138,17 +138,20 @@ impl PyValue for PyAsyncGenWrappedValue {
impl PyAsyncGenWrappedValue {}

impl PyAsyncGenWrappedValue {
fn unbox(ag: &PyAsyncGen, val: PyResult, vm: &VirtualMachine) -> PyResult {
if let Err(ref e) = val {
if e.isinstance(&vm.ctx.exceptions.stop_async_iteration)
|| e.isinstance(&vm.ctx.exceptions.generator_exit)
{
ag.inner.closed.store(true);
}
fn unbox(ag: &PyAsyncGen, val: PyResult<PyIterReturn>, vm: &VirtualMachine) -> PyResult {
let (closed, async_done) = match &val {
Ok(PyIterReturn::StopIteration(_)) => (true, true),
Err(e) if e.isinstance(&vm.ctx.exceptions.generator_exit) => (true, true),
Err(_) => (false, true),
_ => (false, false),
};
if closed {
ag.inner.closed.store(true);
}
if async_done {
ag.running_async.store(false);
}
let val = val?;

let val = val?.into_async_pyresult(vm)?;
match_class!(match val {
val @ Self => {
ag.running_async.store(false);
Expand Down Expand Up @@ -214,7 +217,7 @@ impl PyAsyncGenASend {
}
}
};
let res = self.ag.inner.send(val, vm);
let res = self.ag.inner.send(self.ag.as_object(), val, vm);
let res = PyAsyncGenWrappedValue::unbox(&self.ag, res, vm);
if res.is_err() {
self.close();
Expand All @@ -237,6 +240,7 @@ impl PyAsyncGenASend {
}

let res = self.ag.inner.throw(
self.ag.as_object(),
exc_type,
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
Expand All @@ -258,8 +262,7 @@ impl PyAsyncGenASend {
impl IteratorIterable for PyAsyncGenASend {}
impl SlotIterator for PyAsyncGenASend {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// TODO: Fix zelf.send to return PyIterReturn
PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm)
PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
}
}

Expand Down Expand Up @@ -315,27 +318,28 @@ impl PyAsyncGenAThrow {
self.ag.running_async.store(true);

let (ty, val, tb) = self.value.clone();
let ret = self.ag.inner.throw(ty, val, tb, vm);
let ret = self.ag.inner.throw(self.ag.as_object(), ty, val, tb, vm);
let ret = if self.aclose {
if self.ignored_close(&ret) {
Err(self.yield_close(vm))
} else {
ret
ret.and_then(|o| o.into_async_pyresult(vm))
}
} else {
PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
};
ret.map_err(|e| self.check_error(e, vm))
}
AwaitableState::Iter => {
let ret = self.ag.inner.send(val, vm);
let ret = self.ag.inner.send(self.ag.as_object(), val, vm);
if self.aclose {
match ret {
Ok(v) if v.payload_is::<PyAsyncGenWrappedValue>() => {
Ok(PyIterReturn::Return(v)) if v.payload_is::<PyAsyncGenWrappedValue>() => {
Err(self.yield_close(vm))
}
Ok(v) => Ok(v),
Err(e) => Err(self.check_error(e, vm)),
other => other
.and_then(|o| o.into_async_pyresult(vm))
.map_err(|e| self.check_error(e, vm)),
}
} else {
PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
Expand All @@ -353,6 +357,7 @@ impl PyAsyncGenAThrow {
vm: &VirtualMachine,
) -> PyResult {
let ret = self.ag.inner.throw(
self.ag.as_object(),
exc_type,
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
Expand All @@ -362,7 +367,7 @@ impl PyAsyncGenAThrow {
if self.ignored_close(&ret) {
Err(self.yield_close(vm))
} else {
ret
ret.and_then(|o| o.into_async_pyresult(vm))
}
} else {
PyAsyncGenWrappedValue::unbox(&self.ag, ret, vm)
Expand All @@ -375,9 +380,11 @@ impl PyAsyncGenAThrow {
self.state.store(AwaitableState::Closed);
}

fn ignored_close(&self, res: &PyResult) -> bool {
res.as_ref()
.map_or(false, |v| v.payload_is::<PyAsyncGenWrappedValue>())
fn ignored_close(&self, res: &PyResult<PyIterReturn>) -> bool {
res.as_ref().map_or(false, |v| match v {
PyIterReturn::Return(obj) => obj.payload_is::<PyAsyncGenWrappedValue>(),
PyIterReturn::StopIteration(_) => false,
})
}
fn yield_close(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
self.ag.running_async.store(false);
Expand All @@ -401,8 +408,7 @@ impl PyAsyncGenAThrow {
impl IteratorIterable for PyAsyncGenAThrow {}
impl SlotIterator for PyAsyncGenAThrow {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// TODO: Fix zelf.send to return PyIterReturn
PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm)
PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
}
}

Expand Down
39 changes: 20 additions & 19 deletions vm/src/builtins/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{PyCode, PyStrRef, PyTypeRef};
use crate::{
coroutine::{Coro, Variant},
coroutine::Coro,
frame::FrameRef,
function::OptionalArg,
protocol::PyIterReturn,
Expand All @@ -10,6 +10,7 @@ use crate::{

#[pyclass(module = false, name = "coroutine")]
#[derive(Debug)]
// PyCoro_Type in CPython
pub struct PyCoroutine {
inner: Coro,
}
Expand All @@ -28,7 +29,7 @@ impl PyCoroutine {

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
PyCoroutine {
inner: Coro::new(frame, Variant::Coroutine, name),
inner: Coro::new(frame, name),
}
}

Expand All @@ -43,24 +44,25 @@ impl PyCoroutine {
}

#[pymethod(magic)]
fn repr(zelf: PyRef<Self>) -> String {
zelf.inner.repr(zelf.get_id())
fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> String {
zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm)
}

#[pymethod]
fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.inner.send(value, vm)
fn send(zelf: PyRef<Self>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
zelf.inner.send(zelf.as_object(), value, vm)
}

#[pymethod]
fn throw(
&self,
zelf: PyRef<Self>,
exc_type: PyObjectRef,
exc_val: OptionalArg,
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
self.inner.throw(
) -> PyResult<PyIterReturn> {
zelf.inner.throw(
zelf.as_object(),
exc_type,
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
Expand All @@ -69,8 +71,8 @@ impl PyCoroutine {
}

#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.inner.close(vm)
fn close(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<()> {
zelf.inner.close(zelf.as_object(), vm)
}

#[pymethod(name = "__await__")]
Expand Down Expand Up @@ -105,13 +107,13 @@ impl PyCoroutine {
impl IteratorIterable for PyCoroutine {}
impl SlotIterator for PyCoroutine {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// TODO: Fix zelf.send to return PyIterReturn
PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm)
Self::send(zelf.clone(), vm.ctx.none(), vm)
}
}

#[pyclass(module = false, name = "coroutine_wrapper")]
#[derive(Debug)]
// PyCoroWrapper_Type in CPython
pub struct PyCoroutineWrapper {
coro: PyRef<PyCoroutine>,
}
Expand All @@ -125,8 +127,8 @@ impl PyValue for PyCoroutineWrapper {
#[pyimpl(with(SlotIterator))]
impl PyCoroutineWrapper {
#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.coro.send(val, vm)
fn send(zelf: PyRef<Self>, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
PyCoroutine::send(zelf.coro.clone(), val, vm)
}

#[pymethod]
Expand All @@ -136,16 +138,15 @@ impl PyCoroutineWrapper {
exc_val: OptionalArg,
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
self.coro.throw(exc_type, exc_val, exc_tb, vm)
) -> PyResult<PyIterReturn> {
PyCoroutine::throw(self.coro.clone(), exc_type, exc_val, exc_tb, vm)
}
}

impl IteratorIterable for PyCoroutineWrapper {}
impl SlotIterator for PyCoroutineWrapper {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// TODO: Fix zelf.send to return PyIterReturn
PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm)
Self::send(zelf.clone(), vm.ctx.none(), vm)
}
}

Expand Down
3 changes: 2 additions & 1 deletion vm/src/builtins/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ impl SlotIterator for PyFilter {
} else {
// the predicate itself can raise StopIteration which does stop the filter
// iteration
match PyIterReturn::from_result(vm.invoke(predicate, vec![next_obj.clone()]), vm)? {
match PyIterReturn::from_pyresult(vm.invoke(predicate, vec![next_obj.clone()]), vm)?
{
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)),
}
Expand Down
26 changes: 13 additions & 13 deletions vm/src/builtins/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use super::{PyCode, PyStrRef, PyTypeRef};
use crate::{
coroutine::{Coro, Variant},
coroutine::Coro,
frame::FrameRef,
function::OptionalArg,
protocol::PyIterReturn,
Expand Down Expand Up @@ -32,7 +32,7 @@ impl PyGenerator {

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
PyGenerator {
inner: Coro::new(frame, Variant::Gen, name),
inner: Coro::new(frame, name),
}
}

Expand All @@ -47,24 +47,25 @@ impl PyGenerator {
}

#[pymethod(magic)]
fn repr(zelf: PyRef<Self>) -> String {
zelf.inner.repr(zelf.get_id())
fn repr(zelf: PyRef<Self>, vm: &VirtualMachine) -> String {
zelf.inner.repr(zelf.as_object(), zelf.get_id(), vm)
}

#[pymethod]
fn send(&self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.inner.send(value, vm)
fn send(zelf: PyRef<Self>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
zelf.inner.send(zelf.as_object(), value, vm)
}

#[pymethod]
fn throw(
&self,
zelf: PyRef<Self>,
exc_type: PyObjectRef,
exc_val: OptionalArg,
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
self.inner.throw(
) -> PyResult<PyIterReturn> {
zelf.inner.throw(
zelf.as_object(),
exc_type,
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
Expand All @@ -73,8 +74,8 @@ impl PyGenerator {
}

#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.inner.close(vm)
fn close(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult<()> {
zelf.inner.close(zelf.as_object(), vm)
}

#[pyproperty]
Expand All @@ -98,8 +99,7 @@ impl PyGenerator {
impl IteratorIterable for PyGenerator {}
impl SlotIterator for PyGenerator {
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
// TODO: Fix zelf.send to return PyIterReturn
PyIterReturn::from_result(zelf.send(vm.ctx.none(), vm), vm)
Self::send(zelf.clone(), vm.ctx.none(), vm)
}
}

Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl SlotIterator for PyMap {
}

// the mapper itself can raise StopIteration which does stop the map iteration
PyIterReturn::from_result(vm.invoke(&zelf.mapper, next_objs), vm)
PyIterReturn::from_pyresult(vm.invoke(&zelf.mapper, next_objs), vm)
}
}

Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/pytype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl PyType {
}
"__next__" => {
let func: slots::IterNextFunc = |zelf, vm| {
PyIterReturn::from_result(
PyIterReturn::from_pyresult(
vm.call_special_method(zelf.clone(), "__next__", ()),
vm,
)
Expand Down
Loading