Skip to content

Commit 5d2e79b

Browse files
committed
py/objset: Allow intersection() to take multiple sets as args.
Signed-off-by: Amirreza Hamzavi <amirrezahamzavi2000@gmail.com>
1 parent 3ca01ec commit 5d2e79b

File tree

2 files changed

+39
-29
lines changed

2 files changed

+39
-29
lines changed

py/objset.c

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -227,47 +227,47 @@ static mp_obj_t set_diff_update(size_t n_args, const mp_obj_t *args) {
227227
}
228228
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_diff_update_obj, 1, set_diff_update);
229229

230-
static mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update) {
230+
static mp_obj_t set_intersect_int(size_t n_args, const mp_obj_t *args, bool update) {
231+
mp_obj_set_t *in_ptr;
231232
if (update) {
232-
check_set(self_in);
233+
check_set(args[0]);
234+
in_ptr = MP_OBJ_TO_PTR(args[0]);
233235
} else {
234-
check_set_or_frozenset(self_in);
236+
check_set_or_frozenset(args[0]);
237+
in_ptr = MP_OBJ_TO_PTR(set_copy(args[0]));
235238
}
236239

237-
if (self_in == other) {
238-
return update ? mp_const_none : set_copy(self_in);
239-
}
240+
for (size_t i = 1; i < n_args; i++) {
241+
if (MP_OBJ_FROM_PTR(in_ptr) != args[i]) {
242+
mp_obj_set_t *out_ptr = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL));
240243

241-
mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in);
242-
mp_obj_set_t *out = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL));
244+
mp_obj_t iter = mp_getiter(args[i], NULL);
245+
mp_obj_t next;
246+
while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) {
247+
if (mp_set_lookup(&in_ptr->set, next, MP_MAP_LOOKUP)) {
248+
set_add(MP_OBJ_FROM_PTR(out_ptr), next);
249+
}
250+
}
243251

244-
mp_obj_t iter = mp_getiter(other, NULL);
245-
mp_obj_t next;
246-
while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) {
247-
if (mp_set_lookup(&self->set, next, MP_MAP_LOOKUP)) {
248-
set_add(MP_OBJ_FROM_PTR(out), next);
252+
m_del(mp_obj_t, in_ptr->set.table, in_ptr->set.alloc);
253+
in_ptr->set.alloc = out_ptr->set.alloc;
254+
in_ptr->set.used = out_ptr->set.used;
255+
in_ptr->set.table = out_ptr->set.table;
249256
}
250257
}
251258

252-
if (update) {
253-
m_del(mp_obj_t, self->set.table, self->set.alloc);
254-
self->set.alloc = out->set.alloc;
255-
self->set.used = out->set.used;
256-
self->set.table = out->set.table;
257-
}
258-
259-
return update ? mp_const_none : MP_OBJ_FROM_PTR(out);
259+
return update ? mp_const_none : MP_OBJ_FROM_PTR(in_ptr);
260260
}
261261

262-
static mp_obj_t set_intersect(mp_obj_t self_in, mp_obj_t other) {
263-
return set_intersect_int(self_in, other, false);
262+
static mp_obj_t set_intersect(size_t n_args, const mp_obj_t *args) {
263+
return set_intersect_int(n_args, args, false);
264264
}
265-
static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_obj, set_intersect);
265+
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_obj, 2, set_intersect);
266266

267-
static mp_obj_t set_intersect_update(mp_obj_t self_in, mp_obj_t other) {
268-
return set_intersect_int(self_in, other, true);
267+
static mp_obj_t set_intersect_update(size_t n_args, const mp_obj_t *args) {
268+
return set_intersect_int(n_args, args, true);
269269
}
270-
static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_update_obj, set_intersect_update);
270+
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_update_obj, 2, set_intersect_update);
271271

272272
static mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) {
273273
check_set_or_frozenset(self_in);
@@ -468,7 +468,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
468468
case MP_BINARY_OP_XOR:
469469
return set_symmetric_difference(lhs, rhs);
470470
case MP_BINARY_OP_AND:
471-
return set_intersect(lhs, rhs);
471+
return set_intersect_int(2, args, false);
472472
case MP_BINARY_OP_SUBTRACT:
473473
return set_diff(2, args);
474474
case MP_BINARY_OP_INPLACE_OR:
@@ -486,7 +486,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
486486
return set_symmetric_difference(lhs, rhs);
487487
}
488488
case MP_BINARY_OP_INPLACE_AND:
489-
rhs = set_intersect_int(lhs, rhs, update);
489+
rhs = set_intersect_int(2, args, update);
490490
if (update) {
491491
return lhs;
492492
} else {

tests/basics/set_intersection.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
s = {1, 2, 3, 4}
22
print(sorted(s))
33
print(sorted(s.intersection({1, 3})))
4+
print(sorted(s))
45
print(sorted(s.intersection([3, 4])))
6+
print(sorted(s))
7+
print(sorted(s.intersection({1, 2, 3}, {1, 4, 5}, {1})))
8+
print(sorted(s))
9+
print(sorted(s.intersection([1, 3], [1, 3, 5])))
10+
print(sorted(s))
511

612
print(s.intersection_update([1]))
713
print(sorted(s))
14+
15+
s = {1, 2, 3}
16+
print(s.intersection_update({1, 2}, {1}))
17+
print(sorted(s))

0 commit comments

Comments
 (0)