Skip to content

Commit 577cea9

Browse files
committed
Add coroutines, async/await functionality, and gen.close()
1 parent 52f1965 commit 577cea9

File tree

10 files changed

+253
-24
lines changed

10 files changed

+253
-24
lines changed

bytecode/src/bytecode.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ bitflags! {
5555
const HAS_ANNOTATIONS = 0x04;
5656
const NEW_LOCALS = 0x08;
5757
const IS_GENERATOR = 0x10;
58+
const IS_COROUTINE = 0x20;
5859
}
5960
}
6061

@@ -273,6 +274,7 @@ pub enum Instruction {
273274
Reverse {
274275
amount: usize,
275276
},
277+
GetAwaitable,
276278
}
277279

278280
use self::Instruction::*;
@@ -549,6 +551,7 @@ impl Instruction {
549551
FormatValue { spec, .. } => w!(FormatValue, spec), // TODO: write conversion
550552
PopException => w!(PopException),
551553
Reverse { amount } => w!(Reverse, amount),
554+
GetAwaitable => w!(GetAwaitable),
552555
}
553556
}
554557
}

compiler/src/compile.rs

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ struct Compiler<O: OutputStream = BasicOutputStream> {
2929
current_qualified_path: Option<String>,
3030
in_loop: bool,
3131
in_function_def: bool,
32+
in_async_func: bool,
3233
optimize: u8,
3334
}
3435

@@ -126,6 +127,7 @@ impl<O: OutputStream> Compiler<O> {
126127
current_qualified_path: None,
127128
in_loop: false,
128129
in_function_def: false,
130+
in_async_func: false,
129131
optimize,
130132
}
131133
}
@@ -451,11 +453,7 @@ impl<O: OutputStream> Compiler<O> {
451453
decorator_list,
452454
returns,
453455
} => {
454-
if *is_async {
455-
unimplemented!("async def");
456-
} else {
457-
self.compile_function_def(name, args, body, decorator_list, returns)?
458-
}
456+
self.compile_function_def(name, args, body, decorator_list, returns, *is_async)?;
459457
}
460458
ClassDef {
461459
name,
@@ -797,11 +795,15 @@ impl<O: OutputStream> Compiler<O> {
797795
body: &[ast::Statement],
798796
decorator_list: &[ast::Expression],
799797
returns: &Option<ast::Expression>, // TODO: use type hint somehow..
798+
is_async: bool,
800799
) -> Result<(), CompileError> {
801800
// Create bytecode for this function:
802801
// remember to restore self.in_loop to the original after the function is compiled
803802
let was_in_loop = self.in_loop;
804803
let was_in_function_def = self.in_function_def;
804+
805+
let was_in_async_func = self.in_async_func;
806+
self.in_async_func = is_async;
805807
self.in_loop = false;
806808
self.in_function_def = true;
807809

@@ -870,6 +872,10 @@ impl<O: OutputStream> Compiler<O> {
870872
});
871873
}
872874

