Skip to content

Commit cb66cc3

Browse files
committed
Fix set inplace operation bug
1 parent fe75ed2 commit cb66cc3

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

vm/src/builtins/set.rs

+39-14
Original file line numberDiff line numberDiff line change
@@ -375,10 +375,11 @@ impl PySetInner {
375375

376376
fn difference_update(
377377
&self,
378-
others: impl std::iter::Iterator<Item = ArgIterable>,
378+
others: impl std::iter::Iterator<Item = PyObjectRef>,
379379
vm: &VirtualMachine,
380380
) -> PyResult<()> {
381381
for iterable in others {
382+
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?;
382383
for item in iterable.iter(vm)? {
383384
self.content.delete_if_exists(vm, &*item?)?;
384385
}
@@ -388,11 +389,12 @@ impl PySetInner {
388389

389390
fn symmetric_difference_update(
390391
&self,
391-
others: impl std::iter::Iterator<Item = ArgIterable>,
392+
others: impl std::iter::Iterator<Item = PyObjectRef>,
392393
vm: &VirtualMachine,
393394
) -> PyResult<()> {
394395
for iterable in others {
395396
// We want to remove duplicates in iterable
397+
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?;
396398
let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?;
397399
for item in iterable_set.elements() {
398400
self.content.delete_or_insert(vm, &item, ())?;
@@ -681,43 +683,66 @@ impl PySet {
681683
}
682684

683685
#[pymethod]
684-
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
685-
self.inner.difference_update(others.into_iter(), vm)?;
686+
fn difference_update(zelf: PyRef<Self>, others: PosArgs<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
687+
let arguments = others.into_vec();
688+
if arguments.len() == 1 {
689+
if let Some(iterable) = arguments.first() {
690+
if zelf.is(iterable) {
691+
zelf.inner.clear();
692+
return Ok(());
693+
}
694+
}
695+
} else if arguments.is_empty() {
696+
return Ok(());
697+
}
698+
zelf.inner.difference_update(arguments.into_iter(), vm)?;
686699
Ok(())
687700
}
688701

689702
#[pymethod(magic)]
690703
fn isub(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
704+
if !set.class().is(vm.ctx.types.set_type) {
705+
return Err(vm.new_type_error("Type error".to_owned()));
706+
}
691707
if zelf.is(&set) {
692708
zelf.inner.clear();
693709
return Ok(zelf);
694710
}
695-
let set = set.try_into_value::<AnySet>(vm)?;
696-
zelf.inner
697-
.difference_update(set.into_iterable_iter(vm)?, vm)?;
711+
zelf.inner.difference_update(std::iter::once(set), vm)?;
698712
Ok(zelf)
699713
}
700714

701715
#[pymethod]
702716
fn symmetric_difference_update(
703-
&self,
704-
others: PosArgs<ArgIterable>,
717+
zelf: PyRef<Self>,
718+
others: PosArgs<PyObjectRef>,
705719
vm: &VirtualMachine,
706720
) -> PyResult<()> {
707-
self.inner
708-
.symmetric_difference_update(others.into_iter(), vm)?;
721+
let arguments = others.into_vec();
722+
if arguments.len() == 1 {
723+
if let Some(iterable) = arguments.first() {
724+
if zelf.is(iterable) {
725+
zelf.inner.clear();
726+
return Ok(());
727+
}
728+
}
729+
} else if arguments.is_empty() {
730+
return Ok(());
731+
}
732+
zelf.inner.symmetric_difference_update(arguments.into_iter(), vm)?;
709733
Ok(())
710734
}
711735

712736
#[pymethod(magic)]
713737
fn ixor(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
738+
if !set.class().is(vm.ctx.types.set_type) {
739+
return Err(vm.new_type_error("Type error".to_owned()));
740+
}
714741
if zelf.is(&set) {
715742
zelf.inner.clear();
716743
return Ok(zelf);
717744
}
718-
let set = set.try_into_value::<AnySet>(vm)?;
719-
zelf.inner
720-
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
745+
zelf.inner.symmetric_difference_update(std::iter::once(set), vm)?;
721746
Ok(zelf)
722747
}
723748

0 commit comments

Comments
 (0)