Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion py/compile.c
Original file line number Diff line number Diff line change
Expand Up @@ -2401,8 +2401,17 @@ STATIC void compile_trailer_paren_helper(compiler_t *comp, mp_parse_node_t pn_ar
}
star_flags |= MP_EMIT_STAR_FLAG_SINGLE;
star_args |= (mp_uint_t)1 << i;

if (n_keyword == 0) {
// star-args before kwargs encoded as positional arg
n_positional++;
} else {
// star-args after kwargs encoded as kw arg with key=NULL
EMIT(load_null);
n_keyword++;
}

compile_node(comp, pns_arg->nodes[0]);
n_positional++;
} else if (MP_PARSE_NODE_STRUCT_KIND(pns_arg) == PN_arglist_dbl_star) {
star_flags |= MP_EMIT_STAR_FLAG_DOUBLE;
// double-star args are stored as kw arg with key of None
Expand Down
52 changes: 41 additions & 11 deletions py/runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -735,17 +735,30 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_
mp_obj_t *args2;
size_t args2_alloc;
size_t args2_len = 0;
size_t n_args_star_args = n_args;

// Try to get a hint for unpacked * args length
ssize_t list_len = 0;

if (star_args != 0) {
for (size_t i = 0; i < n_args; i++) {
if ((star_args >> i) & 1) {
mp_obj_t len = mp_obj_len_maybe(args[i]);
if (len != MP_OBJ_NULL) {
if (star_args) {
// kw can also contain star args.
n_args_star_args += n_kw;

for (size_t i = 0; i < n_args_star_args; i++) {
if (!((star_args >> i) & 1)) {
continue;
}

mp_obj_t arg = i >= n_args ? args[n_args + 2 * (i - n_args) + 1] : args[i];

mp_obj_t len = mp_obj_len_maybe(arg);

if (len != MP_OBJ_NULL) {
list_len += mp_obj_get_int(len);

if (i < n_args) {
// -1 accounts for 1 of n_args occupied by this arg
list_len += mp_obj_get_int(len) - 1;
list_len--;
}
}
}
Expand All @@ -757,9 +770,20 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_
for (size_t i = 0; i < n_kw; i++) {
mp_obj_t key = args[n_args + i * 2];
mp_obj_t value = args[n_args + i * 2 + 1];
if (key == MP_OBJ_NULL && value != MP_OBJ_NULL && mp_obj_is_type(value, &mp_type_dict)) {

if (key == MP_OBJ_NULL) {
// -1 accounts for 1 of n_kw occupied by this arg
kw_dict_len += mp_obj_dict_len(value) - 1;
kw_dict_len--;

if (((star_args >> (n_args + i)) & 1)) {
// star args were already handled above
continue;
}

// double-star args
if (mp_obj_is_type(value, &mp_type_dict)) {
kw_dict_len += mp_obj_dict_len(value);
}
}
}

Expand Down Expand Up @@ -792,8 +816,9 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_
args2[args2_len++] = self;
}

for (size_t i = 0; i < n_args; i++) {
mp_obj_t arg = args[i];
for (size_t i = 0; i < n_args_star_args; i++) {
mp_obj_t arg = i >= n_args ? args[n_args + 2 * (i - n_args) + 1] : args[i];

if ((star_args >> i) & 1) {
// star arg
if (mp_obj_is_type(arg, &mp_type_tuple) || mp_obj_is_type(arg, &mp_type_list)) {
Expand Down Expand Up @@ -824,7 +849,7 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_
args2[args2_len++] = item;
}
}
} else {
} else if (i < n_args) {
// normal argument
assert(args2_len < args2_alloc);
args2[args2_len++] = arg;
Expand All @@ -848,6 +873,11 @@ void mp_call_prepare_args_n_kw_var(bool have_self, size_t n_args_n_kw, const mp_
mp_obj_t kw_key = args[n_args + i * 2];
mp_obj_t kw_value = args[n_args + i * 2 + 1];
if (kw_key == MP_OBJ_NULL) {
if ((star_args >> (n_args + i)) & 1) {
// star args have already been handled above
continue;
}

// double-star args
if (mp_obj_is_type(kw_value, &mp_type_dict)) {
// dictionary
Expand Down
4 changes: 4 additions & 0 deletions tests/basics/fun_callstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def foo(a, b, c):
# pos then iterator
foo(1, *range(2, 4))

# star after kw
foo(1, 2, c=3, *())
foo(b=2, *(1,), c=3)

# an iterator with many elements
def foo(*rest):
print(rest)
Expand Down