Skip to content

Support recursion in JIT-ed functions #5473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions jit/src/instructions.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use super::{JitCompileError, JitSig, JitType};
use cranelift::codegen::ir::FuncRef;
use cranelift::prelude::*;
use num_traits::cast::ToPrimitive;
use rustpython_compiler_core::bytecode::{
Expand All @@ -6,8 +8,6 @@ use rustpython_compiler_core::bytecode::{
};
use std::collections::HashMap;

use super::{JitCompileError, JitSig, JitType};

#[repr(u16)]
enum CustomTrapCode {
/// Raised when shifting by a negative number
Expand All @@ -27,6 +27,7 @@ enum JitValue {
Bool(Value),
None,
Tuple(Vec<JitValue>),
FuncRef(FuncRef),
}

impl JitValue {
Expand All @@ -43,14 +44,14 @@ impl JitValue {
JitValue::Int(_) => Some(JitType::Int),
JitValue::Float(_) => Some(JitType::Float),
JitValue::Bool(_) => Some(JitType::Bool),
JitValue::None | JitValue::Tuple(_) => None,
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}

fn into_value(self) -> Option<Value> {
match self {
JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val),
JitValue::None | JitValue::Tuple(_) => None,
JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None,
}
}
}
Expand All @@ -68,6 +69,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
builder: &'a mut FunctionBuilder<'b>,
num_variables: usize,
arg_types: &[JitType],
ret_type: Option<JitType>,
entry_block: Block,
) -> FunctionCompiler<'a, 'b> {
let mut compiler = FunctionCompiler {
Expand All @@ -77,7 +79,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
label_to_block: HashMap::new(),
sig: JitSig {
args: arg_types.to_vec(),
ret: None,
ret: ret_type,
},
};
let params = compiler.builder.func.dfg.block_params(entry_block).to_vec();
Expand Down Expand Up @@ -132,7 +134,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
}
JitValue::Bool(val) => Ok(val),
JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)),
JitValue::Tuple(_) => Err(JitCompileError::NotSupported),
JitValue::Tuple(_) | JitValue::FuncRef(_) => Err(JitCompileError::NotSupported),
}
}

Expand All @@ -146,6 +148,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {

pub fn compile<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
) -> Result<(), JitCompileError> {
// TODO: figure out if this is sufficient -- previously individual labels were associated
Expand Down Expand Up @@ -177,7 +180,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
continue;
}

self.add_instruction(instruction, arg, &bytecode.constants)?;
self.add_instruction(func_ref, bytecode, instruction, arg)?;
}

