Skip to content

Commit 53db70e

Browse files
authored
Support recursion in JIT-ed functions (RustPython#5473)
1 parent 76c699b commit 53db70e

File tree

6 files changed

+118
-23
lines changed

6 files changed

+118
-23
lines changed

jit/src/instructions.rs

+46-10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use super::{JitCompileError, JitSig, JitType};
2+
use cranelift::codegen::ir::FuncRef;
13
use cranelift::prelude::*;
24
use num_traits::cast::ToPrimitive;
35
use rustpython_compiler_core::bytecode::{
@@ -6,8 +8,6 @@ use rustpython_compiler_core::bytecode::{
68
};
79
use std::collections::HashMap;
810

9-
use super::{JitCompileError, JitSig, JitType};
10-
1111
#[repr(u16)]
1212
enum CustomTrapCode {
1313
/// Raised when shifting by a negative number
@@ -27,6 +27,7 @@ enum JitValue {
2727
Bool(Value),
2828
None,
2929
Tuple(Vec<JitValue>),
30+
FuncRef(FuncRef),
3031
}
3132

3233
impl JitValue {
@@ -43,14 +44,14 @@ impl JitValue {
4344
JitValue::Int(_) => Some(JitType::Int),
4445
JitValue::Float(_) => Some(JitType::Float),
4546
JitValue::Bool(_) => Some(JitType::Bool),
46-
JitValue::None | JitValue::Tuple(_) => None,
47+
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
4748
}
4849
}
4950

5051
fn into_value(self) -> Option<Value> {
5152
match self {
5253
JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val),
53-
JitValue::None | JitValue::Tuple(_) => None,
54+
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
5455
}
5556
}
5657
}
@@ -68,6 +69,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
6869
builder: &'a mut FunctionBuilder<'b>,
6970
num_variables: usize,
7071
arg_types: &[JitType],
72+
ret_type: Option<JitType>,
7173
entry_block: Block,
7274
) -> FunctionCompiler<'a, 'b> {
7375
let mut compiler = FunctionCompiler {
@@ -77,7 +79,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
7779
label_to_block: HashMap::new(),
7880
sig: JitSig {
7981
args: arg_types.to_vec(),
80-
ret: None,
82+
ret: ret_type,
8183
},
8284
};
8385
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
@@ -132,7 +134,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
132134
}
133135
JitValue::Bool(val) => Ok(val),
134136
JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)),
135-
JitValue::Tuple(_) => Err(JitCompileError::NotSupported),
137+
JitValue::Tuple(_) | JitValue::FuncRef(_) => Err(JitCompileError::NotSupported),
136138
}
137139
}
138140

@@ -146,6 +148,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
146148

147149
pub fn compile<C: bytecode::Constant>(
148150
&mut self,
151+
func_ref: FuncRef,
149152
bytecode: &CodeObject<C>,
150153
) -> Result<(), JitCompileError> {
151154
// TODO: figure out if this is sufficient -- previously individual labels were associated
@@ -177,7 +180,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
177180
continue;
178181
}
179182

180-
self.add_instruction(instruction, arg, &bytecode.constants)?;
183+
self.add_instruction(func_ref, bytecode, instruction, arg)?;
181184
}
182185

183186
Ok(())
@@ -229,9 +232,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
229232

230233
pub fn add_instruction<C: bytecode::Constant>(
231234
&mut self,
235+
func_ref: FuncRef,
236+
bytecode: &CodeObject<C>,
232237
instruction: Instruction,
233238
arg: OpArg,
234-
constants: &[C],
235239
) -> Result<(), JitCompileError> {
236240
match instruction {
237241
Instruction::ExtendedArg => Ok(()),
@@ -282,7 +286,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
282286
self.store_variable(idx.get(arg), val)
283287
}
284288
Instruction::LoadConst { idx } => {
285-
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
289+
let val = self
290+
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
286291
self.stack.push(val);
287292
Ok(())
288293
}
@@ -311,7 +316,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
311316
self.return_value(val)
312317
}
313318
Instruction::ReturnConst { idx } => {
314-
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
319+
let val = self
320+
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
315321
self.return_value(val)
316322
}
317323
Instruction::CompareOperation { op, .. } => {
@@ -508,6 +514,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
508514
// TODO: block support
509515
Ok(())
510516
}
517+
Instruction::LoadGlobal(idx) => {
518+
let name = &bytecode.names[idx.get(arg) as usize];
519+
520+
if name.as_ref() != bytecode.obj_name.as_ref() {
521+
Err(JitCompileError::NotSupported)
522+
} else {
523+
self.stack.push(JitValue::FuncRef(func_ref));
524+
Ok(())
525+
}
526+
}
527+
Instruction::CallFunctionPositional { nargs } => {
528+
let nargs = nargs.get(arg);
529+
530+
let mut args = Vec::new();
531+
for _ in 0..nargs {
532+
let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
533+
args.push(arg.into_value().unwrap());
534+
}
535+
536+
match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
537+
JitValue::FuncRef(reference) => {
538+
let call = self.builder.ins().call(reference, &args);
539+
let returns = self.builder.inst_results(call);
540+
self.stack.push(JitValue::Int(returns[0]));
541+
542+
Ok(())
543+
}
544+
_ => Err(JitCompileError::BadBytecode),
545+
}
546+
}
511547
_ => Err(JitCompileError::NotSupported),
512548
}
513549
}

jit/src/lib.rs

