diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 527f505c0169b3..f945afb7934abd 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -5,6 +5,7 @@ import unittest, string, sys, struct from test import support from collections import UserList +import random class Sequence: def __init__(self, seq='wxyz'): self.seq = seq @@ -317,6 +318,45 @@ def test_rindex(self): else: self.checkraises(TypeError, 'hello', 'rindex', 42) + def test_find_periodic_pattern(self): + """Cover the special path for periodic patterns.""" + def reference_find(p, s): + m = len(p) + for i in range(len(s)): + if s[i:i+m] == p: + return i + return -1 + + rr = random.randrange + choices = random.choices + for _ in range(1000): + p0 = ''.join(choices('abcde', k=rr(10))) * rr(10, 20) + p = p0[:len(p0) - rr(10)] # pop off some characters + left = ''.join(choices('abcdef', k=rr(200))) + right = ''.join(choices('abcdef', k=rr(200))) + text = left + p + right + with self.subTest(p=p, text=text): + self.checkequal(reference_find(p, text), + text, 'find', p) + + def test_find_shift_table_overflow(self): + """When the table of 16-bit shifts overflows.""" + N = 2**16 + 100 # Overflow the 16-bit shift table + + # first check the periodic case + # here, the shift for 'b' is N. + pattern1 = 'a' * N + 'b' + 'a' * N + text1 = 'babbaa' * N + pattern1 + self.checkequal(len(text1)-len(pattern1), + text1, 'find', pattern1) + + # now check the non-periodic case + # here, the shift for 'd' is 3*(N+1) + pattern2 = 'ddd' + 'abc' * N + "eee" + text2 = pattern2[:-1] + "ddeede" * 2 * N + pattern2 + "de" * N + self.checkequal(len(text2) - N*len("de") - len(pattern2), + text2, 'find', pattern2) + def test_lower(self): self.checkequal('hello', 'HeLLo', 'lower') self.checkequal('hello', 'hello', 'lower') diff --git a/Misc/NEWS.d/next/Core and Builtins/2020-10-12-23-46-49.bpo-41972.0pHodE.rst b/Misc/NEWS.d/next/Core and Builtins/2020-10-12-23-46-49.bpo-41972.0pHodE.rst new file mode 100644 index 00000000000000..e340d690590e50 --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2020-10-12-23-46-49.bpo-41972.0pHodE.rst @@ -0,0 +1 @@ +Substring search functions such as ``str1 in str2`` and ``str2.find(str1)`` now use the "Two-Way" string comparison algorithm whenever ``str1`` is long enough, to avoid quadratic behavior in the worst cases. diff --git a/Objects/stringlib/fastsearch.h b/Objects/stringlib/fastsearch.h index 56a4467d353813..8991e3e9086323 100644 --- a/Objects/stringlib/fastsearch.h +++ b/Objects/stringlib/fastsearch.h @@ -6,10 +6,25 @@ moore and horspool, with a few more bells and whistles on the top. for some more background, see: http://effbot.org/zone/stringlib.htm */ +/* When the needle/pattern is long enough during the a forward search + or count, use the more complex Two-Way algorithm, which leverages + patterns in the string to ensure no worse than linear time. + Additionally, a Boyer-Moore bad-character shift table is computed + so that sublinear (as in O(n/m)) time is achieved in more cases. + References: + http://www-igm.univ-mlv.fr/~lecroq/string/node26.html#SECTION00260 + https://en.wikipedia.org/wiki/Two-way_string-matching_algorithm + This implementation was largely influenced by glibc: + https://code.woboq.org/userspace/glibc/string/str-two-way.h.html + https://code.woboq.org/userspace/glibc/string/memmem.c.html + Discussion here: + https://bugs.python.org/issue41972 + */ + /* note: fastsearch may access s[n], which isn't a problem when using Python's ordinary string types, but may cause problems if you're using this code in other contexts. also, the count mode returns -1 - if there cannot possible be a match in the target string, and 0 if + if there cannot possibly be a match in the target string, and 0 if it has actually checked for matches, but didn't find any. callers beware! */ @@ -17,6 +32,15 @@ #define FAST_SEARCH 1 #define FAST_RSEARCH 2 +/* Change to a 1 to see logging comments walk through the algorithm. */ +#if 0 && STRINGLIB_SIZEOF_CHAR == 1 +#define LOG(...) printf(__VA_ARGS__) +#define LOG_STRING(s, n) printf("\"%.*s\"", n, s) +#else +#define LOG(...) +#define LOG_STRING(s, n) +#endif + #if LONG_BIT >= 128 #define STRINGLIB_BLOOM_WIDTH 128 #elif LONG_BIT >= 64 @@ -160,6 +184,354 @@ STRINGLIB(rfind_char)(const STRINGLIB_CHAR* s, Py_ssize_t n, STRINGLIB_CHAR ch) #undef MEMCHR_CUT_OFF +// Preprocessing steps for the two-way algorithm. +Py_LOCAL_INLINE(Py_ssize_t) +STRINGLIB(_lex_search)(const STRINGLIB_CHAR *needle, Py_ssize_t needle_len, + int inverted, Py_ssize_t *return_period) +{ + // We'll eventually partition needle into + // needle[:max_suffix + 1] + needle[max_suffix + 1:] + Py_ssize_t max_suffix = -1; + + Py_ssize_t suffix = 0; // candidate for max_suffix + Py_ssize_t period = 1; // candidate for return_period + Py_ssize_t k = 1; // working index + + while (suffix + k < needle_len) { + STRINGLIB_CHAR a = needle[suffix + k]; + STRINGLIB_CHAR b = needle[max_suffix + k]; + if (inverted ? (a < b) : (b < a)) { + // Suffix is smaller, period is entire prefix so far. + suffix += k; + k = 1; + period = suffix - max_suffix; + } + else if (a == b) { + // Advance through the repetition of the current period. + if (k != period) { + k++; + } + else { + suffix += period; + k = 1; + } + } + else { + // Found a bigger suffix. + max_suffix = suffix; + suffix += 1; + k = 1; + period = 1; + } + } + *return_period = period; + return max_suffix + 1; +} + + +Py_LOCAL_INLINE(Py_ssize_t) +STRINGLIB(_critical_factorization)(const STRINGLIB_CHAR *needle, + Py_ssize_t needle_len, + Py_ssize_t *return_period) +{ + /* Morally, this is what we want to happen: + >>> x = "GCAGAGAG" + >>> suf, period = _critical_factorization(x) + >>> x[:suf], x[suf:] + ('GC', 'AGAGAG') + >>> period + 2 */ + Py_ssize_t period1, period2, max_suf1, max_suf2; + + // Search using both forward and inverted character-orderings + max_suf1 = STRINGLIB(_lex_search)(needle, needle_len, 0, &period1); + max_suf2 = STRINGLIB(_lex_search)(needle, needle_len, 1, &period2); + + // Choose the later suffix + if (max_suf2 < max_suf1) { + *return_period = period1; + return max_suf1; + } + else { + *return_period = period2; + return max_suf2; + } +} + + +#define SHIFT_TYPE uint16_t +#define NOT_FOUND ((1U<<(8*sizeof(SHIFT_TYPE))) - 1U) +#define SHIFT_OVERFLOW (NOT_FOUND - 1U) + +#define TABLE_SIZE_BITS 7 +#define TABLE_SIZE (1U << TABLE_SIZE_BITS) +#define TABLE_MASK (TABLE_SIZE - 1U) + +Py_LOCAL_INLINE(void) +STRINGLIB(_init_table)(const STRINGLIB_CHAR *needle, Py_ssize_t needle_len, + SHIFT_TYPE *table) +{ + // Fill the table with NOT_FOUND + memset(table, 0xff, TABLE_SIZE * sizeof(SHIFT_TYPE)); + assert(table[0] == NOT_FOUND); + assert(table[TABLE_SIZE - 1] == NOT_FOUND); + for (Py_ssize_t j = 0; j < needle_len; j++) { + // TABLE_MASK means not in string + // SHIFT_OVERFLOW means shift at least SHIFT_OVERFLOW + Py_ssize_t shift = needle_len - j - 1; + if (shift > SHIFT_OVERFLOW) { + shift = SHIFT_OVERFLOW; + } + table[needle[j] & TABLE_MASK] = (SHIFT_TYPE)shift; + } +} + + +Py_LOCAL_INLINE(Py_ssize_t) +STRINGLIB(_two_way)(const STRINGLIB_CHAR *needle, Py_ssize_t needle_len, + const STRINGLIB_CHAR *haystack, Py_ssize_t haystack_len, + Py_ssize_t suffix, Py_ssize_t period, + SHIFT_TYPE *shift_table) +{ + LOG("========================\n"); + LOG("Two-way with needle="); LOG_STRING(needle, needle_len); + LOG(" and haystack="); LOG_STRING(haystack, haystack_len); + LOG("\nSplit "); LOG_STRING(needle, needle_len); + LOG(" into "); LOG_STRING(needle, suffix); + LOG(" and "); LOG_STRING(needle + suffix, needle_len - suffix); + LOG(".\n"); + + if (memcmp(needle, needle+period, suffix * STRINGLIB_SIZEOF_CHAR) == 0) { + LOG("needle is completely periodic.\n"); + // a mismatch can only advance by the period. + // use memory to avoid re-scanning known occurrences of the period. + Py_ssize_t memory = 0; + Py_ssize_t j = 0; // index into haystack + while (j <= haystack_len - needle_len) { + + // Visualize the line-up: + LOG("> "); LOG_STRING(haystack, haystack_len); + LOG("\n> "); LOG("%*s", j, ""); LOG_STRING(needle, needle_len); + LOG("\n"); + + STRINGLIB_CHAR last = haystack[j + needle_len - 1]; + int index = last & TABLE_MASK; + SHIFT_TYPE shift = shift_table[index]; + + switch (shift) + { + case 0: { + break; + } + case NOT_FOUND: { + LOG("Last character not found in string.\n"); + memory = 0; + j += needle_len; + continue; + } + case SHIFT_OVERFLOW: { + LOG("Shift overflowed.\n"); + memory = 0; + j += SHIFT_OVERFLOW; + continue; + } + default: { + if (memory && shift < period) { + LOG("Shifting through multiple periods.\n"); + j += needle_len - period; + } else { + LOG("Table says shift by %d.\n", shift); + j += shift; + } + memory = 0; + continue; + } + } + + LOG("Scanning right half.\n"); + Py_ssize_t i = Py_MAX(suffix, memory); + while (i < needle_len && needle[i] == haystack[j+i]) { + i++; + } + if (i >= needle_len) { + LOG("Right half matched. Scanning left half.\n"); + i = suffix - 1; + while (memory < i + 1 && needle[i] == haystack[j+i]) { + i--; + } + if (i + 1 < memory + 1) { + LOG("Left half matches. Returning %d.\n", j); + return j; + } + LOG("No match.\n"); + // Remember how many periods were scanned on the right + j += period; + memory = needle_len - period; + } + else { + LOG("Skip without checking left half.\n"); + j += i - suffix + 1; + memory = 0; + } + } + } + else { + LOG("needle is NOT completely periodic.\n"); + // The two halves are distinct; + // no extra memory is required, + // and a mismatch results in a maximal shift. + period = 1 + Py_MAX(suffix, needle_len - suffix); + LOG("Using period %d.\n", period); + + Py_ssize_t j = 0; + while (j <= haystack_len - needle_len) { + LOG("> "); LOG_STRING(haystack, haystack_len); + LOG("\n> "); LOG("%*s", j, ""); LOG_STRING(needle, needle_len); + LOG("\n"); + + STRINGLIB_CHAR last = haystack[j + needle_len - 1]; + int index = last & TABLE_MASK; + SHIFT_TYPE shift = shift_table[index]; + switch (shift) + { + case 0: { + break; + } + case NOT_FOUND: { + LOG("Last character not found in string.\n"); + j += needle_len; + continue; + } + default: { + LOG("Table says shift by %d.\n", shift); + j += shift; + continue; + } + } + + assert((haystack[j + needle_len - 1] & TABLE_MASK) + == (needle[needle_len - 1] & TABLE_MASK)); + + LOG("Checking the right half.\n"); + Py_ssize_t i = suffix; + for (; i < needle_len; i++) { + if (needle[i] != haystack[j + i]){ + LOG("No match.\n"); + break; + } + } + + if (i >= needle_len) { + LOG("Matches. Checking the left half.\n"); + i = suffix - 1; + for (i = suffix - 1; i >= 0; i--) { + if (needle[i] != haystack[j + i]) { + break; + } + } + if (i == -1) { + LOG("Match! (at %d)\n", j); + return j; + } + j += period; + } + else { + LOG("Jump forward without checking left half.\n"); + j += i - suffix + 1; + } + } + + } + LOG("Reached end. Returning -1.\n"); + return -1; +} + + +Py_LOCAL_INLINE(Py_ssize_t) +STRINGLIB(_fastsearch)(const STRINGLIB_CHAR *needle, Py_ssize_t needle_len, + const STRINGLIB_CHAR *haystack, Py_ssize_t haystack_len) +{ + Py_ssize_t index = STRINGLIB(find_char)(haystack, + haystack_len - needle_len + 1, + needle[0]); + if (index == -1) { + return -1; + } + + // Do a fast compare in all cases to maybe avoid the initialization overhead + if (memcmp(haystack+index, needle, needle_len*STRINGLIB_SIZEOF_CHAR) == 0) { + return index; + } + else { + // Start later. + index++; + } + + // Prework: factorize. + Py_ssize_t period, suffix; + suffix = STRINGLIB(_critical_factorization)(needle, needle_len, &period); + + // Prework: make a skip table. + SHIFT_TYPE shift_table[TABLE_SIZE]; + STRINGLIB(_init_table)(needle, needle_len, shift_table); + + Py_ssize_t result = STRINGLIB(_two_way)(needle, needle_len, + haystack + index, + haystack_len - index, + suffix, period, + shift_table); + + if (result == -1) { + return -1; + } + return index + result; +} + + +Py_LOCAL_INLINE(Py_ssize_t) +STRINGLIB(_fastcount)(const STRINGLIB_CHAR *needle, Py_ssize_t needle_len, + const STRINGLIB_CHAR *haystack, Py_ssize_t haystack_len, + Py_ssize_t maxcount) +{ + Py_ssize_t index = STRINGLIB(find_char)(haystack, + haystack_len - needle_len + 1, + needle[0]); + if (index == -1) { + return 0; + } + Py_ssize_t suffix, period; + suffix = STRINGLIB(_critical_factorization)(needle, needle_len, &period); + SHIFT_TYPE shift_table[TABLE_SIZE]; + STRINGLIB(_init_table)(needle, needle_len, shift_table); + Py_ssize_t count = 0; + while (1) { + Py_ssize_t result = STRINGLIB(_two_way)(needle, needle_len, + haystack + index, + haystack_len - index, + suffix, period, + shift_table); + if (result == -1) { + return count; + } + else { + count++; + if (count == maxcount) { + return maxcount; + } + index += result + needle_len; + } + } + +} + +#undef SHIFT_TYPE +#undef NOT_FOUND +#undef SHIFT_OVERFLOW +#undef TABLE_SIZE_BITS +#undef TABLE_SIZE +#undef TABLE_MASK + + Py_LOCAL_INLINE(Py_ssize_t) FASTSEARCH(const STRINGLIB_CHAR* s, Py_ssize_t n, const STRINGLIB_CHAR* p, Py_ssize_t m, @@ -195,10 +567,22 @@ FASTSEARCH(const STRINGLIB_CHAR* s, Py_ssize_t n, } mlast = m - 1; - skip = mlast - 1; + skip = mlast; mask = 0; if (mode != FAST_RSEARCH) { + if (n >= 4000 && m >= 20) { + /* long needles/haystacks get the two-way algorithm. */ + if (mode == FAST_SEARCH) { + return STRINGLIB(_fastsearch)(p, m, s, n); + } + else { + return STRINGLIB(_fastcount)(p, m, s, n, maxcount); + } + } + + /* Short needles use Fredrik Lundh's Horspool/Sunday hybrid + algorithm for less overhead. */ const STRINGLIB_CHAR *ss = s + m - 1; const STRINGLIB_CHAR *pp = p + m - 1; @@ -281,3 +665,5 @@ FASTSEARCH(const STRINGLIB_CHAR* s, Py_ssize_t n, return count; } +#undef LOG +#undef LOG_STRING