From 7cf25440fa4d45e72a5cc601990410c595bcca87 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 22 Feb 2019 16:16:06 +0200 Subject: [PATCH 1/2] Add set.__iter__ --- tests/snippets/set.py | 7 +++++++ vm/src/obj/objset.rs | 17 +++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index be64d92dbb..d3583a324f 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -99,6 +99,13 @@ def __hash__(self): assert a == set([1,2,3,4,5]) assert_raises(TypeError, lambda: a.update(1)) +a = set([1,2,3]) +b = set() +for e in a: + assert e == 1 or e == 2 or e == 3 + b.add(e) +assert a == b + a = set([1,2,3]) a.intersection_update([2,3,4,5]) assert a == set([2,3]) diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 5854c94d78..3db9b1d1f3 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -521,6 +521,22 @@ fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> } } +fn set_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + arg_check!(vm, args, required = [(zelf, Some(vm.ctx.set_type()))]); + + let items = get_elements(zelf).values().map(|x| x.clone()).collect(); + let set_list = vm.ctx.new_list(items); + let iter_obj = PyObject::new( + PyObjectPayload::Iterator { + position: 0, + iterated_obj: set_list, + }, + vm.ctx.iter_type(), + ); + + Ok(iter_obj) +} + fn frozenset_repr(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!(vm, args, required = [(o, Some(vm.ctx.frozenset_type()))]); @@ -608,6 +624,7 @@ pub fn init(context: &PyContext) { "symmetric_difference_update", context.new_rustfunc(set_symmetric_difference_update), ); + context.set_attr(&set_type, "__iter__", context.new_rustfunc(set_iter)); let frozenset_type = &context.frozenset_type; From fc10560308ed31a78e2137a21a022470076c9276 Mon Sep 17 00:00:00 2001 From: Aviv Palivoda Date: Fri, 22 Feb 2019 12:59:24 +0200 Subject: [PATCH 2/2] Add set.{__ior__,__iand__,__isub__,__ixor__} --- tests/snippets/set.py | 40 ++++++++++++++++++++++++++++++++++++++++ vm/src/obj/objset.rs | 38 +++++++++++++++++++++++++++++++------- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/tests/snippets/set.py b/tests/snippets/set.py index d3583a324f..84754bca81 100644 --- a/tests/snippets/set.py +++ b/tests/snippets/set.py @@ -106,17 +106,57 @@ def __hash__(self): b.add(e) assert a == b +a = set([1,2,3]) +a |= set([3,4,5]) +assert a == set([1,2,3,4,5]) +try: + a |= 1 +except TypeError: + pass +else: + assert False, "TypeError not raised" + a = set([1,2,3]) a.intersection_update([2,3,4,5]) assert a == set([2,3]) assert_raises(TypeError, lambda: a.intersection_update(1)) +a = set([1,2,3]) +a &= set([2,3,4,5]) +assert a == set([2,3]) +try: + a &= 1 +except TypeError: + pass +else: + assert False, "TypeError not raised" + a = set([1,2,3]) a.difference_update([3,4,5]) assert a == set([1,2]) assert_raises(TypeError, lambda: a.difference_update(1)) +a = set([1,2,3]) +a -= set([3,4,5]) +assert a == set([1,2]) +try: + a -= 1 +except TypeError: + pass +else: + assert False, "TypeError not raised" + a = set([1,2,3]) a.symmetric_difference_update([3,4,5]) assert a == set([1,2,4,5]) assert_raises(TypeError, lambda: a.difference_update(1)) + +a = set([1,2,3]) +a ^= set([3,4,5]) +assert a == set([1,2,4,5]) +try: + a ^= 1 +except TypeError: + pass +else: + assert False, "TypeError not raised" diff --git a/vm/src/obj/objset.rs b/vm/src/obj/objset.rs index 3db9b1d1f3..ebae2b44cd 100644 --- a/vm/src/obj/objset.rs +++ b/vm/src/obj/objset.rs @@ -433,6 +433,11 @@ fn set_pop(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { } fn set_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_ior(vm, args)?; + Ok(vm.get_none()) +} + +fn set_ior(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, @@ -447,17 +452,27 @@ fn set_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { while let Ok(v) = vm.call_method(&iterator, "__next__", vec![]) { insert_into_set(vm, elements, &v)?; } - Ok(vm.get_none()) } - _ => Err(vm.new_type_error("set.update is called with no other".to_string())), + _ => return Err(vm.new_type_error("set.update is called with no other".to_string())), } + Ok(zelf.clone()) } fn set_intersection_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_combine_update_inner(vm, args, SetCombineOperation::Intersection)?; + Ok(vm.get_none()) +} + +fn set_iand(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { set_combine_update_inner(vm, args, SetCombineOperation::Intersection) } fn set_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_combine_update_inner(vm, args, SetCombineOperation::Difference)?; + Ok(vm.get_none()) +} + +fn set_isub(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { set_combine_update_inner(vm, args, SetCombineOperation::Difference) } @@ -486,13 +501,18 @@ fn set_combine_update_inner( elements.remove(&element.0.clone()); } } - Ok(vm.get_none()) } - _ => Err(vm.new_type_error("".to_string())), + _ => return Err(vm.new_type_error("".to_string())), } + Ok(zelf.clone()) } fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { + set_ixor(vm, args)?; + Ok(vm.get_none()) +} + +fn set_ixor(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { arg_check!( vm, args, @@ -514,11 +534,11 @@ fn set_symmetric_difference_update(vm: &mut VirtualMachine, args: PyFuncArgs) -> elements.remove(&element.0.clone()); } } - - Ok(vm.get_none()) } - _ => Err(vm.new_type_error("".to_string())), + _ => return Err(vm.new_type_error("".to_string())), } + + Ok(zelf.clone()) } fn set_iter(vm: &mut VirtualMachine, args: PyFuncArgs) -> PyResult { @@ -609,21 +629,25 @@ pub fn init(context: &PyContext) { context.set_attr(&set_type, "copy", context.new_rustfunc(set_copy)); context.set_attr(&set_type, "pop", context.new_rustfunc(set_pop)); context.set_attr(&set_type, "update", context.new_rustfunc(set_update)); + context.set_attr(&set_type, "__ior__", context.new_rustfunc(set_ior)); context.set_attr( &set_type, "intersection_update", context.new_rustfunc(set_intersection_update), ); + context.set_attr(&set_type, "__iand__", context.new_rustfunc(set_iand)); context.set_attr( &set_type, "difference_update", context.new_rustfunc(set_difference_update), ); + context.set_attr(&set_type, "__isub__", context.new_rustfunc(set_isub)); context.set_attr( &set_type, "symmetric_difference_update", context.new_rustfunc(set_symmetric_difference_update), ); + context.set_attr(&set_type, "__ixor__", context.new_rustfunc(set_ixor)); context.set_attr(&set_type, "__iter__", context.new_rustfunc(set_iter)); let frozenset_type = &context.frozenset_type;