+27-10
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ impl Jit {
4949
&mut self,
5050
bytecode: &bytecode::CodeObject<C>,
5151
args: &[JitType],
52+
ret: Option<JitType>,
5253
) -> Result<(FuncId, JitSig), JitCompileError> {
5354
for arg in args {
5455
self.ctx
@@ -58,29 +59,44 @@ impl Jit {
5859
.push(AbiParam::new(arg.to_cranelift()));
5960
}
6061

62+
if ret.is_some() {
63+
self.ctx
64+
.func
65+
.signature
66+
.returns
67+
.push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
68+
}
69+
70+
let id = self.module.declare_function(
71+
&format!("jit_{}", bytecode.obj_name.as_ref()),
72+
Linkage::Export,
73+
&self.ctx.func.signature,
74+
)?;
75+
76+
let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);
77+
6178
let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
6279
let entry_block = builder.create_block();
6380
builder.append_block_params_for_function_params(entry_block);
6481
builder.switch_to_block(entry_block);
6582

6683
let sig = {
67-
let mut compiler =
68-
FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
84+
let mut compiler = FunctionCompiler::new(
85+
&mut builder,
86+
bytecode.varnames.len(),
87+
args,
88+
ret,
89+
entry_block,
90+
);
6991

70-
compiler.compile(bytecode)?;
92+
compiler.compile(func_ref, bytecode)?;
7193

7294
compiler.sig
7395
};
7496

7597
builder.seal_all_blocks();
7698
builder.finalize();
7799

78-
let id = self.module.declare_function(
79-
&format!("jit_{}", bytecode.obj_name.as_ref()),
80-
Linkage::Export,
81-
&self.ctx.func.signature,
82-
)?;
83-
84100
self.module.define_function(id, &mut self.ctx)?;
85101

86102
self.module.clear_context(&mut self.ctx);
@@ -92,10 +108,11 @@ impl Jit {
92108
pub fn compile<C: bytecode::Constant>(
93109
bytecode: &bytecode::CodeObject<C>,
94110
args: &[JitType],
111+
ret: Option<JitType>,
95112
) -> Result<CompiledCode, JitCompileError> {
96113
let mut jit = Jit::new();
97114

98-
let (id, sig) = jit.build_function(bytecode, args)?;
115+
let (id, sig) = jit.build_function(bytecode, args, ret)?;
99116

100117
jit.module.finalize_definitions();
101118

jit/tests/common.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,17 @@ impl Function {
2727
arg_types.push(arg_type);
2828
}
2929

30-
rustpython_jit::compile(&self.code, &arg_types).expect("Compile failure")
30+
let ret_type = match self.annotations.get("return") {
31+
Some(StackValue::String(annotation)) => match annotation.as_str() {
32+
"int" => Some(JitType::Int),
33+
"float" => Some(JitType::Float),
34+
"bool" => Some(JitType::Bool),
35+
_ => panic!("Unrecognised jit type"),
36+
},
37+
_ => None,
38+
};
39+
40+
rustpython_jit::compile(&self.code, &arg_types, ret_type).expect("Compile failure")
3141
}
3242
}
3343

jit/tests/misc_tests.rs

+12
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,15 @@ fn test_unpack_tuple() {
113113
assert_eq!(unpack_tuple(0, 1), Ok(1));
114114
assert_eq!(unpack_tuple(1, 2), Ok(2));
115115
}
116+
117+
#[test]
118+
fn test_recursive_fib() {
119+
let fib = jit_function! { fib(n: i64) -> i64 => r##"
120+
def fib(n: int) -> int:
121+
if n == 0 or n == 1:
122+
return 1
123+
return fib(n-1) + fib(n-2)
124+
"## };
125+
126+
assert_eq!(fib(10), Ok(89));
127+
}

vm/src/builtins/function.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,8 @@ impl PyFunction {
506506
zelf.jitted_code
507507
.get_or_try_init(|| {
508508
let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?;
509-
rustpython_jit::compile(&zelf.code.code, &arg_types)
509+
let ret_type = jitfunc::jit_ret_type(&zelf, vm)?;
510+
rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type)
510511
.map_err(|err| jitfunc::new_jit_error(err.to_string(), vm))
511512
})
512513
.map(drop)

vm/src/builtins/function/jitfunc.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu
5252
Ok(JitType::Bool)
5353
} else {
5454
Err(new_jit_error(
55-
"Jit requires argument to be either int or float".to_owned(),
55+
"Jit requires argument to be either int, float or bool".to_owned(),
5656
vm,
5757
))
5858
}
@@ -106,6 +106,25 @@ pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult
106106
}
107107
}
108108

109+
pub fn jit_ret_type(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Option<JitType>> {
110+
let func_obj: PyObjectRef = func.as_ref().to_owned();
111+
let annotations = func_obj.get_attr("__annotations__", vm)?;
112+
if vm.is_none(&annotations) {
113+
Err(new_jit_error(
114+
"Jitting function requires return type to have annotations".to_owned(),
115+
vm,
116+
))
117+
} else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) {
118+
if dict.contains_key("return", vm) {
119+
get_jit_arg_type(&dict, "return", vm).map_or(Ok(None), |t| Ok(Some(t)))
120+
} else {
121+
Ok(None)
122+
}
123+
} else {
124+
Err(vm.new_type_error("Function annotations aren't a dict".to_owned()))
125+
}
126+
}
127+
109128
fn get_jit_value(vm: &VirtualMachine, obj: &PyObject) -> Result<AbiValue, ArgsError> {
110129
// This does exact type checks as subclasses of int/float can't be passed to jitted functions
111130
let cls = obj.class();

0 commit comments

Comments
 (0)