Skip to content

Commit 6212c81

Browse files
authored
Merge pull request #5276 from youknowone/async-for-comprehension
Async for comprehension
2 parents 3949ecc + 408459b commit 6212c81

File tree

4 files changed

+147
-83
lines changed

4 files changed

+147
-83
lines changed

Lib/test/test_asyncgen.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -513,16 +513,15 @@ def __anext__(self):
513513
return self.yielded
514514
self.check_async_iterator_anext(MyAsyncIterWithTypesCoro)
515515

516-
# TODO: RUSTPYTHON: async for gen expression compilation
517-
# def test_async_gen_aiter(self):
518-
# async def gen():
519-
# yield 1
520-
# yield 2
521-
# g = gen()
522-
# async def consume():
523-
# return [i async for i in aiter(g)]
524-
# res = self.loop.run_until_complete(consume())
525-
# self.assertEqual(res, [1, 2])
516+
def test_async_gen_aiter(self):
517+
async def gen():
518+
yield 1
519+
yield 2
520+
g = gen()
521+
async def consume():
522+
return [i async for i in aiter(g)]
523+
res = self.loop.run_until_complete(consume())
524+
self.assertEqual(res, [1, 2])
526525

527526
# TODO: RUSTPYTHON, NameError: name 'aiter' is not defined
528527
@unittest.expectedFailure
@@ -1569,22 +1568,23 @@ async def main():
15691568
self.assertIn('unhandled exception during asyncio.run() shutdown',
15701569
message['message'])
15711570

1572-
# TODO: RUSTPYTHON: async for gen expression compilation
1573-
# def test_async_gen_expression_01(self):
1574-
# async def arange(n):
1575-
# for i in range(n):
1576-
# await asyncio.sleep(0.01)
1577-
# yield i
1571+
# TODO: RUSTPYTHON; TypeError: object async_generator can't be used in 'await' expression
1572+
@unittest.expectedFailure
1573+
def test_async_gen_expression_01(self):
1574+
async def arange(n):
1575+
for i in range(n):
1576+
await asyncio.sleep(0.01)
1577+
yield i
15781578

1579-
# def make_arange(n):
1580-
# # This syntax is legal starting with Python 3.7
1581-
# return (i * 2 async for i in arange(n))
1579+
def make_arange(n):
1580+
# This syntax is legal starting with Python 3.7
1581+
return (i * 2 async for i in arange(n))
15821582

1583-
# async def run():
1584-
# return [i async for i in make_arange(10)]
1583+
async def run():
1584+
return [i async for i in make_arange(10)]
15851585

1586-
# res = self.loop.run_until_complete(run())
1587-
# self.assertEqual(res, [i * 2 for i in range(10)])
1586+
res = self.loop.run_until_complete(run())
1587+
self.assertEqual(res, [i * 2 for i in range(10)])
15881588

15891589
# TODO: RUSTPYTHON: async for gen expression compilation
15901590
# def test_async_gen_expression_02(self):

Lib/test/test_grammar.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -418,44 +418,46 @@ def test_var_annot_simple_exec(self):
418418
gns['__annotations__']
419419

