Skip to content

Commit bd55bae

Browse files
committed
Optimize Wtf8Codepoints::count
1 parent a861264 commit bd55bae

File tree

2 files changed

+173
-2
lines changed

2 files changed

+173
-2
lines changed

common/src/wtf8/core_str_count.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//! Modified from core::str::count
2+
3+
use super::Wtf8;
4+
5+
const USIZE_SIZE: usize = core::mem::size_of::<usize>();
6+
const UNROLL_INNER: usize = 4;
7+
8+
#[inline]
9+
pub(super) fn count_chars(s: &Wtf8) -> usize {
10+
if s.len() < USIZE_SIZE * UNROLL_INNER {
11+
// Avoid entering the optimized implementation for strings where the
12+
// difference is not likely to matter, or where it might even be slower.
13+
// That said, a ton of thought was not spent on the particular threshold
14+
// here, beyond "this value seems to make sense".
15+
char_count_general_case(s.as_bytes())
16+
} else {
17+
do_count_chars(s)
18+
}
19+
}
20+
21+
fn do_count_chars(s: &Wtf8) -> usize {
22+
// For correctness, `CHUNK_SIZE` must be:
23+
//
24+
// - Less than or equal to 255, otherwise we'll overflow bytes in `counts`.
25+
// - A multiple of `UNROLL_INNER`, otherwise our `break` inside the
26+
// `body.chunks(CHUNK_SIZE)` loop is incorrect.
27+
//
28+
// For performance, `CHUNK_SIZE` should be:
29+
// - Relatively cheap to `/` against (so some simple sum of powers of two).
30+
// - Large enough to avoid paying for the cost of the `sum_bytes_in_usize`
31+
// too often.
32+
const CHUNK_SIZE: usize = 192;
33+
34+
// Check the properties of `CHUNK_SIZE` and `UNROLL_INNER` that are required
35+
// for correctness.
36+
const _: () = assert!(CHUNK_SIZE < 256);
37+
const _: () = assert!(CHUNK_SIZE % UNROLL_INNER == 0);
38+
39+
// SAFETY: transmuting `[u8]` to `[usize]` is safe except for size
40+
// differences which are handled by `align_to`.
41+
let (head, body, tail) = unsafe { s.as_bytes().align_to::<usize>() };
42+
43+
// This should be quite rare, and basically exists to handle the degenerate
44+
// cases where align_to fails (as well as miri under symbolic alignment
45+
// mode).
46+
//
47+
// The `unlikely` helps discourage LLVM from inlining the body, which is
48+
// nice, as we would rather not mark the `char_count_general_case` function
49+
// as cold.
50+
if unlikely(body.is_empty() || head.len() > USIZE_SIZE || tail.len() > USIZE_SIZE) {
51+
return char_count_general_case(s.as_bytes());
52+
}
53+
54+
let mut total = char_count_general_case(head) + char_count_general_case(tail);
55+
// Split `body` into `CHUNK_SIZE` chunks to reduce the frequency with which
56+
// we call `sum_bytes_in_usize`.
57+
for chunk in body.chunks(CHUNK_SIZE) {
58+
// We accumulate intermediate sums in `counts`, where each byte contains
59+
// a subset of the sum of this chunk, like a `[u8; size_of::<usize>()]`.
60+
let mut counts = 0;
61+
62+
let (unrolled_chunks, remainder) = slice_as_chunks::<_, UNROLL_INNER>(chunk);
63+
for unrolled in unrolled_chunks {
64+
for &word in unrolled {
65+
// Because `CHUNK_SIZE` is < 256, this addition can't cause the
66+
// count in any of the bytes to overflow into a subsequent byte.
67+
counts += contains_non_continuation_byte(word);
68+
}
69+
}
70+
71+
// Sum the values in `counts` (which, again, is conceptually a `[u8;
72+
// size_of::<usize>()]`), and accumulate the result into `total`.
73+
total += sum_bytes_in_usize(counts);
74+
75+
// If there's any data in `remainder`, then handle it. This will only
76+
// happen for the last `chunk` in `body.chunks()` (because `CHUNK_SIZE`
77+
// is divisible by `UNROLL_INNER`), so we explicitly break at the end
78+
// (which seems to help LLVM out).
79+
if !remainder.is_empty() {
80+
// Accumulate all the data in the remainder.
81+
let mut counts = 0;
82+
for &word in remainder {
83+
counts += contains_non_continuation_byte(word);
84+
}
85+
total += sum_bytes_in_usize(counts);
86+
break;
87+
}
88+
}
89+
total
90+
}
91+
92+
// Checks each byte of `w` to see if it contains the first byte in a UTF-8
93+
// sequence. Bytes in `w` which are continuation bytes are left as `0x00` (e.g.
94+
// false), and bytes which are non-continuation bytes are left as `0x01` (e.g.
95+
// true)
96+
#[inline]
97+
fn contains_non_continuation_byte(w: usize) -> usize {
98+
const LSB: usize = usize_repeat_u8(0x01);
99+
((!w >> 7) | (w >> 6)) & LSB
100+
}
101+
102+
// Morally equivalent to `values.to_ne_bytes().into_iter().sum::<usize>()`, but
103+
// more efficient.
104+
#[inline]
105+
fn sum_bytes_in_usize(values: usize) -> usize {
106+
const LSB_SHORTS: usize = usize_repeat_u16(0x0001);
107+
const SKIP_BYTES: usize = usize_repeat_u16(0x00ff);
108+
109+
let pair_sum: usize = (values & SKIP_BYTES) + ((values >> 8) & SKIP_BYTES);
110+
pair_sum.wrapping_mul(LSB_SHORTS) >> ((USIZE_SIZE - 2) * 8)
111+
}
112+
113+
// This is the most direct implementation of the concept of "count the number of
114+
// bytes in the string which are not continuation bytes", and is used for the
115+
// head and tail of the input string (the first and last item in the tuple
116+
// returned by `slice::align_to`).
117+
fn char_count_general_case(s: &[u8]) -> usize {
118+
s.iter()
119+
.filter(|&&byte| !super::core_str::utf8_is_cont_byte(byte))
120+
.count()
121+
}
122+
123+
// polyfills of unstable library features
124+
125+
const fn usize_repeat_u8(x: u8) -> usize {
126+
usize::from_ne_bytes([x; size_of::<usize>()])
127+
}
128+
129+
const fn usize_repeat_u16(x: u16) -> usize {
130+
let mut r = 0usize;
131+
let mut i = 0;
132+
while i < size_of::<usize>() {
133+
// Use `wrapping_shl` to make it work on targets with 16-bit `usize`
134+
r = r.wrapping_shl(16) | (x as usize);
135+
i += 2;
136+
}
137+
r
138+
}
139+
140+
fn slice_as_chunks<T, const N: usize>(slice: &[T]) -> (&[[T; N]], &[T]) {
141+
assert!(N != 0, "chunk size must be non-zero");
142+
let len_rounded_down = slice.len() / N * N;
143+
// SAFETY: The rounded-down value is always the same or smaller than the
144+
// original length, and thus must be in-bounds of the slice.
145+
let (multiple_of_n, remainder) = unsafe { slice.split_at_unchecked(len_rounded_down) };
146+
// SAFETY: We already panicked for zero, and ensured by construction
147+
// that the length of the subslice is a multiple of N.
148+
let array_slice = unsafe { slice_as_chunks_unchecked(multiple_of_n) };
149+
(array_slice, remainder)
150+
}
151+
152+
unsafe fn slice_as_chunks_unchecked<T, const N: usize>(slice: &[T]) -> &[[T; N]] {
153+
let new_len = slice.len() / N;
154+
// SAFETY: We cast a slice of `new_len * N` elements into
155+
// a slice of `new_len` many `N` elements chunks.
156+
unsafe { std::slice::from_raw_parts(slice.as_ptr().cast(), new_len) }
157+
}
158+
159+
fn unlikely(x: bool) -> bool {
160+
x
161+
}

