diff --git a/tests/snippets/builtin_complex.py b/tests/snippets/builtin_complex.py new file mode 100644 index 0000000000..ef644b4da0 --- /dev/null +++ b/tests/snippets/builtin_complex.py @@ -0,0 +1,24 @@ +# __abs__ + +assert abs(complex(3, 4)) == 5 +assert abs(complex(3, -4)) == 5 +assert abs(complex(1.5, 2.5)) == 2.9154759474226504 + +# __eq__ + +assert complex(1, -1) == complex(1, -1) +assert complex(1, 0) == 1 +assert not complex(1, 1) == 1 +assert complex(1, 0) == 1.0 +assert not complex(1, 1) == 1.0 +assert not complex(1, 0) == 1.5 +assert bool(complex(1, 0)) +assert not complex(1, 2) == complex(1, 1) +# Currently broken - see issue #419 +# assert complex(1, 2) != 'foo' +assert complex(1, 2).__eq__('foo') == NotImplemented + +# __neg__ + +assert -complex(1, -1) == complex(-1, 1) +assert -complex(0, 0) == complex(0, 0) diff --git a/vm/src/builtins.rs b/vm/src/builtins.rs index 1321a96e46..33dde2256c 100644 --- a/vm/src/builtins.rs +++ b/vm/src/builtins.rs @@ -718,6 +718,9 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef { ctx.set_attr(&py_mod, "type", ctx.type_type()); ctx.set_attr(&py_mod, "zip", ctx.zip_type()); + // Constants + ctx.set_attr(&py_mod, "NotImplemented", ctx.not_implemented.clone()); + // Exceptions: ctx.set_attr( &py_mod, diff --git a/vm/src/obj/objcomplex.rs b/vm/src/obj/objcomplex.rs index 353b89b7f5..77f200f7ee 100644 --- a/vm/src/obj/objcomplex.rs +++ b/vm/src/obj/objcomplex.rs @@ -3,8 +3,10 @@ use super::super::pyobject::{ }; use super::super::vm::VirtualMachine; use super::objfloat; +use super::objint; use super::objtype; use num_complex::Complex64; +use num_traits::ToPrimitive; pub fn init(context: &PyContext) { let complex_type = &context.complex_type; @@ -13,7 +15,10 @@ pub fn init(context: &PyContext) { "Create a complex number from a real part and an optional imaginary part.\n\n\ This is equivalent to (real + imag*1j) where imag defaults to 0."; + context.set_attr(&complex_type, "__abs__", context.new_rustfunc(complex_abs)); context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add)); + context.set_attr(&complex_type, "__eq__", context.new_rustfunc(complex_eq)); + context.set_attr(&complex_type, "__neg__", context.new_rustfunc(complex_neg)); context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new)); context.set_attr( &complex_type, @@ -70,6 +75,13 @@ fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { )) } +fn complex_abs(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]); + + let Complex64 { re, im } = get_value(zelf); + Ok(vm.ctx.new_float(re.hypot(im))) +} + fn complex_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, @@ -92,6 +104,36 @@ fn complex_conjugate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { Ok(vm.ctx.new_complex(v1.conj())) } +fn complex_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!( + vm, + args, + required = [(zelf, Some(vm.ctx.complex_type())), (other, None)] + ); + + let z = get_value(zelf); + + let result = if objtype::isinstance(other, &vm.ctx.complex_type()) { + z == get_value(other) + } else if objtype::isinstance(other, &vm.ctx.int_type()) { + match objint::get_value(other).to_f64() { + Some(f) => z.im == 0.0f64 && z.re == f, + None => false, + } + } else if objtype::isinstance(other, &vm.ctx.float_type()) { + z.im == 0.0 && z.re == objfloat::get_value(other) + } else { + return Ok(vm.ctx.not_implemented()); + }; + + Ok(vm.ctx.new_bool(result)) +} + +fn complex_neg(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]); + Ok(vm.ctx.new_complex(-get_value(zelf))) +} + fn complex_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(obj, Some(vm.ctx.complex_type()))]); let v = get_value(obj); diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index c5968f1dc6..239e3bbd73 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -131,6 +131,7 @@ pub struct PyContext { pub map_type: PyObjectRef, pub memoryview_type: PyObjectRef, pub none: PyObjectRef, + pub not_implemented: PyObjectRef, pub tuple_type: PyObjectRef, pub set_type: PyObjectRef, pub staticmethod_type: PyObjectRef, @@ -226,6 +227,11 @@ impl PyContext { create_type("NoneType", &type_type, &object_type, &dict_type), ); + let not_implemented = PyObject::new( + PyObjectPayload::NotImplemented, + create_type("NotImplementedType", &type_type, &object_type, &dict_type), + ); + let true_value = PyObject::new( PyObjectPayload::Integer { value: One::one() }, bool_type.clone(), @@ -261,6 +267,7 @@ impl PyContext { zip_type, dict_type, none, + not_implemented, str_type, range_type, slice_type, @@ -432,6 +439,9 @@ impl PyContext { pub fn none(&self) -> PyObjectRef { self.none.clone() } + pub fn not_implemented(&self) -> PyObjectRef { + self.not_implemented.clone() + } pub fn object(&self) -> PyObjectRef { self.object.clone() } @@ -965,6 +975,7 @@ pub enum PyObjectPayload { dict: PyObjectRef, }, None, + NotImplemented, Class { name: String, dict: RefCell, @@ -1011,6 +1022,7 @@ impl fmt::Debug for PyObjectPayload { PyObjectPayload::Module { .. } => write!(f, "module"), PyObjectPayload::Scope { .. } => write!(f, "scope"), PyObjectPayload::None => write!(f, "None"), + PyObjectPayload::NotImplemented => write!(f, "NotImplemented"), PyObjectPayload::Class { ref name, .. } => write!(f, "class {:?}", name), PyObjectPayload::Instance { .. } => write!(f, "instance"), PyObjectPayload::RustFunction { .. } => write!(f, "rust function"),