diff --git a/tests/snippets/bools.py b/tests/snippets/bools.py index e9d41c5d58..b3f52df062 100644 --- a/tests/snippets/bools.py +++ b/tests/snippets/bools.py @@ -142,4 +142,39 @@ def __len__(self): with assertRaises(TypeError): - bool(TestLenThrowError()) \ No newline at end of file + bool(TestLenThrowError()) + +# Verify that TypeError occurs when bad things are returned +# from __bool__(). This isn't really a bool test, but +# it's related. +def check(o): + with assertRaises(TypeError): + bool(o) + +class Foo(object): + def __bool__(self): + return self +check(Foo()) + +class Bar(object): + def __bool__(self): + return "Yes" +check(Bar()) + +class Baz(int): + def __bool__(self): + return self +check(Baz()) + +# __bool__() must return a bool not an int +class Spam(int): + def __bool__(self): + return 1 +check(Spam()) + +class Eggs: + def __len__(self): + return -1 + +with assertRaises(ValueError): + bool(Eggs()) \ No newline at end of file diff --git a/vm/src/obj/objbool.rs b/vm/src/obj/objbool.rs index 91d47c6266..eef179353d 100644 --- a/vm/src/obj/objbool.rs +++ b/vm/src/obj/objbool.rs @@ -1,3 +1,4 @@ +use num_bigint::Sign; use num_traits::Zero; use crate::function::PyFuncArgs; @@ -29,22 +30,30 @@ pub fn boolval(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { // If descriptor returns Error, propagate it further let method = method_or_err?; let bool_obj = vm.invoke(&method, PyFuncArgs::default())?; - match bool_obj.payload::() { - Some(int_obj) => !int_obj.as_bigint().is_zero(), - None => { - return Err(vm.new_type_error(format!( - "__bool__ should return bool, returned type {}", - bool_obj.class().name - ))) - } + if !objtype::isinstance(&bool_obj, &vm.ctx.bool_type()) { + return Err(vm.new_type_error(format!( + "__bool__ should return bool, returned type {}", + bool_obj.class().name + ))); } + + get_value(&bool_obj) } None => match vm.get_method(obj.clone(), "__len__") { Some(method_or_err) => { let method = method_or_err?; let bool_obj = vm.invoke(&method, PyFuncArgs::default())?; match bool_obj.payload::() { - Some(int_obj) => !int_obj.as_bigint().is_zero(), + Some(int_obj) => { + let len_val = int_obj.as_bigint(); + if len_val.sign() == Sign::Minus { + return Err( + vm.new_value_error("__len__() should return >= 0".to_string()) + ); + } + + !len_val.is_zero() + } None => { return Err(vm.new_type_error(format!( "{} object cannot be interpreted as integer",