875+
if is_async {
876+
code.flags |= bytecode::CodeFlags::IS_COROUTINE;
877+
}
878+
873879
self.emit(Instruction::LoadConst {
874880
value: bytecode::Constant::Code {
875881
code: Box::new(code),
@@ -891,6 +897,7 @@ impl<O: OutputStream> Compiler<O> {
891897
self.current_qualified_path = old_qualified_path;
892898
self.in_loop = was_in_loop;
893899
self.in_function_def = was_in_function_def;
900+
self.in_async_func = was_in_async_func;
894901
Ok(())
895902
}
896903

@@ -1551,7 +1558,7 @@ impl<O: OutputStream> Compiler<O> {
15511558
self.emit(Instruction::BuildSlice { size });
15521559
}
15531560
Yield { value } => {
1554-
if !self.in_function_def {
1561+
if !self.in_function_def || self.in_async_func {
15551562
return Err(CompileError {
15561563
error: CompileErrorType::InvalidYield,
15571564
location: self.current_source_location.clone(),
@@ -1566,8 +1573,13 @@ impl<O: OutputStream> Compiler<O> {
15661573
};
15671574
self.emit(Instruction::YieldValue);
15681575
}
1569-
Await { .. } => {
1570-
unimplemented!("await");
1576+
Await { value } => {
1577+
self.compile_expression(value)?;
1578+
self.emit(Instruction::GetAwaitable);
1579+
self.emit(Instruction::LoadConst {
1580+
value: bytecode::Constant::None,
1581+
});
1582+
self.emit(Instruction::YieldFrom);
15711583
}
15721584
YieldFrom { value } => {
15731585
self.mark_generator();
@@ -1612,8 +1624,14 @@ impl<O: OutputStream> Compiler<O> {
16121624
self.load_name(name);
16131625
}
16141626
Lambda { args, body } => {
1627+
let was_in_loop = self.in_loop;
1628+
let was_in_function_def = self.in_function_def;
1629+
let was_in_async_func = self.in_async_func;
1630+
self.in_async_func = false;
1631+
self.in_loop = false;
1632+
self.in_function_def = true;
1633+
16151634
let name = "<lambda>".to_string();
1616-
// no need to worry about the self.loop_depth because there are no loops in lambda expressions
16171635
self.enter_function(&name, args)?;
16181636
self.compile_expression(body)?;
16191637
self.emit(Instruction::ReturnValue);
@@ -1629,6 +1647,10 @@ impl<O: OutputStream> Compiler<O> {
16291647
});
16301648
// Turn code object into function object:
16311649
self.emit(Instruction::MakeFunction);
1650+
1651+
self.in_loop = was_in_loop;
1652+
self.in_function_def = was_in_function_def;
1653+
self.in_async_func = was_in_async_func;
16321654
}
16331655
Comprehension { kind, generators } => {
16341656
self.compile_comprehension(kind, generators)?;

vm/src/builtins.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,7 @@ pub fn make_module(vm: &VirtualMachine, module: PyObjectRef) {
900900
"UserWarning" => ctx.exceptions.user_warning.clone(),
901901

902902
"KeyboardInterrupt" => ctx.exceptions.keyboard_interrupt.clone(),
903+
"GeneratorExit" => ctx.exceptions.generator_exit.clone(),
903904
"SystemExit" => ctx.exceptions.system_exit.clone(),
904905
});
905906
}

vm/src/exceptions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ pub struct ExceptionZoo {
271271
pub user_warning: PyClassRef,
272272

273273
pub keyboard_interrupt: PyClassRef,
274+
pub generator_exit: PyClassRef,
274275
pub system_exit: PyClassRef,
275276
}
276277

@@ -327,6 +328,7 @@ impl ExceptionZoo {
327328
let user_warning = create_type("UserWarning", &type_type, &warning);
328329

329330
let keyboard_interrupt = create_type("KeyboardInterrupt", &type_type, &base_exception_type);
331+
let generator_exit = create_type("GeneratorExit", &type_type, &base_exception_type);
330332
let system_exit = create_type("SystemExit", &type_type, &base_exception_type);
331333

332334
ExceptionZoo {
@@ -376,6 +378,7 @@ impl ExceptionZoo {
376378
reference_error,
377379
user_warning,
378380
keyboard_interrupt,
381+
generator_exit,
379382
system_exit,
380383
}
381384
}

vm/src/frame.rs

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@ use crate::bytecode;
88
use crate::function::PyFuncArgs;
99
use crate::obj::objbool;
1010
use crate::obj::objcode::PyCodeRef;
11+
use crate::obj::objcoroutine::PyCoroutine;
1112
use crate::obj::objdict::{PyDict, PyDictRef};
13+
use crate::obj::objgenerator::PyGenerator;
1214
use crate::obj::objiter;
1315
use crate::obj::objlist;
1416
use crate::obj::objslice::PySlice;
15-
use crate::obj::objstr;
16-
use crate::obj::objstr::PyString;
17+
use crate::obj::objstr::{self, PyString};
1718
use crate::obj::objtraceback::{PyTraceback, PyTracebackRef};
1819
use crate::obj::objtuple::PyTuple;
19-
use crate::obj::objtype;
20-
use crate::obj::objtype::PyClassRef;
20+
use crate::obj::objtype::{self, PyClassRef};
2121
use crate::pyobject::{
2222
IdProtocol, ItemProtocol, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
2323
};
@@ -230,7 +230,7 @@ impl Frame {
230230
exc_type: PyClassRef,
231231
exc_val: PyObjectRef,
232232
exc_tb: PyObjectRef,
233-
) -> PyResult {
233+
) -> PyResult<ExecutionResult> {
234234
if let bytecode::Instruction::YieldFrom = self.code.instructions[self.lasti.get()] {
235235
let coro = self.last_value();
236236
vm.call_method(
@@ -243,15 +243,15 @@ impl Frame {
243243
self.lasti.set(self.lasti.get() + 1);
244244
let val = objiter::stop_iter_value(vm, &err)?;
245245
self._send(coro, val, vm)
246-
})
246+
})?;
247+
Ok(ExecutionResult::Return(vm.get_none()))
247248
} else {
248249
let exception = vm.new_exception_obj(exc_type, vec![exc_val])?;
249250
match self.unwind_blocks(vm, UnwindReason::Raising { exception }) {
250251
Ok(None) => self.run(vm),
251252
Ok(Some(result)) => Ok(result),
252253
Err(exception) => Err(exception),
253254
}
254-
.and_then(|res| res.into_result(vm))
255255
}
256256
}
257257

@@ -465,6 +465,24 @@ impl Frame {
465465
self.push_value(iter_obj);
466466
Ok(None)
467467
}
468+
bytecode::Instruction::GetAwaitable => {
469+
let awaited_obj = self.pop_value();
470+
let awaitable = if awaited_obj.payload_is::<crate::obj::objcoroutine::PyCoroutine>()
471+
{
472+
awaited_obj
473+
} else {
474+
let await_method =
475+
vm.get_method_or_type_error(awaited_obj.clone(), "__await__", || {
476+
format!(
477+
"object {} can't be used in 'await' expression",
478+
awaited_obj.class().name,
479+
)
480+
})?;
481+
vm.invoke(&await_method, vec![])?
482+
};
483+
self.push_value(awaitable);
484+
Ok(None)
485+
}
468486
bytecode::Instruction::ForIter { target } => self.execute_for_iter(vm, *target),
469487
bytecode::Instruction::MakeFunction => self.execute_make_function(vm),
470488
bytecode::Instruction::CallFunction { typ } => self.execute_call_function(vm, typ),
@@ -1016,7 +1034,11 @@ impl Frame {
10161034
}
10171035

10181036
fn _send(&self, coro: PyObjectRef, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
1019-
if vm.is_none(&val) {
1037+
if let Some(gen) = coro.payload::<PyGenerator>() {
1038+
gen.send(val, vm)
1039+
} else if let Some(coro) = coro.payload::<PyCoroutine>() {
1040+
coro.send(val, vm)
1041+
} else if vm.is_none(&val) {
10201042
objiter::call_next(vm, &coro)
10211043
} else {
10221044
vm.call_method(&coro, "send", vec![val])

vm/src/obj/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pub mod objbytes;
88
pub mod objclassmethod;
99
pub mod objcode;
1010
pub mod objcomplex;
11+
pub mod objcoroutine;
1112
pub mod objdict;
1213
pub mod objellipsis;
1314
pub mod objenumerate;

0 commit comments

Comments
 (0)