common/src/wtf8/mod.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ use bstr::{ByteSlice, ByteVec};
5353

5454
mod core_char;
5555
mod core_str;
56+
mod core_str_count;
5657

5758
const UTF8_REPLACEMENT_CHARACTER: &str = "\u{FFFD}";
5859

@@ -1256,6 +1257,10 @@ impl Iterator for Wtf8CodePoints<'_> {
12561257
fn last(mut self) -> Option<Self::Item> {
12571258
self.next_back()
12581259
}
1260+
1261+
fn count(self) -> usize {
1262+
core_str_count::count_chars(self.as_wtf8())
1263+
}
12591264
}
12601265

12611266
impl DoubleEndedIterator for Wtf8CodePoints<'_> {
@@ -1277,8 +1282,8 @@ impl<'a> Wtf8CodePoints<'a> {
12771282

12781283
#[derive(Clone)]
12791284
pub struct Wtf8CodePointIndices<'a> {
1280-
pub(super) front_offset: usize,
1281-
pub(super) iter: Wtf8CodePoints<'a>,
1285+
front_offset: usize,
1286+
iter: Wtf8CodePoints<'a>,
12821287
}
12831288

12841289
impl Iterator for Wtf8CodePointIndices<'_> {
@@ -1308,6 +1313,11 @@ impl Iterator for Wtf8CodePointIndices<'_> {
13081313
// No need to go through the entire string.
13091314
self.next_back()
13101315
}
1316+
1317+
#[inline]
1318+
fn count(self) -> usize {
1319+
self.iter.count()
1320+
}
13111321
}
13121322

13131323
impl DoubleEndedIterator for Wtf8CodePointIndices<'_> {

0 commit comments

Comments
 (0)