Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 39 additions & 52 deletions py/objstr.c
Original file line number Diff line number Diff line change
Expand Up @@ -186,19 +186,26 @@ STATIC mp_obj_t bytes_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m

// like strstr but with specified length and allows \0 bytes
// TODO replace with something more efficient/standard
STATIC const byte *find_subbytes(const byte *haystack, uint hlen, const byte *needle, uint nlen) {
STATIC const byte *find_subbytes(const byte *haystack, machine_uint_t hlen, const byte *needle, machine_uint_t nlen, machine_int_t direction) {
if (hlen >= nlen) {
for (uint i = 0; i <= hlen - nlen; i++) {
bool found = true;
for (uint j = 0; j < nlen; j++) {
if (haystack[i + j] != needle[j]) {
found = false;
break;
}
machine_uint_t str_index, str_index_end;
if (direction > 0) {
str_index = 0;
str_index_end = hlen - nlen;
} else {
str_index = hlen - nlen;
str_index_end = 0;
}
for (;;) {
if (memcmp(&haystack[str_index], needle, nlen) == 0) {
//found
return haystack + str_index;
}
if (found) {
return haystack + i;
if (str_index == str_index_end) {
//not found
break;
}
str_index += direction;
}
}
return NULL;
Expand Down Expand Up @@ -260,7 +267,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
/* NOTE `a in b` is `b.__contains__(a)` */
if (MP_OBJ_IS_STR(rhs_in)) {
GET_STR_DATA_LEN(rhs_in, rhs_data, rhs_len);
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len) != NULL);
return MP_BOOL(find_subbytes(lhs_data, lhs_len, rhs_data, rhs_len, 1) != NULL);
}
break;

Expand Down Expand Up @@ -382,7 +389,7 @@ STATIC mp_obj_t str_split(uint n_args, const mp_obj_t *args) {
return res;
}

STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) {
STATIC mp_obj_t str_finder(uint n_args, const mp_obj_t *args, machine_int_t direction) {
assert(2 <= n_args && n_args <= 4);
assert(MP_OBJ_IS_STR(args[0]));
assert(MP_OBJ_IS_STR(args[1]));
Expand All @@ -392,28 +399,31 @@ STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) {

machine_uint_t start = 0;
machine_uint_t end = haystack_len;
/* TODO use a non-exception-throwing mp_get_index */
if (n_args >= 3 && args[2] != mp_const_none) {
start = mp_get_index(&str_type, haystack_len, args[2], true);
}
if (n_args >= 4 && args[3] != mp_const_none) {
end = mp_get_index(&str_type, haystack_len, args[3], true);
}

const byte *p = find_subbytes(haystack + start, haystack_len - start, needle, needle_len);
const byte *p = find_subbytes(haystack + start, end - start, needle, needle_len, direction);
if (p == NULL) {
// not found
return MP_OBJ_NEW_SMALL_INT(-1);
} else {
// found
machine_int_t pos = p - haystack;
if (pos + needle_len > end) {
pos = -1;
}
return MP_OBJ_NEW_SMALL_INT(pos);
return MP_OBJ_NEW_SMALL_INT(p - haystack);
}
}

STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) {
return str_finder(n_args, args, 1);
}

STATIC mp_obj_t str_rfind(uint n_args, const mp_obj_t *args) {
return str_finder(n_args, args, -1);
}

// TODO: (Much) more variety in args
STATIC mp_obj_t str_startswith(mp_obj_t self_in, mp_obj_t arg) {
GET_STR_DATA_LEN(self_in, str, str_len);
Expand All @@ -424,15 +434,6 @@ STATIC mp_obj_t str_startswith(mp_obj_t self_in, mp_obj_t arg) {
return MP_BOOL(memcmp(str, prefix, prefix_len) == 0);
}

STATIC bool chr_in_str(const byte* const str, const machine_uint_t str_len, int c) {
for (machine_uint_t i = 0; i < str_len; i++) {
if (str[i] == c) {
return true;
}
}
return false;
}

STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) {
assert(1 <= n_args && n_args <= 2);
assert(MP_OBJ_IS_STR(args[0]));
Expand All @@ -457,7 +458,7 @@ STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) {
bool first_good_char_pos_set = false;
machine_uint_t last_good_char_pos = 0;
for (machine_uint_t i = 0; i < orig_str_len; i++) {
if (!chr_in_str(chars_to_del, chars_to_del_len, orig_str[i])) {
if (find_subbytes(chars_to_del, chars_to_del_len, &orig_str[i], 1, 1) == NULL) {
last_good_char_pos = i;
if (!first_good_char_pos_set) {
first_good_char_pos = i;
Expand Down Expand Up @@ -547,7 +548,7 @@ STATIC mp_obj_t str_replace(uint n_args, const mp_obj_t *args) {
const byte *old_occurrence;
const byte *offset_ptr = str;
machine_uint_t offset_num = 0;
while ((old_occurrence = find_subbytes(offset_ptr, str_len - offset_num, old, old_len)) != NULL) {
while ((old_occurrence = find_subbytes(offset_ptr, str_len - offset_num, old, old_len, 1)) != NULL) {
// copy from just after end of last occurrence of to-be-replaced string to right before start of next occurrence
if (data != NULL) {
memcpy(data + replaced_str_index, offset_ptr, old_occurrence - offset_ptr);
Expand Down Expand Up @@ -601,7 +602,6 @@ STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) {

machine_uint_t start = 0;
machine_uint_t end = haystack_len;
/* TODO use a non-exception-throwing mp_get_index */
if (n_args >= 3 && args[2] != mp_const_none) {
start = mp_get_index(&str_type, haystack_len, args[2], true);
}
Expand Down Expand Up @@ -648,27 +648,12 @@ STATIC mp_obj_t str_partitioner(mp_obj_t self_in, mp_obj_t arg, machine_int_t di
result[2] = self_in;
}

if (str_len >= sep_len) {
machine_uint_t str_index, str_index_end;
if (direction > 0) {
str_index = 0;
str_index_end = str_len - sep_len;
} else {
str_index = str_len - sep_len;
str_index_end = 0;
}
for (;;) {
if (memcmp(&str[str_index], sep, sep_len) == 0) {
result[0] = mp_obj_new_str(str, str_index, false);
result[1] = arg;
result[2] = mp_obj_new_str(str + str_index + sep_len, str_len - str_index - sep_len, false);
break;
}
if (str_index == str_index_end) {
break;
}
str_index += direction;
}
const byte *position_ptr = find_subbytes(str, str_len, sep, sep_len, direction);
if (position_ptr != NULL) {
machine_uint_t position = position_ptr - str;
result[0] = mp_obj_new_str(str, position, false);
result[1] = arg;
result[2] = mp_obj_new_str(str + position + sep_len, str_len - position - sep_len, false);
}

return mp_obj_new_tuple(3, result);
Expand Down Expand Up @@ -697,6 +682,7 @@ STATIC machine_int_t str_get_buffer(mp_obj_t self_in, buffer_info_t *bufinfo, in
}

STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_find_obj, 2, 4, str_find);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_rfind_obj, 2, 4, str_rfind);
STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_join_obj, str_join);
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_split_obj, 1, 3, str_split);
STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_startswith_obj, str_startswith);
Expand All @@ -709,6 +695,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_rpartition_obj, str_rpartition);

STATIC const mp_method_t str_type_methods[] = {
{ "find", &str_find_obj },
{ "rfind", &str_rfind_obj },
{ "join", &str_join_obj },
{ "split", &str_split_obj },
{ "startswith", &str_startswith_obj },
Expand Down
23 changes: 23 additions & 0 deletions tests/basics/string_rfind.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
print("hello world".rfind("ll"))
print("hello world".rfind("ll", None))
print("hello world".rfind("ll", 1))
print("hello world".rfind("ll", 1, None))
print("hello world".rfind("ll", None, None))
print("hello world".rfind("ll", 1, -1))
print("hello world".rfind("ll", 1, 1))
print("hello world".rfind("ll", 1, 2))
print("hello world".rfind("ll", 1, 3))
print("hello world".rfind("ll", 1, 4))
print("hello world".rfind("ll", 1, 5))
print("hello world".rfind("ll", -100))
print("0000".rfind('0'))
print("0000".rfind('0', 0))
print("0000".rfind('0', 1))
print("0000".rfind('0', 2))
print("0000".rfind('0', 3))
print("0000".rfind('0', 4))
print("0000".rfind('0', 5))
print("0000".rfind('-1', 3))
print("0000".rfind('1', 3))
print("0000".rfind('1', 4))
print("0000".rfind('1', 5))