From 4230efae2ec992e38bc942197b5413cf35d664ec Mon Sep 17 00:00:00 2001 From: Jeong Yunwon Date: Sat, 16 Jul 2022 01:55:49 +0900 Subject: [PATCH] SetIterable -> AnySet --- vm/src/builtins/set.rs | 295 +++++++++++++++++++++++++---------------- 1 file changed, 180 insertions(+), 115 deletions(-) diff --git a/vm/src/builtins/set.rs b/vm/src/builtins/set.rs index a637b3b6ef..606ecc1900 100644 --- a/vm/src/builtins/set.rs +++ b/vm/src/builtins/set.rs @@ -35,9 +35,39 @@ pub struct PySet { } impl PySet { + pub fn new_ref(ctx: &Context) -> PyRef { + // Initialized empty, as calling __hash__ is required for adding each object to the set + // which requires a VM context - this is done in the set code itself. + PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None) + } + pub fn elements(&self) -> Vec { self.inner.elements() } + + fn fold_op( + &self, + others: impl std::iter::Iterator, + op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, + vm: &VirtualMachine, + ) -> PyResult { + Ok(Self { + inner: self.inner.fold_op(others, op, vm)?, + }) + } + + fn op( + &self, + other: AnySet, + op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, + vm: &VirtualMachine, + ) -> PyResult { + Ok(Self { + inner: self + .inner + .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?, + }) + } } /// frozenset() -> empty frozenset object @@ -51,9 +81,46 @@ pub struct PyFrozenSet { } impl PyFrozenSet { + // Also used by ssl.rs windows. + pub fn from_iter( + vm: &VirtualMachine, + it: impl IntoIterator, + ) -> PyResult { + let inner = PySetInner::default(); + for elem in it { + inner.add(elem, vm)?; + } + // FIXME: empty set check + Ok(Self { inner }) + } + pub fn elements(&self) -> Vec { self.inner.elements() } + + fn fold_op( + &self, + others: impl std::iter::Iterator, + op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, + vm: &VirtualMachine, + ) -> PyResult { + Ok(Self { + inner: self.inner.fold_op(others, op, vm)?, + }) + } + + fn op( + &self, + other: AnySet, + op: fn(&PySetInner, ArgIterable, &VirtualMachine) -> PyResult, + vm: &VirtualMachine, + ) -> PyResult { + Ok(Self { + inner: self + .inner + .fold_op(std::iter::once(other.into_iterable(vm)?), op, vm)?, + }) + } } impl fmt::Debug for PySet { @@ -99,6 +166,19 @@ impl PySetInner { Ok(set) } + fn fold_op( + &self, + others: impl std::iter::Iterator, + op: fn(&Self, O, &VirtualMachine) -> PyResult, + vm: &VirtualMachine, + ) -> PyResult { + let mut res = self.copy(); + for other in others { + res = op(&res, other, vm)?; + } + Ok(res) + } + fn len(&self) -> usize { self.content.len() } @@ -259,7 +339,11 @@ impl PySetInner { } } - fn update(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult<()> { + fn update( + &self, + others: impl std::iter::Iterator, + vm: &VirtualMachine, + ) -> PyResult<()> { for iterable in others { for item in iterable.iter(vm)? { self.add(item?, vm)?; @@ -270,7 +354,7 @@ impl PySetInner { fn intersection_update( &self, - others: PosArgs, + others: impl std::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { let mut temp_inner = self.copy(); @@ -287,7 +371,11 @@ impl PySetInner { Ok(()) } - fn difference_update(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult<()> { + fn difference_update( + &self, + others: impl std::iter::Iterator, + vm: &VirtualMachine, + ) -> PyResult<()> { for iterable in others { for item in iterable.iter(vm)? { self.content.delete_if_exists(vm, &*item?)?; @@ -298,7 +386,7 @@ impl PySetInner { fn symmetric_difference_update( &self, - others: PosArgs, + others: impl std::iter::Iterator, vm: &VirtualMachine, ) -> PyResult<()> { for iterable in others { @@ -373,24 +461,6 @@ fn reduce_set( )) } -macro_rules! multi_args_set { - ($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{ - let mut res = $zelf.inner.copy(); - for other in $others { - res = res.$op(other, $vm)? - } - Ok(Self { inner: res }) - }}; -} - -impl PySet { - pub fn new_ref(ctx: &Context) -> PyRef { - // Initialized empty, as calling __hash__ is required for adding each object to the set - // which requires a VM context - this is done in the set code itself. - PyRef::new_ref(Self::default(), ctx.types.set_type.to_owned(), None) - } -} - #[pyimpl( with(Constructor, Initializer, AsSequence, Hashable, Comparable, Iterable), flags(BASETYPE) @@ -420,17 +490,17 @@ impl PySet { #[pymethod] fn union(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_set!(vm, others, self, union) + self.fold_op(others.into_iter(), PySetInner::union, vm) } #[pymethod] fn intersection(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_set!(vm, others, self, intersection) + self.fold_op(others.into_iter(), PySetInner::intersection, vm) } #[pymethod] fn difference(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_set!(vm, others, self, difference) + self.fold_op(others.into_iter(), PySetInner::difference, vm) } #[pymethod] @@ -439,7 +509,7 @@ impl PySet { others: PosArgs, vm: &VirtualMachine, ) -> PyResult { - multi_args_set!(vm, others, self, symmetric_difference) + self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm) } #[pymethod] @@ -460,10 +530,12 @@ impl PySet { #[pymethod(name = "__ror__")] #[pymethod(magic)] fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.union(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::union, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -472,10 +544,12 @@ impl PySet { #[pymethod(name = "__rand__")] #[pymethod(magic)] fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.intersection(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::intersection, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -483,10 +557,12 @@ impl PySet { #[pymethod(magic)] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.difference(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::difference, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -500,10 +576,12 @@ impl PySet { #[pymethod(name = "__rxor__")] #[pymethod(magic)] fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.symmetric_difference(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::symmetric_difference, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -557,14 +635,14 @@ impl PySet { } #[pymethod(magic)] - fn ior(zelf: PyRef, iterable: SetIterable, vm: &VirtualMachine) -> PyResult> { - zelf.inner.update(iterable.iterable, vm)?; + fn ior(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { + zelf.inner.update(set.into_iterable_iter(vm)?, vm)?; Ok(zelf) } #[pymethod] fn update(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult<()> { - self.inner.update(others, vm)?; + self.inner.update(others.into_iter(), vm)?; Ok(()) } @@ -574,33 +652,27 @@ impl PySet { others: PosArgs, vm: &VirtualMachine, ) -> PyResult<()> { - self.inner.intersection_update(others, vm)?; + self.inner.intersection_update(others.into_iter(), vm)?; Ok(()) } #[pymethod(magic)] - fn iand( - zelf: PyRef, - iterable: SetIterable, - vm: &VirtualMachine, - ) -> PyResult> { - zelf.inner.intersection_update(iterable.iterable, vm)?; + fn iand(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { + zelf.inner + .intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?; Ok(zelf) } #[pymethod] fn difference_update(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult<()> { - self.inner.difference_update(others, vm)?; + self.inner.difference_update(others.into_iter(), vm)?; Ok(()) } #[pymethod(magic)] - fn isub( - zelf: PyRef, - iterable: SetIterable, - vm: &VirtualMachine, - ) -> PyResult> { - zelf.inner.difference_update(iterable.iterable, vm)?; + fn isub(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { + zelf.inner + .difference_update(set.into_iterable_iter(vm)?, vm)?; Ok(zelf) } @@ -610,18 +682,15 @@ impl PySet { others: PosArgs, vm: &VirtualMachine, ) -> PyResult<()> { - self.inner.symmetric_difference_update(others, vm)?; + self.inner + .symmetric_difference_update(others.into_iter(), vm)?; Ok(()) } #[pymethod(magic)] - fn ixor( - zelf: PyRef, - iterable: SetIterable, - vm: &VirtualMachine, - ) -> PyResult> { + fn ixor(zelf: PyRef, set: AnySet, vm: &VirtualMachine) -> PyResult> { zelf.inner - .symmetric_difference_update(iterable.iterable, vm)?; + .symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?; Ok(zelf) } @@ -690,16 +759,6 @@ impl Iterable for PySet { } } -macro_rules! multi_args_frozenset { - ($vm:expr, $others:expr, $zelf:expr, $op:tt) => {{ - let mut res = $zelf.inner.copy(); - for other in $others { - res = res.$op(other, $vm)? - } - Ok(Self { inner: res }) - }}; -} - impl Constructor for PyFrozenSet { type Args = OptionalArg; @@ -733,19 +792,6 @@ impl Constructor for PyFrozenSet { with(Constructor, AsSequence, Hashable, Comparable, Iterable) )] impl PyFrozenSet { - // Also used by ssl.rs windows. - pub fn from_iter( - vm: &VirtualMachine, - it: impl IntoIterator, - ) -> PyResult { - let inner = PySetInner::default(); - for elem in it { - inner.add(elem, vm)?; - } - // FIXME: empty set check - Ok(Self { inner }) - } - #[pymethod(magic)] fn len(&self) -> usize { self.inner.len() @@ -775,17 +821,17 @@ impl PyFrozenSet { #[pymethod] fn union(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_frozenset!(vm, others, self, union) + self.fold_op(others.into_iter(), PySetInner::union, vm) } #[pymethod] fn intersection(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_frozenset!(vm, others, self, intersection) + self.fold_op(others.into_iter(), PySetInner::intersection, vm) } #[pymethod] fn difference(&self, others: PosArgs, vm: &VirtualMachine) -> PyResult { - multi_args_frozenset!(vm, others, self, difference) + self.fold_op(others.into_iter(), PySetInner::difference, vm) } #[pymethod] @@ -794,7 +840,7 @@ impl PyFrozenSet { others: PosArgs, vm: &VirtualMachine, ) -> PyResult { - multi_args_frozenset!(vm, others, self, symmetric_difference) + self.fold_op(others.into_iter(), PySetInner::symmetric_difference, vm) } #[pymethod] @@ -815,10 +861,12 @@ impl PyFrozenSet { #[pymethod(name = "__ror__")] #[pymethod(magic)] fn or(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.union(set_iter.iterable, vm)?, - )) + if let Ok(set) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + set, + PySetInner::union, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -827,10 +875,12 @@ impl PyFrozenSet { #[pymethod(name = "__rand__")] #[pymethod(magic)] fn and(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.intersection(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::intersection, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -838,10 +888,12 @@ impl PyFrozenSet { #[pymethod(magic)] fn sub(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.difference(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::difference, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -855,10 +907,12 @@ impl PyFrozenSet { #[pymethod(name = "__rxor__")] #[pymethod(magic)] fn xor(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult> { - if let Ok(set_iter) = SetIterable::try_from_object(vm, other) { - Ok(PyArithmeticValue::Implemented( - self.symmetric_difference(set_iter.iterable, vm)?, - )) + if let Ok(other) = AnySet::try_from_object(vm, other) { + Ok(PyArithmeticValue::Implemented(self.op( + other, + PySetInner::symmetric_difference, + vm, + )?)) } else { Ok(PyArithmeticValue::NotImplemented) } @@ -927,11 +981,24 @@ impl Iterable for PyFrozenSet { } } -struct SetIterable { - iterable: PosArgs, +struct AnySet { + object: PyObjectRef, +} + +impl AnySet { + fn into_iterable(self, vm: &VirtualMachine) -> PyResult { + self.object.try_into_value(vm) + } + + fn into_iterable_iter( + self, + vm: &VirtualMachine, + ) -> PyResult> { + Ok(std::iter::once(self.into_iterable(vm)?)) + } } -impl TryFromObject for SetIterable { +impl TryFromObject for AnySet { fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { let class = obj.class(); if class.fast_issubclass(vm.ctx.types.set_type) @@ -939,9 +1006,7 @@ impl TryFromObject for SetIterable { { // the class lease needs to be drop to be able to return the object drop(class); - Ok(SetIterable { - iterable: PosArgs::new(vec![ArgIterable::try_from_object(vm, obj)?]), - }) + Ok(AnySet { object: obj }) } else { Err(vm.new_type_error(format!("{} is not a subtype of set or frozenset", class))) }