diff --git a/tests/snippets/subtraction.py b/tests/snippets/subtraction.py new file mode 100644 index 0000000000..3471c1545a --- /dev/null +++ b/tests/snippets/subtraction.py @@ -0,0 +1,21 @@ +assert 5 - 3 == 2 + +class Complex(): + def __init__(self, real, imag): + self.real = real + self.imag = imag + + def __repr__(self): + return "Com" + str((self.real, self.imag)) + + def __sub__(self, other): + return Complex(self.real - other, self.imag) + + def __rsub__(self, other): + return Complex(other - self.real, -self.imag) + + def __eq__(self, other): + return self.real == other.real and self.imag == other.imag + +assert Complex(4, 5) - 3 == Complex(1, 5) +assert 7 - Complex(4, 5) == Complex(3, -5) diff --git a/vm/src/obj/objint.rs b/vm/src/obj/objint.rs index 03d303e713..69925959c4 100644 --- a/vm/src/obj/objint.rs +++ b/vm/src/obj/objint.rs @@ -226,7 +226,7 @@ fn int_sub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { .ctx .new_float(i.to_f64().unwrap() - objfloat::get_value(i2))) } else { - Err(vm.new_type_error(format!("Cannot substract {:?} and {:?}", i, i2))) + Err(vm.new_not_implemented_error(format!("Cannot substract {:?} and {:?}", i, i2))) } } diff --git a/vm/src/vm.rs b/vm/src/vm.rs index e7a110fce9..af47252cb0 100644 --- a/vm/src/vm.rs +++ b/vm/src/vm.rs @@ -78,6 +78,11 @@ impl VirtualMachine { self.new_exception(value_error, msg) } + pub fn new_not_implemented_error(&mut self, msg: String) -> PyObjectRef { + let value_error = self.ctx.exceptions.not_implemented_error.clone(); + self.new_exception(value_error, msg) + } + pub fn new_scope(&mut self, parent_scope: Option) -> PyObjectRef { // let parent_scope = self.current_frame_mut().locals.clone(); self.ctx.new_scope(parent_scope) @@ -437,23 +442,50 @@ impl VirtualMachine { } pub fn _sub(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult { - // Try __sub__, next __rsub__, next, give up - self.call_method(&a, "__sub__", vec![b]) - /* - if a.has_attr("__sub__") { - self.call_method(&a, "__sub__", vec![b]) - } else if b.has_attr("__rsub__") { - self.call_method(&b, "__rsub__", vec![a]) - } else { - // Cannot sub a and b - let a_type_name = objtype::get_type_name(&a.typ()); - let b_type_name = objtype::get_type_name(&b.typ()); - Err(self.new_type_error(format!( - "Unsupported operand types for '-': '{}' and '{}'", - a_type_name, b_type_name - ))) + // 1. Try __sub__, next __rsub__, next, give up + if let Ok(method) = self.get_method(a.clone(), "__sub__") { + match self.invoke( + method, + PyFuncArgs { + args: vec![b.clone()], + kwargs: vec![], + }, + ) { + Ok(value) => return Ok(value), + Err(err) => { + if !objtype::isinstance(&err, &self.ctx.exceptions.not_implemented_error) { + return Err(err); + } + } + } } - */ + + // 2. try __rsub__ + if let Ok(method) = self.get_method(b.clone(), "__rsub__") { + match self.invoke( + method, + PyFuncArgs { + args: vec![a.clone()], + kwargs: vec![], + }, + ) { + Ok(value) => return Ok(value), + Err(err) => { + if !objtype::isinstance(&err, &self.ctx.exceptions.not_implemented_error) { + return Err(err); + } + } + } + } + + // 3. It all failed :( + // Cannot sub a and b + let a_type_name = objtype::get_type_name(&a.typ()); + let b_type_name = objtype::get_type_name(&b.typ()); + Err(self.new_type_error(format!( + "Unsupported operand types for '-': '{}' and '{}'", + a_type_name, b_type_name + ))) } pub fn _add(&mut self, a: PyObjectRef, b: PyObjectRef) -> PyResult {