420420
# TODO: RUSTPYTHON
421-
# def test_var_annot_custom_maps(self):
422-
# # tests with custom locals() and __annotations__
423-
# ns = {'__annotations__': CNS()}
424-
# exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
425-
# self.assertEqual(ns['__annotations__']['x'], int)
426-
# self.assertEqual(ns['__annotations__']['z'], str)
427-
# with self.assertRaises(KeyError):
428-
# ns['__annotations__']['w']
429-
# nonloc_ns = {}
430-
# class CNS2:
431-
# def __init__(self):
432-
# self._dct = {}
433-
# def __setitem__(self, item, value):
434-
# nonlocal nonloc_ns
435-
# self._dct[item] = value
436-
# nonloc_ns[item] = value
437-
# def __getitem__(self, item):
438-
# return self._dct[item]
439-
# exec('x: int = 1', {}, CNS2())
440-
# self.assertEqual(nonloc_ns['__annotations__']['x'], int)
421+
@unittest.expectedFailure
422+
def test_var_annot_custom_maps(self):
423+
# tests with custom locals() and __annotations__
424+
ns = {'__annotations__': CNS()}
425+
exec('X: int; Z: str = "Z"; (w): complex = 1j', ns)
426+
self.assertEqual(ns['__annotations__']['x'], int)
427+
self.assertEqual(ns['__annotations__']['z'], str)
428+
with self.assertRaises(KeyError):
429+
ns['__annotations__']['w']
430+
nonloc_ns = {}
431+
class CNS2:
432+
def __init__(self):
433+
self._dct = {}
434+
def __setitem__(self, item, value):
435+
nonlocal nonloc_ns
436+
self._dct[item] = value
437+
nonloc_ns[item] = value
438+
def __getitem__(self, item):
439+
return self._dct[item]
440+
exec('x: int = 1', {}, CNS2())
441+
self.assertEqual(nonloc_ns['__annotations__']['x'], int)
441442

442443
# TODO: RUSTPYTHON
443-
# def test_var_annot_refleak(self):
444-
# # complex case: custom locals plus custom __annotations__
445-
# # this was causing refleak
446-
# cns = CNS()
447-
# nonloc_ns = {'__annotations__': cns}
448-
# class CNS2:
449-
# def __init__(self):
450-
# self._dct = {'__annotations__': cns}
451-
# def __setitem__(self, item, value):
452-
# nonlocal nonloc_ns
453-
# self._dct[item] = value
454-
# nonloc_ns[item] = value
455-
# def __getitem__(self, item):
456-
# return self._dct[item]
457-
# exec('X: str', {}, CNS2())
458-
# self.assertEqual(nonloc_ns['__annotations__']['x'], str)
444+
@unittest.expectedFailure
445+
def test_var_annot_refleak(self):
446+
# complex case: custom locals plus custom __annotations__
447+
# this was causing refleak
448+
cns = CNS()
449+
nonloc_ns = {'__annotations__': cns}
450+
class CNS2:
451+
def __init__(self):
452+
self._dct = {'__annotations__': cns}
453+
def __setitem__(self, item, value):
454+
nonlocal nonloc_ns
455+
self._dct[item] = value
456+
nonloc_ns[item] = value
457+
def __getitem__(self, item):
458+
return self._dct[item]
459+
exec('X: str', {}, CNS2())
460+
self.assertEqual(nonloc_ns['__annotations__']['x'], str)
459461

460462

461463
def test_var_annot_rhs(self):

