From aaf18058f825c6a9328b9d62f2eb70b0674d2285 Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Thu, 10 Jul 2025 15:27:41 +0900 Subject: [PATCH 1/2] fix(itertools): add re-entrancy guard to tee object --- Lib/test/test_itertools.py | 2 -- vm/src/stdlib/itertools.rs | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 8709948b92..072279ea3a 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1761,8 +1761,6 @@ def test_tee_del_backward(self): del forward, backward raise - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_tee_reenter(self): class I: first = True diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index a018fe382d..cc36c2e14b 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1185,6 +1185,7 @@ mod decl { struct PyItertoolsTeeData { iterable: PyIter, values: PyRwLock>, + locked: AtomicCell, } impl PyItertoolsTeeData { @@ -1192,14 +1193,23 @@ mod decl { Ok(PyRc::new(Self { iterable, values: PyRwLock::new(vec![]), + locked: AtomicCell::new(false), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { if self.values.read().len() == index { - let result = raise_if_stop!(self.iterable.next(vm)?); - self.values.write().push(result); + if self.locked.swap(true) { + return Err(vm.new_runtime_error("cannot re-enter the tee iterator")); + } + + let result = self.iterable.next(vm); + self.locked.store(false); + + let obj = raise_if_stop!(result?); + self.values.write().push(obj); } + Ok(PyIterReturn::Return(self.values.read()[index].clone())) } } From 17011c45cd2361cc80bc42f8c0a0d7de41f39746 Mon Sep 17 00:00:00 2001 From: Jiseok CHOI Date: Thu, 10 Jul 2025 17:37:18 +0900 Subject: [PATCH 2/2] apply feedback PyRwLock -> PyMutex & remove AtomicCell lock field --- vm/src/stdlib/itertools.rs | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index cc36c2e14b..63b2d390ca 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1184,33 +1184,28 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { iterable: PyIter, - values: PyRwLock>, - locked: AtomicCell, + values: PyMutex>, } impl PyItertoolsTeeData { fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult> { Ok(PyRc::new(Self { iterable, - values: PyRwLock::new(vec![]), - locked: AtomicCell::new(false), + values: PyMutex::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.read().len() == index { - if self.locked.swap(true) { - return Err(vm.new_runtime_error("cannot re-enter the tee iterator")); - } - - let result = self.iterable.next(vm); - self.locked.store(false); + let Some(mut values) = self.values.try_lock() else { + return Err(vm.new_runtime_error("cannot re-enter the tee iterator")); + }; - let obj = raise_if_stop!(result?); - self.values.write().push(obj); + if values.len() == index { + let obj = raise_if_stop!(self.iterable.next(vm)?); + values.push(obj); } - Ok(PyIterReturn::Return(self.values.read()[index].clone())) + Ok(PyIterReturn::Return(values[index].clone())) } }