diff --git a/jit/src/instructions.rs b/jit/src/instructions.rs index 86dc58aca4..ef099026e8 100644 --- a/jit/src/instructions.rs +++ b/jit/src/instructions.rs @@ -315,18 +315,34 @@ impl<'a, 'b> FunctionCompiler<'a, 'b> { let b = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; let a = self.stack.pop().ok_or(JitCompileError::BadBytecode)?; + let a_type: Option = a.to_jit_type(); + let b_type: Option = b.to_jit_type(); + match (a, b) { - (JitValue::Int(a), JitValue::Int(b)) => { + (JitValue::Int(a), JitValue::Int(b)) + | (JitValue::Bool(a), JitValue::Bool(b)) + | (JitValue::Bool(a), JitValue::Int(b)) + | (JitValue::Int(a), JitValue::Bool(b)) => { + let operand_one = match a_type.unwrap() { + JitType::Bool => self.builder.ins().uextend(types::I64, a), + _ => a, + }; + + let operand_two = match b_type.unwrap() { + JitType::Bool => self.builder.ins().uextend(types::I64, b), + _ => b, + }; + let cond = match op { ComparisonOperator::Equal => IntCC::Equal, ComparisonOperator::NotEqual => IntCC::NotEqual, ComparisonOperator::Less => IntCC::SignedLessThan, ComparisonOperator::LessOrEqual => IntCC::SignedLessThanOrEqual, ComparisonOperator::Greater => IntCC::SignedGreaterThan, - ComparisonOperator::GreaterOrEqual => IntCC::SignedLessThanOrEqual, + ComparisonOperator::GreaterOrEqual => IntCC::SignedGreaterThanOrEqual, }; - let val = self.builder.ins().icmp(cond, a, b); + let val = self.builder.ins().icmp(cond, operand_one, operand_two); // TODO: Remove this `bint` in cranelift 0.90 as icmp now returns i8 self.stack .push(JitValue::Bool(self.builder.ins().bint(types::I8, val))); diff --git a/jit/tests/bool_tests.rs b/jit/tests/bool_tests.rs index ed25ddb83f..191993938d 100644 --- a/jit/tests/bool_tests.rs +++ b/jit/tests/bool_tests.rs @@ -50,3 +50,153 @@ fn test_if_not() { assert_eq!(if_not(true), Ok(1)); assert_eq!(if_not(false), Ok(0)); } + +#[test] +fn test_eq() { + let eq = jit_function! { eq(a:bool, b:bool) -> i64 => r##" + def eq(a: bool, b: bool): + if a == b: + return 1 + return 0 + "## }; + + assert_eq!(eq(false, false), Ok(1)); + assert_eq!(eq(true, true), Ok(1)); + assert_eq!(eq(false, true), Ok(0)); + assert_eq!(eq(true, false), Ok(0)); +} + +#[test] +fn test_eq_with_integers() { + let eq = jit_function! { eq(a:bool, b:i64) -> i64 => r##" + def eq(a: bool, b: int): + if a == b: + return 1 + return 0 + "## }; + + assert_eq!(eq(false, 0), Ok(1)); + assert_eq!(eq(true, 1), Ok(1)); + assert_eq!(eq(false, 1), Ok(0)); + assert_eq!(eq(true, 0), Ok(0)); +} + +#[test] +fn test_gt() { + let gt = jit_function! { gt(a:bool, b:bool) -> i64 => r##" + def gt(a: bool, b: bool): + if a > b: + return 1 + return 0 + "## }; + + assert_eq!(gt(false, false), Ok(0)); + assert_eq!(gt(true, true), Ok(0)); + assert_eq!(gt(false, true), Ok(0)); + assert_eq!(gt(true, false), Ok(1)); +} + +#[test] +fn test_gt_with_integers() { + let gt = jit_function! { gt(a:i64, b:bool) -> i64 => r##" + def gt(a: int, b: bool): + if a > b: + return 1 + return 0 + "## }; + + assert_eq!(gt(0, false), Ok(0)); + assert_eq!(gt(1, true), Ok(0)); + assert_eq!(gt(0, true), Ok(0)); + assert_eq!(gt(1, false), Ok(1)); +} + +#[test] +fn test_lt() { + let lt = jit_function! { lt(a:bool, b:bool) -> i64 => r##" + def lt(a: bool, b: bool): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(false, false), Ok(0)); + assert_eq!(lt(true, true), Ok(0)); + assert_eq!(lt(false, true), Ok(1)); + assert_eq!(lt(true, false), Ok(0)); +} + +#[test] +fn test_lt_with_integers() { + let lt = jit_function! { lt(a:i64, b:bool) -> i64 => r##" + def lt(a: int, b: bool): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(0, false), Ok(0)); + assert_eq!(lt(1, true), Ok(0)); + assert_eq!(lt(0, true), Ok(1)); + assert_eq!(lt(1, false), Ok(0)); +} + +#[test] +fn test_gte() { + let gte = jit_function! { gte(a:bool, b:bool) -> i64 => r##" + def gte(a: bool, b: bool): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(false, false), Ok(1)); + assert_eq!(gte(true, true), Ok(1)); + assert_eq!(gte(false, true), Ok(0)); + assert_eq!(gte(true, false), Ok(1)); +} + +#[test] +fn test_gte_with_integers() { + let gte = jit_function! { gte(a:bool, b:i64) -> i64 => r##" + def gte(a: bool, b: int): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(false, 0), Ok(1)); + assert_eq!(gte(true, 1), Ok(1)); + assert_eq!(gte(false, 1), Ok(0)); + assert_eq!(gte(true, 0), Ok(1)); +} + +#[test] +fn test_lte() { + let lte = jit_function! { lte(a:bool, b:bool) -> i64 => r##" + def lte(a: bool, b: bool): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(false, false), Ok(1)); + assert_eq!(lte(true, true), Ok(1)); + assert_eq!(lte(false, true), Ok(1)); + assert_eq!(lte(true, false), Ok(0)); +} + +#[test] +fn test_lte_with_integers() { + let lte = jit_function! { lte(a:bool, b:i64) -> i64 => r##" + def lte(a: bool, b: int): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(false, 0), Ok(1)); + assert_eq!(lte(true, 1), Ok(1)); + assert_eq!(lte(false, 1), Ok(1)); + assert_eq!(lte(true, 0), Ok(0)); +} diff --git a/jit/tests/int_tests.rs b/jit/tests/int_tests.rs index 314849a06e..9ce3f3b4a6 100644 --- a/jit/tests/int_tests.rs +++ b/jit/tests/int_tests.rs @@ -160,6 +160,52 @@ fn test_gt() { assert_eq!(gt(1, -1), Ok(1)); } +#[test] +fn test_lt() { + let lt = jit_function! { lt(a:i64, b:i64) -> i64 => r##" + def lt(a: int, b: int): + if a < b: + return 1 + return 0 + "## }; + + assert_eq!(lt(-1, -5), Ok(0)); + assert_eq!(lt(10, 0), Ok(0)); + assert_eq!(lt(0, 1), Ok(1)); + assert_eq!(lt(-10, -1), Ok(1)); + assert_eq!(lt(100, 100), Ok(0)); +} + +#[test] +fn test_gte() { + let gte = jit_function! { gte(a:i64, b:i64) -> i64 => r##" + def gte(a: int, b: int): + if a >= b: + return 1 + return 0 + "## }; + + assert_eq!(gte(-64, -64), Ok(1)); + assert_eq!(gte(100, -1), Ok(1)); + assert_eq!(gte(1, 2), Ok(0)); + assert_eq!(gte(1, 0), Ok(1)); +} + +#[test] +fn test_lte() { + let lte = jit_function! { lte(a:i64, b:i64) -> i64 => r##" + def lte(a: int, b: int): + if a <= b: + return 1 + return 0 + "## }; + + assert_eq!(lte(-100, -100), Ok(1)); + assert_eq!(lte(-100, 100), Ok(1)); + assert_eq!(lte(10, 1), Ok(0)); + assert_eq!(lte(0, -2), Ok(0)); +} + #[test] fn test_minus() { let minus = jit_function! { minus(a:i64) -> i64 => r##"