Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,6 @@ def test_ixor(self):
else:
self.assertNotIn(c, self.s)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_inplace_on_self(self):
t = self.s.copy()
t |= t
Expand Down
97 changes: 77 additions & 20 deletions vm/src/builtins/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,13 @@ impl PySetInner {

fn intersection_update(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
others: impl std::iter::Iterator<Item = PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
let mut temp_inner = self.copy();
self.clear();
for iterable in others {
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?;
for item in iterable.iter(vm)? {
let obj = item?;
if temp_inner.contains(&obj, vm)? {
Expand All @@ -399,10 +400,11 @@ impl PySetInner {

fn difference_update(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
others: impl std::iter::Iterator<Item = PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?;
for item in iterable.iter(vm)? {
self.content.delete_if_exists(vm, &*item?)?;
}
Expand All @@ -412,11 +414,12 @@ impl PySetInner {

fn symmetric_difference_update(
&self,
others: impl std::iter::Iterator<Item = ArgIterable>,
others: impl std::iter::Iterator<Item = PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
for iterable in others {
// We want to remove duplicates in iterable
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?;
let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?;
for item in iterable_set.elements() {
self.content.delete_or_insert(vm, &item, ())?;
Expand Down Expand Up @@ -676,49 +679,103 @@ impl PySet {

#[pymethod]
fn intersection_update(
&self,
others: PosArgs<ArgIterable>,
zelf: PyRef<Self>,
others: PosArgs<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner.intersection_update(others.into_iter(), vm)?;
let arguments = others.into_vec();
if arguments.len() == 1 {
if let Some(iterable) = arguments.first() {
if zelf.is(iterable) {
return Ok(());
}
}
Comment on lines +688 to +692
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check is not only adaptable for len() == 1 case but also do for every argument.

For example, how about s = set(['a']); s.intersection_update(s, s, s, s, s, s)?

} else if arguments.is_empty() {
return Ok(());
}
zelf.inner.intersection_update(arguments.into_iter(), vm)?;
Ok(())
}

#[pymethod(magic)]
fn iand(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?;
fn iand(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if !set.class().is(vm.ctx.types.set_type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about frozenset? if frozenset is allowed, maybe AnySet check?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, frozen sets are allowed here. Similarly for isub and ixor.

return Err(vm.new_type_error("Type error".to_owned()));
Copy link
Member

@DimitrisJim DimitrisJim Nov 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be "unsupported operand type(s) for +=: 'type1' and 'type2'", similarly for the rest of the type errors.

}
if zelf.is(&set) {
return Ok(zelf);
}
zelf.inner.intersection_update(std::iter::once(set), vm)?;
Ok(zelf)
}

#[pymethod]
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> {
self.inner.difference_update(others.into_iter(), vm)?;
fn difference_update(
zelf: PyRef<Self>,
others: PosArgs<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
let arguments = others.into_vec();
if arguments.len() == 1 {
if let Some(iterable) = arguments.first() {
if zelf.is(iterable) {
zelf.inner.clear();
return Ok(());
}
}
} else if arguments.is_empty() {
return Ok(());
}
zelf.inner.difference_update(arguments.into_iter(), vm)?;
Ok(())
}

#[pymethod(magic)]
fn isub(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
zelf.inner
.difference_update(set.into_iterable_iter(vm)?, vm)?;
fn isub(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if !set.class().is(vm.ctx.types.set_type) {
return Err(vm.new_type_error("Type error".to_owned()));
}
if zelf.is(&set) {
zelf.inner.clear();
return Ok(zelf);
}
zelf.inner.difference_update(std::iter::once(set), vm)?;
Ok(zelf)
}

#[pymethod]
fn symmetric_difference_update(
&self,
others: PosArgs<ArgIterable>,
zelf: PyRef<Self>,
others: PosArgs<PyObjectRef>,
vm: &VirtualMachine,
) -> PyResult<()> {
self.inner
.symmetric_difference_update(others.into_iter(), vm)?;
let arguments = others.into_vec();
if arguments.len() == 1 {
if let Some(iterable) = arguments.first() {
if zelf.is(iterable) {
zelf.inner.clear();
return Ok(());
}
}
} else if arguments.is_empty() {
return Ok(());
}
zelf.inner
.symmetric_difference_update(arguments.into_iter(), vm)?;
Ok(())
}

#[pymethod(magic)]
fn ixor(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
fn ixor(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> {
if !set.class().is(vm.ctx.types.set_type) {
return Err(vm.new_type_error("Type error".to_owned()));
}
if zelf.is(&set) {
zelf.inner.clear();
return Ok(zelf);
}
zelf.inner
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?;
.symmetric_difference_update(std::iter::once(set), vm)?;
Ok(zelf)
}

Expand Down