diff --git a/vm/src/protocol/iter.rs b/vm/src/protocol/iter.rs index 34f1b01ca8..d05bd26a8a 100644 --- a/vm/src/protocol/iter.rs +++ b/vm/src/protocol/iter.rs @@ -276,3 +276,20 @@ where (self.length_hint.unwrap_or(0), self.length_hint) } } + +/// Macro to handle `PyIterReturn` values in iterator implementations. +/// +/// Extracts the object from `PyIterReturn::Return(obj)` or performs early return +/// for `PyIterReturn::StopIteration(v)`. This macro should only be used within +/// functions that return `PyResult`. +#[macro_export] +macro_rules! raise_if_stop { + ($input:expr) => { + match $input { + $crate::protocol::PyIterReturn::Return(obj) => obj, + $crate::protocol::PyIterReturn::StopIteration(v) => { + return Ok($crate::protocol::PyIterReturn::StopIteration(v)) + } + } + }; +} diff --git a/vm/src/stdlib/itertools.rs b/vm/src/stdlib/itertools.rs index df6d9487c8..e8f5d21c54 100644 --- a/vm/src/stdlib/itertools.rs +++ b/vm/src/stdlib/itertools.rs @@ -18,6 +18,7 @@ mod decl { function::{ArgCallable, ArgIntoBool, FuncArgs, OptionalArg, OptionalOption, PosArgs}, identifier, protocol::{PyIter, PyIterReturn, PyNumber}, + raise_if_stop, stdlib::sys, types::{Constructor, IterNext, Iterable, Representable, SelfIter}, }; @@ -41,7 +42,7 @@ mod decl { #[pyslot] fn slot_new(cls: PyTypeRef, args: FuncArgs, vm: &VirtualMachine) -> PyResult { let args_list = PyList::from(args.args); - PyItertoolsChain { + Self { source: PyRwLock::new(Some(args_list.to_pyobject(vm).get_iter(vm)?)), active: PyRwLock::new(None), } @@ -91,17 +92,18 @@ mod decl { fn __setstate__(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { let args = state.as_slice(); if args.is_empty() { - let msg = String::from("function takes at least 1 arguments (0 given)"); - return Err(vm.new_type_error(msg)); + return Err(vm.new_type_error("function takes at least 1 arguments (0 given)")); } if args.len() > 2 { - let msg = format!("function takes at most 2 arguments ({} given)", args.len()); - return Err(vm.new_type_error(msg)); + return Err(vm.new_type_error(format!( + "function takes at most 2 arguments ({} given)", + args.len() + ))); } let source = &args[0]; if args.len() == 1 { if !PyIter::check(source.as_ref()) { - return Err(vm.new_type_error(String::from("Arguments must be iterators."))); + return Err(vm.new_type_error("Arguments must be iterators.")); } *zelf.source.write() = source.to_owned().try_into_value(vm)?; return Ok(()); @@ -109,7 +111,7 @@ mod decl { let active = &args[1]; if !PyIter::check(source.as_ref()) || !PyIter::check(active.as_ref()) { - return Err(vm.new_type_error(String::from("Arguments must be iterators."))); + return Err(vm.new_type_error("Arguments must be iterators.")); } let mut source_lock = zelf.source.write(); let mut active_lock = zelf.active.write(); @@ -215,10 +217,7 @@ mod decl { impl IterNext for PyItertoolsCompress { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { loop { - let sel_obj = match zelf.selectors.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let sel_obj = raise_if_stop!(zelf.selectors.next(vm)?); let verdict = sel_obj.clone().try_to_bool(vm)?; let data_obj = zelf.data.next(vm)?; @@ -387,7 +386,7 @@ mod decl { } None => None, }; - PyItertoolsRepeat { object, times } + Self { object, times } .into_ref_with_type(vm, cls) .map(Into::into) } @@ -466,7 +465,7 @@ mod decl { Self::Args { function, iterable }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsStarmap { function, iterable } + Self { function, iterable } .into_ref_with_type(vm, cls) .map(Into::into) } @@ -527,7 +526,7 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsTakewhile { + Self { predicate, iterable, stop_flag: AtomicCell::new(false), @@ -569,10 +568,7 @@ mod decl { } // might be StopIteration or anything else, which is propagated upwards - let obj = match zelf.iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let obj = raise_if_stop!(zelf.iterable.next(vm)?); let predicate = &zelf.predicate; let verdict = predicate.call((obj.clone(),), vm)?; @@ -614,7 +610,7 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsDropwhile { + Self { predicate, iterable, start_flag: AtomicCell::new(false), @@ -634,6 +630,7 @@ mod decl { (zelf.start_flag.load() as _), ) } + #[pymethod] fn __setstate__( zelf: PyRef, @@ -656,12 +653,7 @@ mod decl { if !zelf.start_flag.load() { loop { - let obj = match iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)); - } - }; + let obj = raise_if_stop!(iterable.next(vm)?); let pred = predicate.clone(); let pred_value = pred.invoke((obj.clone(),), vm)?; if !pred_value.try_to_bool(vm)? { @@ -734,7 +726,7 @@ mod decl { Self::Args { iterable, key }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsGroupBy { + Self { iterable, key_func: key.flatten(), state: PyMutex::new(GroupByState { @@ -755,10 +747,7 @@ mod decl { &self, vm: &VirtualMachine, ) -> PyResult> { - let new_value = match self.iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let new_value = raise_if_stop!(self.iterable.next(vm)?); let new_key = if let Some(ref kf) = self.key_func { kf.call((new_value.clone(),), vm)? } else { @@ -782,23 +771,13 @@ mod decl { let (value, key) = if let Some(old_key) = current_key { loop { - let (value, new_key) = match zelf.advance(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)); - } - }; + let (value, new_key) = raise_if_stop!(zelf.advance(vm)?); if !vm.bool_eq(&new_key, &old_key)? { break (value, new_key); } } } else { - match zelf.advance(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)); - } - } + raise_if_stop!(zelf.advance(vm)?) }; state = zelf.state.lock(); @@ -848,10 +827,7 @@ mod decl { state.current_key.as_ref().unwrap().clone() }; - let (value, key) = match zelf.groupby.advance(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let (value, key) = raise_if_stop!(zelf.groupby.advance(vm)?); if vm.bool_eq(&key, &old_key)? { Ok(PyIterReturn::Return(value)) } else { @@ -952,7 +928,7 @@ mod decl { let iter = iter.get_iter(vm)?; - PyItertoolsIslice { + Self { iterable: iter, cur: AtomicCell::new(0), next: AtomicCell::new(start), @@ -980,14 +956,16 @@ mod decl { fn __setstate__(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { let args = state.as_slice(); if args.len() != 1 { - let msg = format!("function takes exactly 1 argument ({} given)", args.len()); - return Err(vm.new_type_error(msg)); + return Err(vm.new_type_error(format!( + "function takes exactly 1 argument ({} given)", + args.len() + ))); } let cur = &args[0]; if let Ok(cur) = cur.try_to_value(vm) { zelf.cur.store(cur); } else { - return Err(vm.new_type_error(String::from("Argument must be usize."))); + return Err(vm.new_type_error("Argument must be usize.")); } Ok(()) } @@ -1008,10 +986,7 @@ mod decl { } } - let obj = match zelf.iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let obj = raise_if_stop!(zelf.iterable.next(vm)?); zelf.cur.fetch_add(1); // TODO is this overflow check required? attempts to copy CPython. @@ -1049,7 +1024,7 @@ mod decl { }: Self::Args, vm: &VirtualMachine, ) -> PyResult { - PyItertoolsFilterFalse { + Self { predicate, iterable, } @@ -1077,10 +1052,7 @@ mod decl { let iterable = &zelf.iterable; loop { - let obj = match iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let obj = raise_if_stop!(iterable.next(vm)?); let pred_value = if vm.is_none(predicate) { obj.clone() } else { @@ -1117,7 +1089,7 @@ mod decl { type Args = AccumulateArgs; fn py_new(cls: PyTypeRef, args: AccumulateArgs, vm: &VirtualMachine) -> PyResult { - PyItertoolsAccumulate { + Self { iterable: args.iterable, bin_op: args.func.flatten(), initial: args.initial.flatten(), @@ -1192,21 +1164,11 @@ mod decl { let next_acc_value = match acc_value { None => match &zelf.initial { - None => match iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)); - } - }, + None => raise_if_stop!(iterable.next(vm)?), Some(obj) => obj.clone(), }, Some(value) => { - let obj = match iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => { - return Ok(PyIterReturn::StopIteration(v)); - } - }; + let obj = raise_if_stop!(iterable.next(vm)?); match &zelf.bin_op { None => vm._add(&value, &obj)?, Some(op) => op.call((value, obj), vm)?, @@ -1226,8 +1188,8 @@ mod decl { } impl PyItertoolsTeeData { - fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult> { - Ok(PyRc::new(PyItertoolsTeeData { + fn new(iterable: PyIter, _vm: &VirtualMachine) -> PyResult> { + Ok(PyRc::new(Self { iterable, values: PyRwLock::new(vec![]), })) @@ -1235,10 +1197,7 @@ mod decl { fn get_item(&self, vm: &VirtualMachine, index: usize) -> PyResult { if self.values.read().len() == index { - let result = match self.iterable.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let result = raise_if_stop!(self.iterable.next(vm)?); self.values.write().push(result); } Ok(PyIterReturn::Return(self.values.read()[index].clone())) @@ -1295,7 +1254,7 @@ mod decl { if iterator.class().is(PyItertoolsTee::class(&vm.ctx)) { return vm.call_special_method(&iterator, identifier!(vm, __copy__), ()); } - Ok(PyItertoolsTee { + Ok(Self { tee_data: PyItertoolsTeeData::new(iterator, vm)?, index: AtomicCell::new(0), } @@ -1314,10 +1273,7 @@ mod decl { impl SelfIter for PyItertoolsTee {} impl IterNext for PyItertoolsTee { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { - let value = match zelf.tee_data.get_item(vm, zelf.index.load())? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; + let value = raise_if_stop!(zelf.tee_data.get_item(vm, zelf.index.load())?); zelf.index.fetch_add(1); Ok(PyIterReturn::Return(value)) } @@ -1354,7 +1310,7 @@ mod decl { let l = pools.len(); - PyItertoolsProduct { + Self { pools, idxs: PyRwLock::new(vec![0; l]), cur: AtomicCell::new(l.wrapping_sub(1)), @@ -1394,8 +1350,7 @@ mod decl { fn __setstate__(zelf: PyRef, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> { let args = state.as_slice(); if args.len() != zelf.pools.len() { - let msg = "Invalid number of arguments".to_string(); - return Err(vm.new_type_error(msg)); + return Err(vm.new_type_error("Invalid number of arguments")); } let mut idxs: PyRwLockWriteGuard<'_, Vec> = zelf.idxs.write(); idxs.clear(); @@ -1642,7 +1597,7 @@ mod decl { let n = pool.len(); - PyItertoolsCombinationsWithReplacement { + Self { pool, indices: PyRwLock::new(vec![0; r]), r: AtomicCell::new(r), @@ -1860,7 +1815,7 @@ mod decl { fn py_new(cls: PyTypeRef, (iterators, args): Self::Args, vm: &VirtualMachine) -> PyResult { let fillvalue = args.fillvalue.unwrap_or_none(vm); let iterators = iterators.into_vec(); - PyItertoolsZipLongest { + Self { iterators, fillvalue: PyRwLock::new(fillvalue), } @@ -1975,11 +1930,7 @@ mod decl { Some(obj) => obj, }; - let new = match zelf.iterator.next(vm)? { - PyIterReturn::Return(obj) => obj, - PyIterReturn::StopIteration(v) => return Ok(PyIterReturn::StopIteration(v)), - }; - + let new = raise_if_stop!(zelf.iterator.next(vm)?); *zelf.old.write() = Some(new.clone()); Ok(PyIterReturn::Return(vm.new_tuple((old, new)).into()))