Skip to content

Commit 8d376a3

Browse files
authored
Merge pull request adafruit#693 from jepler/issue236
Implement * and *= for array.array
2 parents ce8f7e6 + cdb83b1 commit 8d376a3

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

py/objarray.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,39 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
241241
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
242242
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
243243
switch (op) {
244+
case MP_BINARY_OP_MULTIPLY:
245+
case MP_BINARY_OP_INPLACE_MULTIPLY: {
246+
if (!MP_OBJ_IS_INT(rhs_in)) {
247+
return MP_OBJ_NULL; // op not supported
248+
}
249+
mp_uint_t repeat = mp_obj_get_int(rhs_in);
250+
bool inplace = (op == MP_BINARY_OP_INPLACE_MULTIPLY);
251+
mp_buffer_info_t lhs_bufinfo;
252+
array_get_buffer(lhs_in, &lhs_bufinfo, MP_BUFFER_READ);
253+
mp_obj_array_t *res;
254+
byte *ptr;
255+
size_t orig_lhs_bufinfo_len = lhs_bufinfo.len;
256+
if(inplace) {
257+
res = lhs;
258+
size_t item_sz = mp_binary_get_size('@', lhs->typecode, NULL);
259+
lhs->items = m_renew(byte, lhs->items, (lhs->len + lhs->free) * item_sz, lhs->len * repeat * item_sz);
260+
lhs->len = lhs->len * repeat;
261+
lhs->free = 0;
262+
if (!repeat)
263+
return MP_OBJ_FROM_PTR(res);
264+
repeat--;
265+
ptr = (byte*)res->items + orig_lhs_bufinfo_len;
266+
} else {
267+
res = array_new(lhs_bufinfo.typecode, lhs->len * repeat);
268+
ptr = (byte*)res->items;
269+
}
270+
if(orig_lhs_bufinfo_len) {
271+
for(;repeat--; ptr += orig_lhs_bufinfo_len) {
272+
memcpy(ptr, lhs_bufinfo.buf, orig_lhs_bufinfo_len);
273+
}
274+
}
275+
return MP_OBJ_FROM_PTR(res);
276+
}
244277
case MP_BINARY_OP_ADD: {
245278
// allow to add anything that has the buffer protocol (extension to CPython)
246279
mp_buffer_info_t lhs_bufinfo;

tests/basics/array_mul.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
try:
2+
import array
3+
except ImportError:
4+
print("SKIP")
5+
raise SystemExit
6+
7+
a1 = array.array('I', [1])
8+
a2 = array.array('I', [2]) * 2
9+
a3 = (a1 + a2)
10+
print(a3)
11+
12+
a3 *= 5
13+
print(a3)
14+
15+
a3 *= 0
16+
print(a3)
17+
18+
a4 = a2 * 0
19+
print(a4)
20+
21+
a4 *= 0
22+
print(a4)
23+
24+
a4 = a4 * 2
25+
print(a4)
26+
27+
a4 *= 2
28+
print(a4)

0 commit comments

Comments
 (0)