compiler/codegen/src/compile.rs

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,24 +2629,30 @@ impl Compiler {
26292629
compile_element: &dyn Fn(&mut Self) -> CompileResult<()>,
26302630
) -> CompileResult<()> {
26312631
let prev_ctx = self.ctx;
2632+
let is_async = generators.iter().any(|g| g.is_async);
26322633

26332634
self.ctx = CompileContext {
26342635
loop_data: None,
26352636
in_class: prev_ctx.in_class,
2636-
func: FunctionContext::Function,
2637+
func: if is_async {
2638+
FunctionContext::AsyncFunction
2639+
} else {
2640+
FunctionContext::Function
2641+
},
26372642
};
26382643

26392644
// We must have at least one generator:
26402645
assert!(!generators.is_empty());
26412646

2647+
let flags = bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED;
2648+
let flags = if is_async {
2649+
flags | bytecode::CodeFlags::IS_COROUTINE
2650+
} else {
2651+
flags
2652+
};
2653+
26422654
// Create magnificent function <listcomp>:
2643-
self.push_output(
2644-
bytecode::CodeFlags::NEW_LOCALS | bytecode::CodeFlags::IS_OPTIMIZED,
2645-
1,
2646-
1,
2647-
0,
2648-
name.to_owned(),
2649-
);
2655+
self.push_output(flags, 1, 1, 0, name.to_owned());
26502656
let arg0 = self.varname(".0")?;
26512657

26522658
let return_none = init_collection.is_none();
@@ -2657,13 +2663,11 @@ impl Compiler {
26572663

26582664
let mut loop_labels = vec![];
26592665
for generator in generators {
2660-
if generator.is_async {
2661-
unimplemented!("async for comprehensions");
2662-
}
2663-
26642666
let loop_block = self.new_block();
26652667
let after_block = self.new_block();
26662668

2669+
// emit!(self, Instruction::SetupLoop);
2670+
26672671
if loop_labels.is_empty() {
26682672
// Load iterator onto stack (passed as first argument):
26692673
emit!(self, Instruction::LoadFast(arg0));
@@ -2672,20 +2676,36 @@ impl Compiler {
26722676
self.compile_expression(&generator.iter)?;
26732677

26742678
// Get iterator / turn item into an iterator
2675-
emit!(self, Instruction::GetIter);
2679+
if generator.is_async {
2680+
emit!(self, Instruction::GetAIter);
2681+
} else {
2682+
emit!(self, Instruction::GetIter);
2683+
}
26762684
}
26772685

26782686
loop_labels.push((loop_block, after_block));
2679-
26802687
self.switch_to_block(loop_block);
2681-
emit!(
2682-
self,
2683-
Instruction::ForIter {
2684-
target: after_block,
2685-
}
2686-
);
2687-
2688-
self.compile_store(&generator.target)?;
2688+
if generator.is_async {
2689+
emit!(
2690+
self,
2691+
Instruction::SetupExcept {
2692+
handler: after_block,
2693+
}
2694+
);
2695+
emit!(self, Instruction::GetANext);
2696+
self.emit_constant(ConstantData::None);
2697+
emit!(self, Instruction::YieldFrom);
2698+
self.compile_store(&generator.target)?;
2699+
emit!(self, Instruction::PopBlock);
2700+
} else {
2701+
emit!(
2702+
self,
2703+
Instruction::ForIter {
2704+
target: after_block,
2705+
}
2706+
);
2707+
self.compile_store(&generator.target)?;
2708+
}
26892709

26902710
// Now evaluate the ifs:
26912711
for if_condition in &generator.ifs {
@@ -2701,6 +2721,9 @@ impl Compiler {
27012721

27022722
// End of for loop:
27032723
self.switch_to_block(after_block);
2724+
if is_async {
2725+
emit!(self, Instruction::EndAsyncFor);
2726+
}
27042727
}
27052728

27062729
if return_none {
@@ -2737,10 +2760,19 @@ impl Compiler {
27372760
self.compile_expression(&generators[0].iter)?;
27382761

27392762
// Get iterator / turn item into an iterator
2740-
emit!(self, Instruction::GetIter);
2763+
if is_async {
2764+
emit!(self, Instruction::GetAIter);
2765+
} else {
2766+
emit!(self, Instruction::GetIter);
2767+
};
27412768

27422769
// Call just created <listcomp> function:
27432770
emit!(self, Instruction::CallFunctionPositional { nargs: 1 });
2771+
if is_async {
2772+
emit!(self, Instruction::GetAwaitable);
2773+
self.emit_constant(ConstantData::None);
2774+
emit!(self, Instruction::YieldFrom);
2775+
}
27442776
Ok(())
27452777
}
27462778

vm/src/frame.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,10 @@ impl ExecutingFrame<'_> {
351351
let mut arg_state = bytecode::OpArgState::default();
352352
loop {
353353
let idx = self.lasti() as usize;
354+
// eprintln!(
355+
// "location: {:?} {}",
356+
// self.code.locations[idx], self.code.source_path
357+
// );
354358
self.update_lasti(|i| *i += 1);
355359
let bytecode::CodeUnit { op, arg } = instrs[idx];
356360
let arg = arg_state.extend(arg);
@@ -993,6 +997,9 @@ impl ExecutingFrame<'_> {
993997
Ok(None)
994998
}
995999
bytecode::Instruction::GetANext => {
1000+
#[cfg(debug_assertions)] // remove when GetANext is fully implemented
1001+
let orig_stack_len = self.state.stack.len();
1002+
9961003
let aiter = self.top_value();
9971004
let awaitable = if aiter.class().is(vm.ctx.types.async_generator) {
9981005
vm.call_special_method(aiter, identifier!(vm, __anext__), ())?
@@ -1030,6 +1037,8 @@ impl ExecutingFrame<'_> {
10301037
})?
10311038
};
10321039
self.push_value(awaitable);
1040+
#[cfg(debug_assertions)]
1041+
debug_assert_eq!(orig_stack_len + 1, self.state.stack.len());
10331042
Ok(None)
10341043
}
10351044
bytecode::Instruction::EndAsyncFor => {
@@ -1238,6 +1247,7 @@ impl ExecutingFrame<'_> {
12381247
fn unwind_blocks(&mut self, vm: &VirtualMachine, reason: UnwindReason) -> FrameResult {
12391248
// First unwind all existing blocks on the block stack:
12401249
while let Some(block) = self.current_block() {
1250+
// eprintln!("unwinding block: {:.60?} {:.60?}", block.typ, reason);
12411251
match block.typ {
12421252
BlockType::Loop => match reason {
12431253
UnwindReason::Break { target } => {
@@ -1935,6 +1945,7 @@ impl ExecutingFrame<'_> {
19351945
}
19361946

19371947
fn push_block(&mut self, typ: BlockType) {
1948+
// eprintln!("block pushed: {:.60?} {}", typ, self.state.stack.len());
19381949
self.state.blocks.push(Block {
19391950
typ,
19401951
level: self.state.stack.len(),
@@ -1944,6 +1955,12 @@ impl ExecutingFrame<'_> {
19441955
#[track_caller]
19451956
fn pop_block(&mut self) -> Block {
19461957
let block = self.state.blocks.pop().expect("No more blocks to pop!");
1958+
// eprintln!(
1959+
// "block popped: {:.60?} {} -> {} ",
1960+
// block.typ,
1961+
// self.state.stack.len(),
1962+
// block.level
1963+
// );
19471964
#[cfg(debug_assertions)]
19481965
if self.state.stack.len() < block.level {
19491966
dbg!(&self);
@@ -1965,6 +1982,11 @@ impl ExecutingFrame<'_> {
19651982
#[inline]
19661983
#[track_caller] // not a real track_caller but push_value is not very useful
19671984
fn push_value(&mut self, obj: PyObjectRef) {
1985+
// eprintln!(
1986+
// "push_value {} / len: {} +1",
1987+
// obj.class().name(),
1988+
// self.state.stack.len()
1989+
// );
19681990
match self.state.stack.try_push(obj) {
19691991
Ok(()) => {}
19701992
Err(_e) => self.fatal("tried to push value onto stack but overflowed max_stackdepth"),
@@ -1975,7 +1997,14 @@ impl ExecutingFrame<'_> {
19751997
#[track_caller] // not a real track_caller but pop_value is not very useful
19761998
fn pop_value(&mut self) -> PyObjectRef {
19771999
match self.state.stack.pop() {
1978-
Some(x) => x,
2000+
Some(x) => {
2001+
// eprintln!(
2002+
// "pop_value {} / len: {}",
2003+
// x.class().name(),
2004+
// self.state.stack.len()
2005+
// );
2006+
x
2007+
}
19792008
None => self.fatal("tried to pop value but there was nothing on the stack"),
19802009
}
19812010
}
@@ -2002,6 +2031,7 @@ impl ExecutingFrame<'_> {
20022031
}
20032032

20042033
#[inline]
2034+
#[track_caller]
20052035
fn nth_value(&self, depth: u32) -> &PyObject {
20062036
let stack = &self.state.stack;
20072037
&stack[stack.len() - depth as usize - 1]

0 commit comments

Comments
 (0)