Skip to content

Commit ad19a10

Browse files
committed
py/objset: Allow intersection() to take multiple sets as args.
Signed-off-by: Amirreza Hamzavi <amirrezahamzavi2000@gmail.com>
1 parent 3ca01ec commit ad19a10

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

py/objset.c

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -227,29 +227,45 @@ static mp_obj_t set_diff_update(size_t n_args, const mp_obj_t *args) {
227227
}
228228
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_diff_update_obj, 1, set_diff_update);
229229

230-
static mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update) {
230+
static mp_obj_t set_intersect_int(size_t n_args, const mp_obj_t *args, bool update) {
231231
if (update) {
232-
check_set(self_in);
232+
check_set(args[0]);
233233
} else {
234-
check_set_or_frozenset(self_in);
234+
check_set_or_frozenset(args[0]);
235235
}
236236

237-
if (self_in == other) {
238-
return update ? mp_const_none : set_copy(self_in);
237+
if (n_args == 2 && args[0] == args[1]) {
238+
return update ? mp_const_none : set_copy(args[0]);
239239
}
240240

241-
mp_obj_set_t *self = MP_OBJ_TO_PTR(self_in);
242241
mp_obj_set_t *out = MP_OBJ_TO_PTR(mp_obj_new_set(0, NULL));
243242

244-
mp_obj_t iter = mp_getiter(other, NULL);
245-
mp_obj_t next;
246-
while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) {
247-
if (mp_set_lookup(&self->set, next, MP_MAP_LOOKUP)) {
248-
set_add(MP_OBJ_FROM_PTR(out), next);
243+
mp_obj_t self_iter = mp_getiter(args[0], NULL);
244+
mp_obj_t self_next;
245+
while ((self_next = mp_iternext(self_iter)) != MP_OBJ_STOP_ITERATION) {
246+
bool skip = false;
247+
for (size_t i = 1; i < n_args; i++) {
248+
mp_obj_t iter = mp_getiter(args[i], NULL);
249+
mp_obj_t next;
250+
bool present = false;
251+
while ((next = mp_iternext(iter)) != MP_OBJ_STOP_ITERATION) {
252+
if (mp_obj_equal(self_next, next)) {
253+
present = true;
254+
break;
255+
}
256+
}
257+
if (!present) {
258+
skip = true;
259+
break;
260+
}
261+
}
262+
if (!skip) {
263+
set_add(MP_OBJ_FROM_PTR(out), self_next);
249264
}
250265
}
251266

252267
if (update) {
268+
mp_obj_set_t *self = MP_OBJ_TO_PTR(args[0]);
253269
m_del(mp_obj_t, self->set.table, self->set.alloc);
254270
self->set.alloc = out->set.alloc;
255271
self->set.used = out->set.used;
@@ -259,15 +275,15 @@ static mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update)
259275
return update ? mp_const_none : MP_OBJ_FROM_PTR(out);
260276
}
261277

262-
static mp_obj_t set_intersect(mp_obj_t self_in, mp_obj_t other) {
263-
return set_intersect_int(self_in, other, false);
278+
static mp_obj_t set_intersect(size_t n_args, const mp_obj_t *args) {
279+
return set_intersect_int(n_args, args, false);
264280
}
265-
static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_obj, set_intersect);
281+
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_obj, 2, set_intersect);
266282

267-
static mp_obj_t set_intersect_update(mp_obj_t self_in, mp_obj_t other) {
268-
return set_intersect_int(self_in, other, true);
283+
static mp_obj_t set_intersect_update(size_t n_args, const mp_obj_t *args) {
284+
return set_intersect_int(n_args, args, true);
269285
}
270-
static MP_DEFINE_CONST_FUN_OBJ_2(set_intersect_update_obj, set_intersect_update);
286+
static MP_DEFINE_CONST_FUN_OBJ_VAR(set_intersect_update_obj, 2, set_intersect_update);
271287

272288
static mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) {
273289
check_set_or_frozenset(self_in);
@@ -468,7 +484,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
468484
case MP_BINARY_OP_XOR:
469485
return set_symmetric_difference(lhs, rhs);
470486
case MP_BINARY_OP_AND:
471-
return set_intersect(lhs, rhs);
487+
return set_intersect(2, args);
472488
case MP_BINARY_OP_SUBTRACT:
473489
return set_diff(2, args);
474490
case MP_BINARY_OP_INPLACE_OR:
@@ -486,7 +502,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
486502
return set_symmetric_difference(lhs, rhs);
487503
}
488504
case MP_BINARY_OP_INPLACE_AND:
489-
rhs = set_intersect_int(lhs, rhs, update);
505+
rhs = set_intersect_int(2, args, update);
490506
if (update) {
491507
return lhs;
492508
} else {

tests/basics/set_intersection.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33
print(sorted(s.intersection({1, 3})))
44
print(sorted(s.intersection([3, 4])))
55

6+
print(sorted(s.intersection({1, 2, 3}, {1, 4, 5}, {1})))
7+
print(sorted(s.intersection([1, 3], [1, 3, 5])))
8+
69
print(s.intersection_update([1]))
710
print(sorted(s))
11+
12+
s = {1, 2, 3}
13+
print(s.intersection_update({1, 2}, {1}))
14+
print(sorted(s))

0 commit comments

Comments
 (0)