@@ -3,7 +3,10 @@ pub(crate) use decl::make_module;
3
3
#[ pymodule( name = "itertools" ) ]
4
4
mod decl {
5
5
use crate :: {
6
- builtins:: { int, PyGenericAlias , PyInt , PyIntRef , PyList , PyTuple , PyTupleRef , PyTypeRef } ,
6
+ builtins:: {
7
+ int, tuple:: IntoPyTuple , PyGenericAlias , PyInt , PyIntRef , PyList , PyTuple , PyTupleRef ,
8
+ PyTypeRef ,
9
+ } ,
7
10
common:: {
8
11
lock:: { PyMutex , PyRwLock , PyRwLockWriteGuard } ,
9
12
rc:: PyRc ,
@@ -1308,6 +1311,7 @@ mod decl {
1308
1311
struct PyItertoolsCombinations {
1309
1312
pool : Vec < PyObjectRef > ,
1310
1313
indices : PyRwLock < Vec < usize > > ,
1314
+ result : PyRwLock < Option < Vec < PyObjectRef > > > ,
1311
1315
r : AtomicCell < usize > ,
1312
1316
exhausted : AtomicCell < bool > ,
1313
1317
}
@@ -1341,6 +1345,7 @@ mod decl {
1341
1345
PyItertoolsCombinations {
1342
1346
pool,
1343
1347
indices : PyRwLock :: new ( ( 0 ..r) . collect ( ) ) ,
1348
+ result : PyRwLock :: new ( None ) ,
1344
1349
r : AtomicCell :: new ( r) ,
1345
1350
exhausted : AtomicCell :: new ( r > n) ,
1346
1351
}
@@ -1350,7 +1355,36 @@ mod decl {
1350
1355
}
1351
1356
1352
1357
#[ pyclass( with( IterNext , Constructor ) ) ]
1353
- impl PyItertoolsCombinations { }
1358
+ impl PyItertoolsCombinations {
1359
+ #[ pymethod( magic) ]
1360
+ fn reduce ( zelf : PyRef < Self > , vm : & VirtualMachine ) -> PyTupleRef {
1361
+ let r = zelf. r . load ( ) ;
1362
+
1363
+ let class = zelf. class ( ) . to_owned ( ) ;
1364
+
1365
+ if zelf. exhausted . load ( ) {
1366
+ return vm. new_tuple ( (
1367
+ class,
1368
+ vm. new_tuple ( ( vm. ctx . empty_tuple . clone ( ) , vm. ctx . new_int ( r) ) ) ,
1369
+ ) ) ;
1370
+ }
1371
+
1372
+ let tup = vm. new_tuple ( ( zelf. pool . clone ( ) . into_pytuple ( vm) , vm. ctx . new_int ( r) ) ) ;
1373
+
1374
+ if zelf. result . read ( ) . is_none ( ) {
1375
+ vm. new_tuple ( ( class, tup) )
1376
+ } else {
1377
+ let mut indices: Vec < PyObjectRef > = Vec :: new ( ) ;
1378
+
1379
+ for item in & zelf. indices . read ( ) [ ..r] {
1380
+ indices. push ( vm. new_pyobj ( * item) ) ;
1381
+ }
1382
+
1383
+ vm. new_tuple ( ( class, tup, indices. into_pytuple ( vm) ) )
1384
+ }
1385
+ }
1386
+ }
1387
+
1354
1388
impl IterNextIterable for PyItertoolsCombinations { }
1355
1389
impl IterNext for PyItertoolsCombinations {
1356
1390
fn next ( zelf : & Py < Self > , vm : & VirtualMachine ) -> PyResult < PyIterReturn > {
@@ -1367,38 +1401,48 @@ mod decl {
1367
1401
return Ok ( PyIterReturn :: Return ( vm. new_tuple ( ( ) ) . into ( ) ) ) ;
1368
1402
}
1369
1403
1370
- let res = vm. ctx . new_tuple (
1371
- zelf. indices
1372
- . read ( )
1373
- . iter ( )
1374
- . map ( |& i| zelf. pool [ i] . clone ( ) )
1375
- . collect ( ) ,
1376
- ) ;
1404
+ let mut result_lock = zelf. result . write ( ) ;
1405
+ let result = if let Some ( ref mut result) = * result_lock {
1406
+ let mut indices = zelf. indices . write ( ) ;
1377
1407
1378
- let mut indices = zelf. indices . write ( ) ;
1408
+ // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1409
+ let mut idx = r as isize - 1 ;
1410
+ while idx >= 0 && indices[ idx as usize ] == idx as usize + n - r {
1411
+ idx -= 1 ;
1412
+ }
1379
1413
1380
- // Scan indices right-to-left until finding one that is not at its maximum (i + n - r).
1381
- let mut idx = r as isize - 1 ;
1382
- while idx >= 0 && indices[ idx as usize ] == idx as usize + n - r {
1383
- idx -= 1 ;
1384
- }
1414
+ // If no suitable index is found, then the indices are all at
1415
+ // their maximum value and we're done.
1416
+ if idx < 0 {
1417
+ zelf. exhausted . store ( true ) ;
1418
+ return Ok ( PyIterReturn :: StopIteration ( None ) ) ;
1419
+ } else {
1420
+ // Increment the current index which we know is not at its
1421
+ // maximum. Then move back to the right setting each index
1422
+ // to its lowest possible value (one higher than the index
1423
+ // to its left -- this maintains the sort order invariant).
1424
+ indices[ idx as usize ] += 1 ;
1425
+ for j in idx as usize + 1 ..r {
1426
+ indices[ j] = indices[ j - 1 ] + 1 ;
1427
+ }
1385
1428
1386
- // If no suitable index is found, then the indices are all at
1387
- // their maximum value and we're done.
1388
- if idx < 0 {
1389
- zelf. exhausted . store ( true ) ;
1390
- } else {
1391
- // Increment the current index which we know is not at its
1392
- // maximum. Then move back to the right setting each index
1393
- // to its lowest possible value (one higher than the index
1394
- // to its left -- this maintains the sort order invariant).
1395
- indices[ idx as usize ] += 1 ;
1396
- for j in idx as usize + 1 ..r {
1397
- indices[ j] = indices[ j - 1 ] + 1 ;
1429
+ // Update the result tuple for the new indices
1430
+ // starting with i, the leftmost index that changed
1431
+ for i in idx as usize ..r {
1432
+ let index = indices[ i] ;
1433
+ let elem = & zelf. pool [ index] ;
1434
+ result[ i] = elem. to_owned ( ) ;
1435
+ }
1436
+
1437
+ result. to_vec ( )
1398
1438
}
1399
- }
1439
+ } else {
1440
+ let res = zelf. pool [ 0 ..r] . to_vec ( ) ;
1441
+ * result_lock = Some ( res. clone ( ) ) ;
1442
+ res
1443
+ } ;
1400
1444
1401
- Ok ( PyIterReturn :: Return ( res . into ( ) ) )
1445
+ Ok ( PyIterReturn :: Return ( vm . ctx . new_tuple ( result ) . into ( ) ) )
1402
1446
}
1403
1447
}
1404
1448
0 commit comments