@@ -537,6 +537,66 @@ impl PyItertoolsFilterFalse {
537
537
}
538
538
}
539
539
540
+ #[ pyclass]
541
+ #[ derive( Debug ) ]
542
+ struct PyItertoolsAccumulate {
543
+ iterable : PyObjectRef ,
544
+ binop : PyObjectRef ,
545
+ acc_value : RefCell < Option < PyObjectRef > > ,
546
+ }
547
+
548
+ impl PyValue for PyItertoolsAccumulate {
549
+ fn class ( vm : & VirtualMachine ) -> PyClassRef {
550
+ vm. class ( "itertools" , "accumulate" )
551
+ }
552
+ }
553
+
554
+ #[ pyimpl]
555
+ impl PyItertoolsAccumulate {
556
+ #[ pymethod( name = "__new__" ) ]
557
+ #[ allow( clippy:: new_ret_no_self) ]
558
+ fn new (
559
+ cls : PyClassRef ,
560
+ iterable : PyObjectRef ,
561
+ binop : OptionalArg < PyObjectRef > ,
562
+ vm : & VirtualMachine ,
563
+ ) -> PyResult < PyRef < PyItertoolsAccumulate > > {
564
+ let iter = get_iter ( vm, & iterable) ?;
565
+
566
+ PyItertoolsAccumulate {
567
+ iterable : iter,
568
+ binop : binop. unwrap_or_else ( || vm. get_none ( ) ) ,
569
+ acc_value : RefCell :: from ( Option :: None ) ,
570
+ }
571
+ . into_ref_with_type ( vm, cls)
572
+ }
573
+
574
+ #[ pymethod( name = "__next__" ) ]
575
+ fn next ( & self , vm : & VirtualMachine ) -> PyResult {
576
+ let iterable = & self . iterable ;
577
+ let obj = call_next ( vm, iterable) ?;
578
+
579
+ let next_acc_value = match & * self . acc_value . borrow ( ) {
580
+ Option :: None => obj. clone ( ) ,
581
+ Option :: Some ( value) => {
582
+ if self . binop . is ( & vm. get_none ( ) ) {
583
+ vm. _add ( value. clone ( ) , obj. clone ( ) ) ?
584
+ } else {
585
+ vm. invoke ( & self . binop , vec ! [ value. clone( ) , obj. clone( ) ] ) ?
586
+ }
587
+ }
588
+ } ;
589
+ self . acc_value . replace ( Option :: from ( next_acc_value. clone ( ) ) ) ;
590
+
591
+ Ok ( next_acc_value)
592
+ }
593
+
594
+ #[ pymethod( name = "__iter__" ) ]
595
+ fn iter ( zelf : PyRef < Self > , _vm : & VirtualMachine ) -> PyRef < Self > {
596
+ zelf
597
+ }
598
+ }
599
+
540
600
pub fn make_module ( vm : & VirtualMachine ) -> PyObjectRef {
541
601
let ctx = & vm. ctx ;
542
602
@@ -561,6 +621,9 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
561
621
let filterfalse = ctx. new_class ( "filterfalse" , ctx. object ( ) ) ;
562
622
PyItertoolsFilterFalse :: extend_class ( ctx, & filterfalse) ;
563
623
624
+ let accumulate = ctx. new_class ( "accumulate" , ctx. object ( ) ) ;
625
+ PyItertoolsAccumulate :: extend_class ( ctx, & accumulate) ;
626
+
564
627
py_module ! ( vm, "itertools" , {
565
628
"chain" => chain,
566
629
"count" => count,
@@ -570,5 +633,6 @@ pub fn make_module(vm: &VirtualMachine) -> PyObjectRef {
570
633
"takewhile" => takewhile,
571
634
"islice" => islice,
572
635
"filterfalse" => filterfalse,
636
+ "accumulate" => accumulate,
573
637
} )
574
638
}
0 commit comments