diff --git a/extra_tests/snippets/operator_arithmetic.py b/extra_tests/snippets/operator_arithmetic.py index d44728d0ff..e1f2c4a38d 100644 --- a/extra_tests/snippets/operator_arithmetic.py +++ b/extra_tests/snippets/operator_arithmetic.py @@ -32,3 +32,23 @@ # Right shift raises value error on negative assert_raises(ValueError, lambda: 1 >> -1) + +# Bitwise or, and, xor raises value error on incompatible types +assert_raises(TypeError, lambda: "abc" | True) +assert_raises(TypeError, lambda: "abc" & True) +assert_raises(TypeError, lambda: "abc" ^ True) +assert_raises(TypeError, lambda: True | "abc") +assert_raises(TypeError, lambda: True & "abc") +assert_raises(TypeError, lambda: True ^ "abc") +assert_raises(TypeError, lambda: "abc" | 1.5) +assert_raises(TypeError, lambda: "abc" & 1.5) +assert_raises(TypeError, lambda: "abc" ^ 1.5) +assert_raises(TypeError, lambda: 1.5 | "abc") +assert_raises(TypeError, lambda: 1.5 & "abc") +assert_raises(TypeError, lambda: 1.5 ^ "abc") +assert_raises(TypeError, lambda: True | 1.5) +assert_raises(TypeError, lambda: True & 1.5) +assert_raises(TypeError, lambda: True ^ 1.5) +assert_raises(TypeError, lambda: 1.5 | True) +assert_raises(TypeError, lambda: 1.5 & True) +assert_raises(TypeError, lambda: 1.5 ^ True) diff --git a/vm/src/builtins/bool.rs b/vm/src/builtins/bool.rs index b3cb20238e..63bf6cff2d 100644 --- a/vm/src/builtins/bool.rs +++ b/vm/src/builtins/bool.rs @@ -127,8 +127,10 @@ impl PyBool { let lhs = get_value(&lhs); let rhs = get_value(&rhs); (lhs || rhs).to_pyobject(vm) + } else if let Some(lhs) = lhs.payload::<PyInt>() { + lhs.or(rhs, vm).to_pyobject(vm) } else { - get_py_int(&lhs).or(rhs, vm).to_pyobject(vm) + vm.ctx.not_implemented() } } @@ -141,8 +143,10 @@ impl PyBool { let lhs = get_value(&lhs); let rhs = get_value(&rhs); (lhs && rhs).to_pyobject(vm) + } else if let Some(lhs) = lhs.payload::<PyInt>() { + lhs.and(rhs, vm).to_pyobject(vm) } else { - get_py_int(&lhs).and(rhs, vm).to_pyobject(vm) + vm.ctx.not_implemented() } } @@ -155,8 +159,10 @@ impl PyBool { let lhs = get_value(&lhs); let rhs = get_value(&rhs); (lhs ^ rhs).to_pyobject(vm) + } else if let Some(lhs) = lhs.payload::<PyInt>() { + lhs.xor(rhs, vm).to_pyobject(vm) } else { - get_py_int(&lhs).xor(rhs, vm).to_pyobject(vm) + vm.ctx.not_implemented() } } } @@ -207,7 +213,3 @@ pub(crate) fn init(context: &Context) { pub(crate) fn get_value(obj: &PyObject) -> bool { !obj.payload::<PyInt>().unwrap().as_bigint().is_zero() } - -fn get_py_int(obj: &PyObject) -> &PyInt { - obj.payload::<PyInt>().unwrap() -}