diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 24585eba3f..2d5e990dd0 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -1,3 +1,5 @@ +use super::{JitCompileError, JitSig, JitType}; +use cranelift::codegen::ir::FuncRef; use cranelift::prelude::*; use num_traits::cast::ToPrimitive; use rustpython_compiler_core::bytecode::{ @@ -6,8 +8,6 @@ use rustpython_compiler_core::bytecode::{ }; use std::collections::HashMap; -use super::{JitCompileError, JitSig, JitType}; - #[repr(u16)] enum CustomTrapCode { /// Raised when shifting by a negative number @@ -27,6 +27,7 @@ enum JitValue { Bool(Value), None, Tuple(Vec), + FuncRef(FuncRef), } impl JitValue { @@ -43,14 +44,14 @@ impl JitValue { JitValue::Int(_) => Some(JitType::Int), JitValue::Float(_) => Some(JitType::Float), JitValue::Bool(_) => Some(JitType::Bool), - JitValue::None | JitValue::Tuple(_) => None, + JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None, } } fn into_value(self) -> Option { match self { JitValue::Int(val) | JitValue::Float(val) | JitValue::Bool(val) => Some(val), - JitValue::None | JitValue::Tuple(_) => None, + JitValue::None | JitValue::Tuple(_) | JitValue::FuncRef(_) => None, } } } @@ -68,6 +69,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { builder: &'a mut FunctionBuilder<'b>, num_variables: usize, arg_types: &[JitType], + ret_type: Option, entry_block: Block, ) -> FunctionCompiler<'a, 'b> { let mut compiler = FunctionCompiler { @@ -77,7 +79,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { label_to_block: HashMap::new(), sig: JitSig { args: arg_types.to_vec(), - ret: None, + ret: ret_type, }, }; let params = compiler.builder.func.dfg.block_params(entry_block).to_vec(); @@ -132,7 +134,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { } JitValue::Bool(val) => Ok(val), JitValue::None => Ok(self.builder.ins().iconst(types::I8, 0)), - JitValue::Tuple(_) => Err(JitCompileError::NotSupported), + JitValue::Tuple(_) | JitValue::FuncRef(_) => Err(JitCompileError::NotSupported), } } @@ -146,6 +148,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { pub fn compile( &mut self, + func_ref: FuncRef, bytecode: &CodeObject, ) -> Result<(), JitCompileError> { // TODO: figure out if this is sufficient -- previously individual labels were associated @@ -177,7 +180,7 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { continue; } - self.add_instruction(instruction, arg, &bytecode.constants)?; + self.add_instruction(func_ref, bytecode, instruction, arg)?; } Ok(()) @@ -229,9 +232,10 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { pub fn add_instruction( &mut self, + func_ref: FuncRef, + bytecode: &CodeObject, instruction: Instruction, arg: OpArg, - constants: &[C], ) -> Result<(), JitCompileError> { match instruction { Instruction::ExtendedArg => Ok(()), @@ -282,7 +286,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { self.store_variable(idx.get(arg), val) } Instruction::LoadConst { idx } => { - let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?; + let val = self + .prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?; self.stack.push(val); Ok(()) } @@ -311,7 +316,8 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { self.return_value(val) } Instruction::ReturnConst { idx } => { - let val = self.prepare_const(constants[idx.get(arg) as usize].borrow_constant())?; + let val = self + .prepare_const(bytecode.constants[idx.get(arg) as usize].borrow_constant())?; self.return_value(val) } Instruction::CompareOperation { op, .. } => { @@ -508,6 +514,36 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { // TODO: block support Ok(()) } + Instruction::LoadGlobal(idx) => { + let name = &bytecode.names[idx.get(arg) as usize]; + + if name.as_ref() != bytecode.obj_name.as_ref() { + Err(JitCompileError::NotSupported) + } else { + self.stack.push(JitValue::FuncRef(func_ref)); + Ok(()) + } + } + Instruction::CallFunctionPositional { nargs } => { + let nargs = nargs.get(arg); + + let mut args = Vec::new(); + for _ in 0..nargs { + let arg = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; + args.push(arg.into_value().unwrap()); + } + + match self.stack.pop().ok_or(JitCompileError::BadBytecode)? { + JitValue::FuncRef(reference) => { + let call = self.builder.ins().call(reference, &args); + let returns = self.builder.inst_results(call); + self.stack.push(JitValue::Int(returns[0])); + + Ok(()) + } + _ => Err(JitCompileError::BadBytecode), + } + } _ => Err(JitCompileError::NotSupported), } } diff --git a/jit/src/lib.rs b/jit/src/lib.rs index 1d099635e2..99bfb45c78 100644 --- a/jit/src/lib.rs +++ b/jit/src/lib.rs @@ -49,6 +49,7 @@ impl Jit { &mut self, bytecode: &bytecode::CodeObject, args: &[JitType], + ret: Option, ) -> Result<(FuncId, JitSig), JitCompileError> { for arg in args { self.ctx @@ -58,16 +59,37 @@ impl Jit { .push(AbiParam::new(arg.to_cranelift())); } + if ret.is_some() { + self.ctx + .func + .signature + .returns + .push(AbiParam::new(ret.clone().unwrap().to_cranelift())); + } + + let id = self.module.declare_function( + &format!("jit_{}", bytecode.obj_name.as_ref()), + Linkage::Export, + &self.ctx.func.signature, + )?; + + let func_ref = self.module.declare_func_in_func(id, &mut self.ctx.func); + let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); let sig = { - let mut compiler = - FunctionCompiler::new(&mut builder, bytecode.varnames.len(), args, entry_block); + let mut compiler = FunctionCompiler::new( + &mut builder, + bytecode.varnames.len(), + args, + ret, + entry_block, + ); - compiler.compile(bytecode)?; + compiler.compile(func_ref, bytecode)?; compiler.sig }; @@ -75,12 +97,6 @@ impl Jit { builder.seal_all_blocks(); builder.finalize(); - let id = self.module.declare_function( - &format!("jit_{}", bytecode.obj_name.as_ref()), - Linkage::Export, - &self.ctx.func.signature, - )?; - self.module.define_function(id, &mut self.ctx)?; self.module.clear_context(&mut self.ctx); @@ -92,10 +108,11 @@ impl Jit { pub fn compile( bytecode: &bytecode::CodeObject, args: &[JitType], + ret: Option, ) -> Result { let mut jit = Jit::new(); - let (id, sig) = jit.build_function(bytecode, args)?; + let (id, sig) = jit.build_function(bytecode, args, ret)?; jit.module.finalize_definitions(); diff --git a/jit/tests/common.rs b/jit/tests/common.rs index ef5e8cfbc2..a2d4fc3bc1 100644 --- a/jit/tests/common.rs +++ b/jit/tests/common.rs @@ -27,7 +27,17 @@ impl Function { arg_types.push(arg_type); } - rustpython_jit::compile(&self.code, &arg_types).expect("Compile failure") + let ret_type = match self.annotations.get("return") { + Some(StackValue::String(annotation)) => match annotation.as_str() { + "int" => Some(JitType::Int), + "float" => Some(JitType::Float), + "bool" => Some(JitType::Bool), + _ => panic!("Unrecognised jit type"), + }, + _ => None, + }; + + rustpython_jit::compile(&self.code, &arg_types, ret_type).expect("Compile failure") } } diff --git a/jit/tests/misc_tests.rs b/jit/tests/misc_tests.rs index 7c1b6c3afb..7e1174da4a 100644 --- a/jit/tests/misc_tests.rs +++ b/jit/tests/misc_tests.rs @@ -113,3 +113,15 @@ fn test_unpack_tuple() { assert_eq!(unpack_tuple(0, 1), Ok(1)); assert_eq!(unpack_tuple(1, 2), Ok(2)); } + +#[test] +fn test_recursive_fib() { + let fib = jit_function! { fib(n: i64) -> i64 => r##" + def fib(n: int) -> int: + if n == 0 or n == 1: + return 1 + return fib(n-1) + fib(n-2) + "## }; + + assert_eq!(fib(10), Ok(89)); +} diff --git a/vm/src/builtins/function.rs b/vm/src/builtins/function.rs index e4cc49f9db..eb5a142f0c 100644 --- a/vm/src/builtins/function.rs +++ b/vm/src/builtins/function.rs @@ -506,7 +506,8 @@ impl PyFunction { zelf.jitted_code .get_or_try_init(|| { let arg_types = jitfunc::get_jit_arg_types(&zelf, vm)?; - rustpython_jit::compile(&zelf.code.code, &arg_types) + let ret_type = jitfunc::jit_ret_type(&zelf, vm)?; + rustpython_jit::compile(&zelf.code.code, &arg_types, ret_type) .map_err(|err| jitfunc::new_jit_error(err.to_string(), vm)) }) .map(drop) diff --git a/vm/src/builtins/function/jitfunc.rs b/vm/src/builtins/function/jitfunc.rs index fe73c3afc0..d46458fc65 100644 --- a/vm/src/builtins/function/jitfunc.rs +++ b/vm/src/builtins/function/jitfunc.rs @@ -52,7 +52,7 @@ fn get_jit_arg_type(dict: &PyDictRef, name: &str, vm: &VirtualMachine) -> PyResu Ok(JitType::Bool) } else { Err(new_jit_error( - "Jit requires argument to be either int or float".to_owned(), + "Jit requires argument to be either int, float or bool".to_owned(), vm, )) } @@ -106,6 +106,25 @@ pub fn get_jit_arg_types(func: &Py, vm: &VirtualMachine) -> PyResult } } +pub fn jit_ret_type(func: &Py, vm: &VirtualMachine) -> PyResult> { + let func_obj: PyObjectRef = func.as_ref().to_owned(); + let annotations = func_obj.get_attr("__annotations__", vm)?; + if vm.is_none(&annotations) { + Err(new_jit_error( + "Jitting function requires return type to have annotations".to_owned(), + vm, + )) + } else if let Ok(dict) = PyDictRef::try_from_object(vm, annotations) { + if dict.contains_key("return", vm) { + get_jit_arg_type(&dict, "return", vm).map_or(Ok(None), |t| Ok(Some(t))) + } else { + Ok(None) + } + } else { + Err(vm.new_type_error("Function annotations aren't a dict".to_owned())) + } +} + fn get_jit_value(vm: &VirtualMachine, obj: &PyObject) -> Result { // This does exact type checks as subclasses of int/float can't be passed to jitted functions let cls = obj.class();