diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 8709948b92..072279ea3a 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -1761,8 +1761,6 @@ def test_tee_del_backward(self): del forward, backward raise - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_tee_reenter(self): class I: first = True diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index a018fe382d..63b2d390ca 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -1184,23 +1184,28 @@ mod decl { #[derive(Debug)] struct PyItertoolsTeeData { iterable: PyIter, - values: PyRwLock>, + values: PyMutex>, } impl PyItertoolsTeeData { fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult> { Ok(PyRc::new(Self { iterable, - values: PyRwLock::new(vec![]), + values: PyMutex::new(vec![]), })) } fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { - if self.values.read().len() == index { - let result = raise_if_stop!(self.iterable.next(vm)?); - self.values.write().push(result); + let Some(mut values) = self.values.try_lock() else { + return Err(vm.new_runtime_error("cannot re-enter the tee iterator")); + }; + + if values.len() == index { + let obj = raise_if_stop!(self.iterable.next(vm)?); + values.push(obj); } - Ok(PyIterReturn::Return(self.values.read()[index].clone())) + + Ok(PyIterReturn::Return(values[index].clone())) } }