diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index ebdf87fc1a..a8e1b52583 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -11,14 +11,13 @@ mod decl { convert::ToPyObject, function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs}, identifier, - protocol::{PyIter, PyIterReturn}, + protocol::{PyIter, PyIterReturn, PyNumber}, stdlib::sys, types::{Constructor, IterNext, IterNextIterable}, AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine, }; use crossbeam_utils::atomic::AtomicCell; - use num_bigint::BigInt; - use num_traits::{One, Signed, ToPrimitive, Zero}; + use num_traits::{Signed, ToPrimitive}; use std::fmt; #[pyattr] @@ -174,14 +173,14 @@ mod decl { #[pyclass(name = "count")] #[derive(Debug, PyPayload)] struct PyItertoolsCount { - cur: PyRwLock, - step: BigInt, + cur: PyRwLock, + step: PyIntRef, } #[derive(FromArgs)] struct CountNewArgs { #[pyarg(positional, optional)] - start: OptionalArg, + start: OptionalArg, #[pyarg(positional, optional)] step: OptionalArg, @@ -195,14 +194,11 @@ mod decl { Self::Args { start, step }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - let start = match start.into_option() { - Some(int) => int.as_bigint().clone(), - None => BigInt::zero(), - }; - let step = match step.into_option() { - Some(int) => int.as_bigint().clone(), - None => BigInt::one(), - }; + let start: PyObjectRef = start.into_option().unwrap_or_else(|| vm.new_pyobj(0)); + let step: PyIntRef = step.into_option().unwrap_or_else(|| vm.new_pyref(1)); + if !PyNumber::check(&start, vm) { + return Err(vm.new_value_error("a number is require".to_owned())); + } PyItertoolsCount { cur: PyRwLock::new(start), @@ -219,7 +215,7 @@ mod decl { // if (lz->cnt == PY_SSIZE_T_MAX) // return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step); #[pymethod(magic)] - fn reduce(zelf: PyRef) -> (PyTypeRef, (BigInt,)) { + fn reduce(zelf: PyRef) -> (PyTypeRef, (PyObjectRef,)) { (zelf.class().clone(), (zelf.cur.read().clone(),)) } @@ -234,8 +230,9 @@ mod decl { impl IterNext for PyItertoolsCount { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { let mut cur = zelf.cur.write(); + let step = zelf.step.clone(); let result = cur.clone(); - *cur += &zelf.step; + *cur = vm._iadd(&*cur, step.as_object())?; Ok(PyIterReturn::Return(result.to_pyobject(vm))) } }