@@ -227,29 +227,45 @@ static mp_obj_t set_diff_update(size_t n_args, const mp_obj_t *args) {
227
227
}
228
228
static MP_DEFINE_CONST_FUN_OBJ_VAR (set_diff_update_obj , 1 , set_diff_update ) ;
229
229
230
- static mp_obj_t set_intersect_int (mp_obj_t self_in , mp_obj_t other , bool update ) {
230
+ static mp_obj_t set_intersect_int (size_t n_args , const mp_obj_t * args , bool update ) {
231
231
if (update ) {
232
- check_set (self_in );
232
+ check_set (args [ 0 ] );
233
233
} else {
234
- check_set_or_frozenset (self_in );
234
+ check_set_or_frozenset (args [ 0 ] );
235
235
}
236
236
237
- if (self_in == other ) {
238
- return update ? mp_const_none : set_copy (self_in );
237
+ if (n_args == 2 && args [ 0 ] == args [ 1 ] ) {
238
+ return update ? mp_const_none : set_copy (args [ 0 ] );
239
239
}
240
240
241
- mp_obj_set_t * self = MP_OBJ_TO_PTR (self_in );
242
241
mp_obj_set_t * out = MP_OBJ_TO_PTR (mp_obj_new_set (0 , NULL ));
243
242
244
- mp_obj_t iter = mp_getiter (other , NULL );
245
- mp_obj_t next ;
246
- while ((next = mp_iternext (iter )) != MP_OBJ_STOP_ITERATION ) {
247
- if (mp_set_lookup (& self -> set , next , MP_MAP_LOOKUP )) {
248
- set_add (MP_OBJ_FROM_PTR (out ), next );
243
+ mp_obj_t self_iter = mp_getiter (args [0 ], NULL );
244
+ mp_obj_t self_next ;
245
+ while ((self_next = mp_iternext (self_iter )) != MP_OBJ_STOP_ITERATION ) {
246
+ bool skip = false;
247
+ for (size_t i = 1 ; i < n_args ; i ++ ) {
248
+ mp_obj_t iter = mp_getiter (args [i ], NULL );
249
+ mp_obj_t next ;
250
+ bool present = false;
251
+ while ((next = mp_iternext (iter )) != MP_OBJ_STOP_ITERATION ) {
252
+ if (mp_obj_equal (self_next , next )) {
253
+ present = true;
254
+ break ;
255
+ }
256
+ }
257
+ if (!present ) {
258
+ skip = true;
259
+ break ;
260
+ }
261
+ }
262
+ if (!skip ) {
263
+ set_add (MP_OBJ_FROM_PTR (out ), self_next );
249
264
}
250
265
}
251
266
252
267
if (update ) {
268
+ mp_obj_set_t * self = MP_OBJ_TO_PTR (args [0 ]);
253
269
m_del (mp_obj_t , self -> set .table , self -> set .alloc );
254
270
self -> set .alloc = out -> set .alloc ;
255
271
self -> set .used = out -> set .used ;
@@ -259,15 +275,15 @@ static mp_obj_t set_intersect_int(mp_obj_t self_in, mp_obj_t other, bool update)
259
275
return update ? mp_const_none : MP_OBJ_FROM_PTR (out );
260
276
}
261
277
262
- static mp_obj_t set_intersect (mp_obj_t self_in , mp_obj_t other ) {
263
- return set_intersect_int (self_in , other , false);
278
+ static mp_obj_t set_intersect (size_t n_args , const mp_obj_t * args ) {
279
+ return set_intersect_int (n_args , args , false);
264
280
}
265
- static MP_DEFINE_CONST_FUN_OBJ_2 (set_intersect_obj , set_intersect ) ;
281
+ static MP_DEFINE_CONST_FUN_OBJ_VAR (set_intersect_obj , 2 , set_intersect ) ;
266
282
267
- static mp_obj_t set_intersect_update (mp_obj_t self_in , mp_obj_t other ) {
268
- return set_intersect_int (self_in , other , true);
283
+ static mp_obj_t set_intersect_update (size_t n_args , const mp_obj_t * args ) {
284
+ return set_intersect_int (n_args , args , true);
269
285
}
270
- static MP_DEFINE_CONST_FUN_OBJ_2 (set_intersect_update_obj , set_intersect_update ) ;
286
+ static MP_DEFINE_CONST_FUN_OBJ_VAR (set_intersect_update_obj , 2 , set_intersect_update ) ;
271
287
272
288
static mp_obj_t set_isdisjoint (mp_obj_t self_in , mp_obj_t other ) {
273
289
check_set_or_frozenset (self_in );
@@ -468,7 +484,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
468
484
case MP_BINARY_OP_XOR :
469
485
return set_symmetric_difference (lhs , rhs );
470
486
case MP_BINARY_OP_AND :
471
- return set_intersect (lhs , rhs );
487
+ return set_intersect (2 , args );
472
488
case MP_BINARY_OP_SUBTRACT :
473
489
return set_diff (2 , args );
474
490
case MP_BINARY_OP_INPLACE_OR :
@@ -486,7 +502,7 @@ static mp_obj_t set_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
486
502
return set_symmetric_difference (lhs , rhs );
487
503
}
488
504
case MP_BINARY_OP_INPLACE_AND :
489
- rhs = set_intersect_int (lhs , rhs , update );
505
+ rhs = set_intersect_int (2 , args , update );
490
506
if (update ) {
491
507
return lhs ;
492
508
} else {
0 commit comments