Skip to content

Commit aaf1805

Browse files
committed
fix(itertools): add re-entrancy guard to tee object
1 parent 18d7c1b commit aaf1805

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

Lib/test/test_itertools.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,8 +1761,6 @@ def test_tee_del_backward(self):
17611761
del forward, backward
17621762
raise
17631763

1764-
# TODO: RUSTPYTHON
1765-
@unittest.expectedFailure
17661764
def test_tee_reenter(self):
17671765
class I:
17681766
first = True

vm/src/stdlib/itertools.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,21 +1185,31 @@ mod decl {
11851185
struct PyItertoolsTeeData {
11861186
iterable: PyIter,
11871187
values: PyRwLock<Vec<PyObjectRef>>,
1188+
locked: AtomicCell<bool>,
11881189
}
11891190

11901191
impl PyItertoolsTeeData {
11911192
fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult<PyRc<Self>> {
11921193
Ok(PyRc::new(Self {
11931194
iterable,
11941195
values: PyRwLock::new(vec![]),
1196+
locked: AtomicCell::new(false),
11951197
}))
11961198
}
11971199

11981200
fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult<PyIterReturn> {
11991201
if self.values.read().len() == index {
1200-
let result = raise_if_stop!(self.iterable.next(vm)?);
1201-
self.values.write().push(result);
1202+
if self.locked.swap(true) {
1203+
return Err(vm.new_runtime_error("cannot re-enter the tee iterator"));
1204+
}
1205+
1206+
let result = self.iterable.next(vm);
1207+
self.locked.store(false);
1208+
1209+
let obj = raise_if_stop!(result?);
1210+
self.values.write().push(obj);
12021211
}
1212+
12031213
Ok(PyIterReturn::Return(self.values.read()[index].clone()))
12041214
}
12051215
}

0 commit comments

Comments
 (0)