Ok(())
Expand Down Expand Up @@ -229,9 +232,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {

pub fn add_instruction<C: bytecode::Constant>(
&mut self,
func_ref: FuncRef,
bytecode: &CodeObject<C>,
instruction: Instruction,
arg: OpArg,
constants: &[C],
) -> Result<(), JitCompileError> {
match instruction {
Instruction::ExtendedArg => Ok(()),
Expand Down Expand Up @@ -282,7 +286,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
self.store_variable(idx.get(arg), val)
}
Instruction::LoadConst { idx } => {
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
let val = self
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
self.stack.push(val);
Ok(())
}
Expand Down Expand Up @@ -311,7 +316,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
self.return_value(val)
}
Instruction::ReturnConst { idx } => {
let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?;
let val = self
.prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?;
self.return_value(val)
}
Instruction::CompareOperation { op, .. } => {
Expand Down Expand Up @@ -508,6 +514,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> {
// TODO: block support
Ok(())
}
Instruction::LoadGlobal(idx) => {
let name = &bytecode.names[idx.get(arg) as usize];

if name.as_ref() != bytecode.obj_name.as_ref() {
Err(JitCompileError::NotSupported)
} else {
self.stack.push(JitValue::FuncRef(func_ref));
Ok(())
}
}
Instruction::CallFunctionPositional { nargs } => {
let nargs = nargs.get(arg);

let mut args = Vec::new();
for _ in 0..nargs {
let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?;
args.push(arg.into_value().unwrap());
}

match self.stack.pop().ok_or(JitCompileError::BadBytecode)? {
JitValue::FuncRef(reference) => {
let call = self.builder.ins().call(reference, &args);
let returns = self.builder.inst_results(call);
self.stack.push(JitValue::Int(returns[0]));

Ok(())
}
_ => Err(JitCompileError::BadBytecode),
}
}
_ => Err(JitCompileError::NotSupported),
}
}
Expand Down
37 changes: 27 additions & 10 deletions jit/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl Jit {
&mut self,
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<(FuncId, JitSig), JitCompileError> {
for arg in args {
self.ctx
Expand All @@ -58,29 +59,44 @@ impl Jit {
.push(AbiParam::new(arg.to_cranelift()));
}

if ret.is_some() {
self.ctx
.func
.signature
.returns
.push(AbiParam::new(ret.clone().unwrap().to_cranelift()));
}

let id = self.module.declare_function(
&format!("jit_{}", bytecode.obj_name.as_ref()),
Linkage::Export,
&self.ctx.func.signature,
)?;

let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func);

let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let entry_block = builder.create_block();
builder.append_block_params_for_function_params(entry_block);
builder.switch_to_block(entry_block);

let sig = {
let mut compiler =
FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block);
let mut compiler = FunctionCompiler::new(
&mut builder,
bytecode.varnames.len(),
args,
ret,
entry_block,
);

compiler.compile(bytecode)?;
compiler.compile(func_ref, bytecode)?;

compiler.sig
};

builder.seal_all_blocks();
builder.finalize();

let id = self.module.declare_function(
&format!("jit_{}", bytecode.obj_name.as_ref()),
Linkage::Export,
&self.ctx.func.signature,
)?;

self.module.define_function(id, &mut self.ctx)?;

self.module.clear_context(&mut self.ctx);
Expand All @@ -92,10 +108,11 @@ impl Jit {
pub fn compile<C: bytecode::Constant>(
bytecode: &bytecode::CodeObject<C>,
args: &[JitType],
ret: Option<JitType>,
) -> Result<CompiledCode, JitCompileError> {
let mut jit = Jit::new();

let (id, sig) = jit.build_function(bytecode, args)?;
let (id, sig) = jit.build_function(bytecode, args, ret)?;

jit.module.finalize_definitions();

Expand Down
12 changes: 11 additions & 1 deletion jit/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,17 @@ impl Function {
arg_types.push(arg_type);
}

rustpython_jit::compile(&self.code, &arg_types).expect("Compile failure")
let ret_type = match self.annotations.get("return") {
Some(StackValue::String(annotation)) => match annotation.as_str() {
"int" => Some(JitType::Int),
"float" => Some(JitType::Float),
"bool" => Some(JitType::Bool),
_ => panic!("Unrecognised jit type"),
},
_ => None,
};

rustpython_jit::compile(&self.code, &arg_types, ret_type).expect("Compile failure")
}
}

Expand Down
12 changes: 12 additions & 0 deletions jit/tests/misc_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,15 @@ fn test_unpack_tuple() {
assert_eq!(unpack_tuple(0, 1), Ok(1));
assert_eq!(unpack_tuple(1, 2), Ok(2));
}

#[test]
fn test_recursive_fib() {
let fib = jit_function! { fib(n: i64) -> i64 => r##"
def fib(n: int) -> int:
if n == 0 or n == 1:
return 1
return fib(n-1) + fib(n-2)
"## };

assert_eq!(fib(10), Ok(89));
}
3 changes: 2 additions & 1 deletion vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,8 @@ impl PyFunction {
zelf.jitted_code
.get_or_try_init(|| {
let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?;
rustpython_jit::compile(&zelf.code.code, &arg_types)
let ret_type = jitfunc::jit_ret_type(&zelf, vm)?;
rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type)
.map_err(|err| jitfunc::new_jit_error(err.to_string(), vm))
})
.map(drop)
Expand Down
21 changes: 20 additions & 1 deletion vm/src/builtins/function/jitfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu
Ok(JitType::Bool)
} else {
Err(new_jit_error(
"Jit requires argument to be either int or float".to_owned(),
"Jit requires argument to be either int, float or bool".to_owned(),
vm,
))
}
Expand Down Expand Up @@ -106,6 +106,25 @@ pub fn get_jit_arg_types(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult
}
}

pub fn jit_ret_type(func: &Py<PyFunction>, vm: &VirtualMachine) -> PyResult<Option<JitType>> {
let func_obj: PyObjectRef = func.as_ref().to_owned();
let annotations = func_obj.get_attr("__annotations__", vm)?;
if vm.is_none(&annotations) {
Err(new_jit_error(
"Jitting function requires return type to have annotations".to_owned(),
vm,
))
} else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) {
if dict.contains_key("return", vm) {
get_jit_arg_type(&dict, "return", vm).map_or(Ok(None), |t| Ok(Some(t)))
} else {
Ok(None)
}
} else {
Err(vm.new_type_error("Function annotations aren't a dict".to_owned()))
}
}

fn get_jit_value(vm: &VirtualMachine, obj: &PyObject) -> Result<AbiValue, ArgsError> {
// This does exact type checks as subclasses of int/float can't be passed to jitted functions
let cls = obj.class();
Expand Down
Loading