Skip to content

Commit 6ad0e54

Browse files
authored
Fix itertools.count to take PyNumber instead of PyInt (RustPython#3822)
1 parent 174c026 commit 6ad0e54

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

vm/src/stdlib/itertools.rs

+13-16
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@ mod decl {
1111
convert::ToPyObject,
1212
function::{ArgCallable, FuncArgs, OptionalArg, OptionalOption, PosArgs},
1313
identifier,
14-
protocol::{PyIter, PyIterReturn},
14+
protocol::{PyIter, PyIterReturn, PyNumber},
1515
stdlib::sys,
1616
types::{Constructor, IterNext, IterNextIterable},
1717
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, PyWeakRef, VirtualMachine,
1818
};
1919
use crossbeam_utils::atomic::AtomicCell;
20-
use num_bigint::BigInt;
21-
use num_traits::{One, Signed, ToPrimitive, Zero};
20+
use num_traits::{Signed, ToPrimitive};
2221
use std::fmt;
2322

2423
#[pyattr]
@@ -174,14 +173,14 @@ mod decl {
174173
#[pyclass(name = "count")]
175174
#[derive(Debug, PyPayload)]
176175
struct PyItertoolsCount {
177-
cur: PyRwLock<BigInt>,
178-
step: BigInt,
176+
cur: PyRwLock<PyObjectRef>,
177+
step: PyIntRef,
179178
}
180179

181180
#[derive(FromArgs)]
182181
struct CountNewArgs {
183182
#[pyarg(positional, optional)]
184-
start: OptionalArg<PyIntRef>,
183+
start: OptionalArg<PyObjectRef>,
185184

186185
#[pyarg(positional, optional)]
187186
step: OptionalArg<PyIntRef>,
@@ -195,14 +194,11 @@ mod decl {
195194
Self::Args { start, step }: Self::Args,
196195
vm: &VirtualMachine,
197196
) -> PyResult {
198-
let start = match start.into_option() {
199-
Some(int) => int.as_bigint().clone(),
200-
None => BigInt::zero(),
201-
};
202-
let step = match step.into_option() {
203-
Some(int) => int.as_bigint().clone(),
204-
None => BigInt::one(),
205-
};
197+
let start: PyObjectRef = start.into_option().unwrap_or_else(|| vm.new_pyobj(0));
198+
let step: PyIntRef = step.into_option().unwrap_or_else(|| vm.new_pyref(1));
199+
if !PyNumber::check(&start, vm) {
200+
return Err(vm.new_value_error("a number is require".to_owned()));
201+
}
206202

207203
PyItertoolsCount {
208204
cur: PyRwLock::new(start),
@@ -219,7 +215,7 @@ mod decl {
219215
// if (lz->cnt == PY_SSIZE_T_MAX)
220216
// return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step);
221217
#[pymethod(magic)]
222-
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (BigInt,)) {
218+
fn reduce(zelf: PyRef<Self>) -> (PyTypeRef, (PyObjectRef,)) {
223219
(zelf.class().clone(), (zelf.cur.read().clone(),))
224220
}
225221

@@ -234,8 +230,9 @@ mod decl {
234230
impl IterNext for PyItertoolsCount {
235231
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
236232
let mut cur = zelf.cur.write();
233+
let step = zelf.step.clone();
237234
let result = cur.clone();
238-
*cur += &zelf.step;
235+
*cur = vm._iadd(&*cur, step.as_object())?;
239236
Ok(PyIterReturn::Return(result.to_pyobject(vm)))
240237
}
241238
}

0 commit comments

Comments
 (0)