diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index cd5b49d16d..fea9a96a00 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -533,8 +533,6 @@ def test_compress(self): next(testIntermediate) self.assertEqual(list(op(testIntermediate)), list(result2)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_count(self): self.assertEqual(lzip('abc',count()), [('a', 0), ('b', 1), ('c', 2)]) self.assertEqual(lzip('abc',count(3)), [('a', 3), ('b', 4), ('c', 5)]) diff --git a/vm/src/builtins/complex.rs b/vm/src/builtins/complex.rs index 9e4f5af904..8805c62471 100644 --- a/vm/src/builtins/complex.rs +++ b/vm/src/builtins/complex.rs @@ -1,14 +1,15 @@ use super::{float, PyStr, PyType, PyTypeRef}; use crate::{ class::PyClassImpl, - convert::ToPyObject, + convert::{ToPyObject, ToPyResult}, function::{ OptionalArg, OptionalOption, PyArithmeticValue::{self, *}, PyComparisonValue, }, identifier, - types::{Comparable, Constructor, Hashable, PyComparisonOp}, + protocol::{PyNumber, PyNumberMethods}, + types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp}, AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine, }; use num_complex::Complex64; @@ -203,7 +204,7 @@ impl PyComplex { } } -#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor))] +#[pyimpl(flags(BASETYPE), with(Comparable, Hashable, Constructor, AsNumber))] impl PyComplex { #[pymethod(magic)] fn complex(zelf: PyRef, vm: &VirtualMachine) -> PyRef { @@ -419,6 +420,72 @@ impl Hashable for PyComplex { } } +impl AsNumber for PyComplex { + const AS_NUMBER: PyNumberMethods = PyNumberMethods { + add: Some(|number, other, vm| Self::number_complex_op(number, other, |a, b| a + b, vm)), + subtract: Some(|number, other, vm| { + Self::number_complex_op(number, other, |a, b| a - b, vm) + }), + multiply: Some(|number, other, vm| { + Self::number_complex_op(number, other, |a, b| a * b, vm) + }), + power: Some(|number, other, vm| Self::number_general_op(number, other, inner_pow, vm)), + negative: Some(|number, vm| { + let value = Self::number_downcast(number).value; + (-value).to_pyresult(vm) + }), + positive: Some(|number, vm| Self::number_complex(number, vm).to_pyresult(vm)), + absolute: Some(|number, vm| { + let value = Self::number_downcast(number).value; + value.norm().to_pyresult(vm) + }), + boolean: Some(|number, _vm| Ok(Self::number_downcast(number).value.is_zero())), + true_divide: Some(|number, other, vm| { + Self::number_general_op(number, other, inner_div, vm) + }), + ..PyNumberMethods::NOT_IMPLEMENTED + }; +} + +impl PyComplex { + fn number_general_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(Complex64, Complex64, &VirtualMachine) -> R, + R: ToPyResult, + { + if let (Some(a), Some(b)) = (number.obj.payload::(), other.payload::()) { + op(a.value, b.value, vm).to_pyresult(vm) + } else { + Ok(vm.ctx.not_implemented()) + } + } + + fn number_complex_op( + number: &PyNumber, + other: &PyObject, + op: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: FnOnce(Complex64, Complex64) -> Complex64, + { + Self::number_general_op(number, other, |a, b, _vm| op(a, b), vm) + } + + fn number_complex(number: &PyNumber, vm: &VirtualMachine) -> PyRef { + if let Some(zelf) = number.obj.downcast_ref_if_exact::(vm) { + zelf.to_owned() + } else { + vm.ctx.new_complex(Self::number_downcast(number).value) + } + } +} + #[derive(FromArgs)] pub struct ComplexArgs { #[pyarg(any, optional)]