From a71e39473be69650c472526231adb11bb5bd79bb Mon Sep 17 00:00:00 2001 From: Amirreza Hamzavi Date: Mon, 9 Sep 2024 13:32:35 +0330 Subject: [PATCH] py/objset: Allow intersection() to take multiple sets as args. Signed-off-by: Amirreza Hamzavi --- py/objset.c | 58 ++++++++++++++++---------------- tests/basics/set_intersection.py | 14 ++++++++ 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/py/objset.c b/py/objset.c index c8fa12a7ec58..9c94e1341ba5 100644 --- a/py/objset.c +++ b/py/objset.c @@ -227,47 +227,47 @@ static mp_obj_t set_diff_update(size_t n_args, const mp_obj_t *args) { } static MP_DEFINE_CONST_FUN_OBJ_VAR(set_diff_update_obj, 1, set_diff_update); -static mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update) { +static mp_obj_t set_intersect_int(size_t n_args, const mp_obj_t *args, bool update) { + mp_obj_set_t *in_ptr; if (update) { - check_set(self_in); + check_set(args[0]); + in_ptr = MP_OBJ_TO_PTR(args[0]); } else { - check_set_or_frozenset(self_in); + check_set_or_frozenset(args[0]); + in_ptr = MP_OBJ_TO_PTR(set_copy(args[0])); } - if (self_in == other) { - return update ? mp_const_none : set_copy(self_in); - } + for (size_t i = 1; i < n_args; i++) { + if (MP_OBJ_FROM_PTR(in_ptr) != args[i]) { + mp_obj_set_t *out_ptr = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL)); - mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in); - mp_obj_set_t *out = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL)); + mp_obj_t iter = mp_getiter(args[i], NULL); + mp_obj_t next; + while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { + if (mp_set_lookup(&in_ptr->set, next, MP_MAP_LOOKUP)) { + set_add(MP_OBJ_FROM_PTR(out_ptr), next); + } + } - mp_obj_t iter = mp_getiter(other, NULL); - mp_obj_t next; - while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) { - if (mp_set_lookup(&self->set, next, MP_MAP_LOOKUP)) { - set_add(MP_OBJ_FROM_PTR(out), next); + m_del(mp_obj_t, in_ptr->set.table, in_ptr->set.alloc); + in_ptr->set.alloc = out_ptr->set.alloc; + in_ptr->set.used = out_ptr->set.used; + in_ptr->set.table = out_ptr->set.table; } } - if (update) { - m_del(mp_obj_t, self->set.table, self->set.alloc); - self->set.alloc = out->set.alloc; - self->set.used = out->set.used; - self->set.table = out->set.table; - } - - return update ? mp_const_none : MP_OBJ_FROM_PTR(out); + return update ? mp_const_none : MP_OBJ_FROM_PTR(in_ptr); } -static mp_obj_t set_intersect(mp_obj_t self_in, mp_obj_t other) { - return set_intersect_int(self_in, other, false); +static mp_obj_t set_intersect(size_t n_args, const mp_obj_t *args) { + return set_intersect_int(n_args, args, false); } -static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_obj, set_intersect); +static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_obj, 1, set_intersect); -static mp_obj_t set_intersect_update(mp_obj_t self_in, mp_obj_t other) { - return set_intersect_int(self_in, other, true); +static mp_obj_t set_intersect_update(size_t n_args, const mp_obj_t *args) { + return set_intersect_int(n_args, args, true); } -static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_update_obj, set_intersect_update); +static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_update_obj, 1, set_intersect_update); static mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) { check_set_or_frozenset(self_in); @@ -468,7 +468,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { case MP_BINARY_OP_XOR: return set_symmetric_difference(lhs, rhs); case MP_BINARY_OP_AND: - return set_intersect(lhs, rhs); + return set_intersect_int(2, args, false); case MP_BINARY_OP_SUBTRACT: return set_diff(2, args); case MP_BINARY_OP_INPLACE_OR: @@ -486,7 +486,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) { return set_symmetric_difference(lhs, rhs); } case MP_BINARY_OP_INPLACE_AND: - rhs = set_intersect_int(lhs, rhs, update); + rhs = set_intersect_int(2, args, update); if (update) { return lhs; } else { diff --git a/tests/basics/set_intersection.py b/tests/basics/set_intersection.py index 73804c840d61..ee1ee724723a 100644 --- a/tests/basics/set_intersection.py +++ b/tests/basics/set_intersection.py @@ -1,7 +1,21 @@ s = {1, 2, 3, 4} print(sorted(s)) +print(sorted(s.intersection())) +print(sorted(s)) +print(s.intersection_update()) +print(sorted(s)) print(sorted(s.intersection({1, 3}))) +print(sorted(s)) print(sorted(s.intersection([3, 4]))) +print(sorted(s)) +print(sorted(s.intersection({1, 2, 3}, {1, 4, 5}, {1}))) +print(sorted(s)) +print(sorted(s.intersection([1, 3], [1, 3, 5]))) +print(sorted(s)) print(s.intersection_update([1])) print(sorted(s)) + +s = {1, 2, 3} +print(s.intersection_update({1, 2}, {1})) +print(sorted(s))