@@ -11,14 +11,13 @@ mod decl {
11
11
convert:: ToPyObject ,
12
12
function:: { ArgCallable , FuncArgs , OptionalArg , OptionalOption , PosArgs } ,
13
13
identifier,
14
- protocol:: { PyIter , PyIterReturn } ,
14
+ protocol:: { PyIter , PyIterReturn , PyNumber } ,
15
15
stdlib:: sys,
16
16
types:: { Constructor , IterNext , IterNextIterable } ,
17
17
AsObject , Py , PyObjectRef , PyPayload , PyRef , PyResult , PyWeakRef , VirtualMachine ,
18
18
} ;
19
19
use crossbeam_utils:: atomic:: AtomicCell ;
20
- use num_bigint:: BigInt ;
21
- use num_traits:: { One , Signed , ToPrimitive , Zero } ;
20
+ use num_traits:: { Signed , ToPrimitive } ;
22
21
use std:: fmt;
23
22
24
23
#[ pyattr]
@@ -174,14 +173,14 @@ mod decl {
174
173
#[ pyclass( name = "count" ) ]
175
174
#[ derive( Debug , PyPayload ) ]
176
175
struct PyItertoolsCount {
177
- cur : PyRwLock < BigInt > ,
178
- step : BigInt ,
176
+ cur : PyRwLock < PyObjectRef > ,
177
+ step : PyIntRef ,
179
178
}
180
179
181
180
#[ derive( FromArgs ) ]
182
181
struct CountNewArgs {
183
182
#[ pyarg( positional, optional) ]
184
- start : OptionalArg < PyIntRef > ,
183
+ start : OptionalArg < PyObjectRef > ,
185
184
186
185
#[ pyarg( positional, optional) ]
187
186
step : OptionalArg < PyIntRef > ,
@@ -195,14 +194,11 @@ mod decl {
195
194
Self :: Args { start, step } : Self :: Args ,
196
195
vm : & VirtualMachine ,
197
196
) -> PyResult {
198
- let start = match start. into_option ( ) {
199
- Some ( int) => int. as_bigint ( ) . clone ( ) ,
200
- None => BigInt :: zero ( ) ,
201
- } ;
202
- let step = match step. into_option ( ) {
203
- Some ( int) => int. as_bigint ( ) . clone ( ) ,
204
- None => BigInt :: one ( ) ,
205
- } ;
197
+ let start: PyObjectRef = start. into_option ( ) . unwrap_or_else ( || vm. new_pyobj ( 0 ) ) ;
198
+ let step: PyIntRef = step. into_option ( ) . unwrap_or_else ( || vm. new_pyref ( 1 ) ) ;
199
+ if !PyNumber :: check ( & start, vm) {
200
+ return Err ( vm. new_value_error ( "a number is require" . to_owned ( ) ) ) ;
201
+ }
206
202
207
203
PyItertoolsCount {
208
204
cur : PyRwLock :: new ( start) ,
@@ -219,7 +215,7 @@ mod decl {
219
215
// if (lz->cnt == PY_SSIZE_T_MAX)
220
216
// return Py_BuildValue("0(00)", Py_TYPE(lz), lz->long_cnt, lz->long_step);
221
217
#[ pymethod( magic) ]
222
- fn reduce ( zelf : PyRef < Self > ) -> ( PyTypeRef , ( BigInt , ) ) {
218
+ fn reduce ( zelf : PyRef < Self > ) -> ( PyTypeRef , ( PyObjectRef , ) ) {
223
219
( zelf. class ( ) . clone ( ) , ( zelf. cur . read ( ) . clone ( ) , ) )
224
220
}
225
221
@@ -234,8 +230,9 @@ mod decl {
234
230
impl IterNext for PyItertoolsCount {
235
231
fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
236
232
let mut cur = zelf. cur . write ( ) ;
233
+ let step = zelf. step . clone ( ) ;
237
234
let result = cur. clone ( ) ;
238
- * cur += & zelf . step ;
235
+ * cur = vm . _iadd ( & * cur , step. as_object ( ) ) ? ;
239
236
Ok ( PyIterReturn :: Return ( result. to_pyobject ( vm) ) )
240
237
}
241
238
}
0 commit comments