Skip to content

Add complex.{__abs__, __eq__, __neg__} #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/snippets/builtin_complex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# __abs__

assert abs(complex(3, 4)) == 5
assert abs(complex(3, -4)) == 5
assert abs(complex(1.5, 2.5)) == 2.9154759474226504

# __eq__

assert complex(1, -1) == complex(1, -1)
assert complex(1, 0) == 1
assert not complex(1, 1) == 1
assert complex(1, 0) == 1.0
assert not complex(1, 1) == 1.0
assert not complex(1, 0) == 1.5
assert bool(complex(1, 0))
assert not complex(1, 2) == complex(1, 1)
# Currently broken - see issue #419
# assert complex(1, 2) != 'foo'
assert complex(1, 2).__eq__('foo') == NotImplemented

# __neg__

assert -complex(1, -1) == complex(-1, 1)
assert -complex(0, 0) == complex(0, 0)
3 changes: 3 additions & 0 deletions vm/src/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,9 @@ pub fn make_module(ctx: &PyContext) -> PyObjectRef {
ctx.set_attr(&py_mod, "type", ctx.type_type());
ctx.set_attr(&py_mod, "zip", ctx.zip_type());

// Constants
ctx.set_attr(&py_mod, "NotImplemented", ctx.not_implemented.clone());

// Exceptions:
ctx.set_attr(
&py_mod,
Expand Down
42 changes: 42 additions & 0 deletions vm/src/obj/objcomplex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use super::super::pyobject::{
};
use super::super::vm::VirtualMachine;
use super::objfloat;
use super::objint;
use super::objtype;
use num_complex::Complex64;
use num_traits::ToPrimitive;

pub fn init(context: &PyContext) {
let complex_type = &context.complex_type;
Expand All @@ -13,7 +15,10 @@ pub fn init(context: &PyContext) {
"Create a complex number from a real part and an optional imaginary part.\n\n\
This is equivalent to (real + imag*1j) where imag defaults to 0.";

context.set_attr(&complex_type, "__abs__", context.new_rustfunc(complex_abs));
context.set_attr(&complex_type, "__add__", context.new_rustfunc(complex_add));
context.set_attr(&complex_type, "__eq__", context.new_rustfunc(complex_eq));
context.set_attr(&complex_type, "__neg__", context.new_rustfunc(complex_neg));
context.set_attr(&complex_type, "__new__", context.new_rustfunc(complex_new));
context.set_attr(
&complex_type,
Expand Down Expand Up @@ -70,6 +75,13 @@ fn complex_new(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
))
}

fn complex_abs(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]);

let Complex64 { re, im } = get_value(zelf);
Ok(vm.ctx.new_float(re.hypot(im)))
}

fn complex_add(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
Expand All @@ -92,6 +104,36 @@ fn complex_conjugate(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
Ok(vm.ctx.new_complex(v1.conj()))
}

fn complex_eq(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(
vm,
args,
required = [(zelf, Some(vm.ctx.complex_type())), (other, None)]
);

let z = get_value(zelf);

let result = if objtype::isinstance(other, &vm.ctx.complex_type()) {
z == get_value(other)
} else if objtype::isinstance(other, &vm.ctx.int_type()) {
match objint::get_value(other).to_f64() {
Some(f) => z.im == 0.0f64 && z.re == f,
None => false,
}
} else if objtype::isinstance(other, &vm.ctx.float_type()) {
z.im == 0.0 && z.re == objfloat::get_value(other)
} else {
return Ok(vm.ctx.not_implemented());
};

Ok(vm.ctx.new_bool(result))
}

fn complex_neg(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(zelf, Some(vm.ctx.complex_type()))]);
Ok(vm.ctx.new_complex(-get_value(zelf)))
}

fn complex_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult {
arg_check!(vm, args, required = [(obj, Some(vm.ctx.complex_type()))]);
let v = get_value(obj);
Expand Down
12 changes: 12 additions & 0 deletions vm/src/pyobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub struct PyContext {
pub map_type: PyObjectRef,
pub memoryview_type: PyObjectRef,
pub none: PyObjectRef,
pub not_implemented: PyObjectRef,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I cannot be believe we did not implement this yet ;)

pub tuple_type: PyObjectRef,
pub set_type: PyObjectRef,
pub staticmethod_type: PyObjectRef,
Expand Down Expand Up @@ -226,6 +227,11 @@ impl PyContext {
create_type("NoneType", &type_type, &object_type, &dict_type),
);

let not_implemented = PyObject::new(
PyObjectPayload::NotImplemented,
create_type("NotImplementedType", &type_type, &object_type, &dict_type),
);

let true_value = PyObject::new(
PyObjectPayload::Integer { value: One::one() },
bool_type.clone(),
Expand Down Expand Up @@ -261,6 +267,7 @@ impl PyContext {
zip_type,
dict_type,
none,
not_implemented,
str_type,
range_type,
slice_type,
Expand Down Expand Up @@ -432,6 +439,9 @@ impl PyContext {
pub fn none(&self) -> PyObjectRef {
self.none.clone()
}
pub fn not_implemented(&self) -> PyObjectRef {
self.not_implemented.clone()
}
pub fn object(&self) -> PyObjectRef {
self.object.clone()
}
Expand Down Expand Up @@ -965,6 +975,7 @@ pub enum PyObjectPayload {
dict: PyObjectRef,
},
None,
NotImplemented,
Class {
name: String,
dict: RefCell<PyAttributes>,
Expand Down Expand Up @@ -1011,6 +1022,7 @@ impl fmt::Debug for PyObjectPayload {
PyObjectPayload::Module { .. } => write!(f, "module"),
PyObjectPayload::Scope { .. } => write!(f, "scope"),
PyObjectPayload::None => write!(f, "None"),
PyObjectPayload::NotImplemented => write!(f, "NotImplemented"),
PyObjectPayload::Class { ref name, .. } => write!(f, "class {:?}", name),
PyObjectPayload::Instance { .. } => write!(f, "instance"),
PyObjectPayload::RustFunction { .. } => write!(f, "rust function"),
Expand Down