-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Fix some set operation to compare self with args #3912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -380,12 +380,13 @@ impl PySetInner { | |
|
||
fn intersection_update( | ||
&self, | ||
others: impl std::iter::Iterator<Item = ArgIterable>, | ||
others: impl std::iter::Iterator<Item = PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
let mut temp_inner = self.copy(); | ||
self.clear(); | ||
for iterable in others { | ||
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?; | ||
for item in iterable.iter(vm)? { | ||
let obj = item?; | ||
if temp_inner.contains(&obj, vm)? { | ||
|
@@ -399,10 +400,11 @@ impl PySetInner { | |
|
||
fn difference_update( | ||
&self, | ||
others: impl std::iter::Iterator<Item = ArgIterable>, | ||
others: impl std::iter::Iterator<Item = PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
for iterable in others { | ||
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?; | ||
for item in iterable.iter(vm)? { | ||
self.content.delete_if_exists(vm, &*item?)?; | ||
} | ||
|
@@ -412,11 +414,12 @@ impl PySetInner { | |
|
||
fn symmetric_difference_update( | ||
&self, | ||
others: impl std::iter::Iterator<Item = ArgIterable>, | ||
others: impl std::iter::Iterator<Item = PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
for iterable in others { | ||
// We want to remove duplicates in iterable | ||
let iterable = ArgIterable::<PyObjectRef>::try_from_object(vm, iterable)?; | ||
let iterable_set = Self::from_iter(iterable.iter(vm)?, vm)?; | ||
for item in iterable_set.elements() { | ||
self.content.delete_or_insert(vm, &item, ())?; | ||
|
@@ -676,49 +679,103 @@ impl PySet { | |
|
||
#[pymethod] | ||
fn intersection_update( | ||
&self, | ||
others: PosArgs<ArgIterable>, | ||
zelf: PyRef<Self>, | ||
others: PosArgs<PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
self.inner.intersection_update(others.into_iter(), vm)?; | ||
let arguments = others.into_vec(); | ||
if arguments.len() == 1 { | ||
if let Some(iterable) = arguments.first() { | ||
if zelf.is(iterable) { | ||
return Ok(()); | ||
} | ||
} | ||
} else if arguments.is_empty() { | ||
return Ok(()); | ||
} | ||
zelf.inner.intersection_update(arguments.into_iter(), vm)?; | ||
Ok(()) | ||
} | ||
|
||
#[pymethod(magic)] | ||
fn iand(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
zelf.inner | ||
.intersection_update(std::iter::once(set.into_iterable(vm)?), vm)?; | ||
fn iand(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
if !set.class().is(vm.ctx.types.set_type) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about frozenset? if frozenset is allowed, maybe AnySet check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, frozen sets are allowed here. Similarly for |
||
return Err(vm.new_type_error("Type error".to_owned())); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be |
||
} | ||
if zelf.is(&set) { | ||
return Ok(zelf); | ||
} | ||
zelf.inner.intersection_update(std::iter::once(set), vm)?; | ||
Ok(zelf) | ||
} | ||
|
||
#[pymethod] | ||
fn difference_update(&self, others: PosArgs<ArgIterable>, vm: &VirtualMachine) -> PyResult<()> { | ||
self.inner.difference_update(others.into_iter(), vm)?; | ||
fn difference_update( | ||
zelf: PyRef<Self>, | ||
others: PosArgs<PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
let arguments = others.into_vec(); | ||
if arguments.len() == 1 { | ||
if let Some(iterable) = arguments.first() { | ||
if zelf.is(iterable) { | ||
zelf.inner.clear(); | ||
return Ok(()); | ||
} | ||
} | ||
} else if arguments.is_empty() { | ||
return Ok(()); | ||
} | ||
zelf.inner.difference_update(arguments.into_iter(), vm)?; | ||
Ok(()) | ||
} | ||
|
||
#[pymethod(magic)] | ||
fn isub(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
zelf.inner | ||
.difference_update(set.into_iterable_iter(vm)?, vm)?; | ||
fn isub(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
if !set.class().is(vm.ctx.types.set_type) { | ||
return Err(vm.new_type_error("Type error".to_owned())); | ||
} | ||
if zelf.is(&set) { | ||
zelf.inner.clear(); | ||
return Ok(zelf); | ||
} | ||
zelf.inner.difference_update(std::iter::once(set), vm)?; | ||
Ok(zelf) | ||
} | ||
|
||
#[pymethod] | ||
fn symmetric_difference_update( | ||
&self, | ||
others: PosArgs<ArgIterable>, | ||
zelf: PyRef<Self>, | ||
others: PosArgs<PyObjectRef>, | ||
vm: &VirtualMachine, | ||
) -> PyResult<()> { | ||
self.inner | ||
.symmetric_difference_update(others.into_iter(), vm)?; | ||
let arguments = others.into_vec(); | ||
if arguments.len() == 1 { | ||
if let Some(iterable) = arguments.first() { | ||
if zelf.is(iterable) { | ||
zelf.inner.clear(); | ||
return Ok(()); | ||
} | ||
} | ||
} else if arguments.is_empty() { | ||
return Ok(()); | ||
} | ||
zelf.inner | ||
.symmetric_difference_update(arguments.into_iter(), vm)?; | ||
Ok(()) | ||
} | ||
|
||
#[pymethod(magic)] | ||
fn ixor(zelf: PyRef<Self>, set: AnySet, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
fn ixor(zelf: PyRef<Self>, set: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<Self>> { | ||
if !set.class().is(vm.ctx.types.set_type) { | ||
return Err(vm.new_type_error("Type error".to_owned())); | ||
} | ||
if zelf.is(&set) { | ||
zelf.inner.clear(); | ||
return Ok(zelf); | ||
} | ||
zelf.inner | ||
.symmetric_difference_update(set.into_iterable_iter(vm)?, vm)?; | ||
.symmetric_difference_update(std::iter::once(set), vm)?; | ||
Ok(zelf) | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check is not only adaptable for
len() == 1
case but also do for every argument.For example, how about
s = set(['a']); s.intersection_update(s, s, s, s, s, s)
?