Skip to content

Commit 960e86c

Browse files
committed
Fix more surrogate crashes
1 parent e9e116b commit 960e86c

20 files changed

+126
-122
lines changed

Lib/test/test_json/test_scanstring.py

-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def test_scanstring(self):
8686
scanstring('["Bad value", truth]', 2, True),
8787
('Bad value', 12))
8888

89-
# TODO: RUSTPYTHON
90-
@unittest.expectedFailure
9189
def test_surrogates(self):
9290
scanstring = self.json.decoder.scanstring
9391
def assertScan(given, expect):

Lib/test/test_stringprep.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from stringprep import *
77

88
class StringprepTests(unittest.TestCase):
9-
# TODO: RUSTPYTHON
10-
@unittest.expectedFailure
119
def test(self):
1210
self.assertTrue(in_table_a1("\u0221"))
1311
self.assertFalse(in_table_a1("\u0222"))

Lib/test/test_subprocess.py

-2
Original file line numberDiff line numberDiff line change
@@ -1198,8 +1198,6 @@ def test_universal_newlines_communicate_encodings(self):
11981198
stdout, stderr = popen.communicate(input='')
11991199
self.assertEqual(stdout, '1\n2\n3\n4')
12001200

1201-
# TODO: RUSTPYTHON
1202-
@unittest.expectedFailure
12031201
def test_communicate_errors(self):
12041202
for errors, expected in [
12051203
('ignore', ''),

Lib/test/test_tarfile.py

-14
Original file line numberDiff line numberDiff line change
@@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase):
20862086

20872087
format = tarfile.USTAR_FORMAT
20882088

2089-
# TODO: RUSTPYTHON
2090-
@unittest.expectedFailure
2091-
def test_uname_unicode(self):
2092-
super().test_uname_unicode()
2093-
20942089
# Test whether the utf-8 encoded version of a filename exceeds the 100
20952090
# bytes name field limit (every occurrence of '\xff' will be expanded to 2
20962091
# bytes).
@@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase):
21702165

21712166
format = tarfile.GNU_FORMAT
21722167

2173-
# TODO: RUSTPYTHON
2174-
@unittest.expectedFailure
2175-
def test_uname_unicode(self):
2176-
super().test_uname_unicode()
2177-
2178-
# TODO: RUSTPYTHON
2179-
@unittest.expectedFailure
21802168
def test_bad_pax_header(self):
21812169
# Test for issue #8633. GNU tar <= 1.23 creates raw binary fields
21822170
# without a hdrcharset=BINARY header.
@@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase):
21982186
# PAX_FORMAT ignores encoding in write mode.
21992187
test_unicode_filename_error = None
22002188

2201-
# TODO: RUSTPYTHON
2202-
@unittest.expectedFailure
22032189
def test_binary_header(self):
22042190
# Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field.
22052191
for encoding, name in (

Lib/test/test_unicode.py

-8
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,6 @@ def test_bytes_comparison(self):
608608
self.assertEqual('abc' == bytearray(b'abc'), False)
609609
self.assertEqual('abc' != bytearray(b'abc'), True)
610610

611-
# TODO: RUSTPYTHON
612-
@unittest.expectedFailure
613611
def test_comparison(self):
614612
# Comparisons:
615613
self.assertEqual('abc', 'abc')
@@ -830,8 +828,6 @@ def test_isidentifier_legacy(self):
830828
warnings.simplefilter('ignore', DeprecationWarning)
831829
self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier())
832830

833-
# TODO: RUSTPYTHON
834-
@unittest.expectedFailure
835831
def test_isprintable(self):
836832
self.assertTrue("".isprintable())
837833
self.assertTrue(" ".isprintable())
@@ -847,8 +843,6 @@ def test_isprintable(self):
847843
self.assertTrue('\U0001F46F'.isprintable())
848844
self.assertFalse('\U000E0020'.isprintable())
849845

850-
# TODO: RUSTPYTHON
851-
@unittest.expectedFailure
852846
def test_surrogates(self):
853847
for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800',
854848
'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'):
@@ -1827,8 +1821,6 @@ def test_codecs_utf7(self):
18271821
'ill-formed sequence'):
18281822
b'+@'.decode('utf-7')
18291823

1830-
# TODO: RUSTPYTHON
1831-
@unittest.expectedFailure
18321824
def test_codecs_utf8(self):
18331825
self.assertEqual(''.encode('utf-8'), b'')
18341826
self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac')

Lib/test/test_userstring.py

-4
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,13 @@ def __rmod__(self, other):
5353
str3 = ustr3('TEST')
5454
self.assertEqual(fmt2 % str3, 'value is TEST')
5555

56-
# TODO: RUSTPYTHON
57-
@unittest.expectedFailure
5856
def test_encode_default_args(self):
5957
self.checkequal(b'hello', 'hello', 'encode')
6058
# Check that encoding defaults to utf-8
6159
self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode')
6260
# Check that errors defaults to 'strict'
6361
self.checkraises(UnicodeError, '\ud800', 'encode')
6462

65-
# TODO: RUSTPYTHON
66-
@unittest.expectedFailure
6763
def test_encode_explicit_none_args(self):
6864
self.checkequal(b'hello', 'hello', 'encode', None, None)
6965
# Check that encoding defaults to utf-8

common/src/wtf8/mod.rs

+34-21
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ impl CodePoint {
122122

123123
/// Returns the numeric value of the code point if it is a leading surrogate.
124124
#[inline]
125-
pub fn to_lead_surrogate(self) -> Option<u16> {
125+
pub fn to_lead_surrogate(self) -> Option<LeadSurrogate> {
126126
match self.value {
127-
lead @ 0xD800..=0xDBFF => Some(lead as u16),
127+
lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)),
128128
_ => None,
129129
}
130130
}
131131

132132
/// Returns the numeric value of the code point if it is a trailing surrogate.
133133
#[inline]
134-
pub fn to_trail_surrogate(self) -> Option<u16> {
134+
pub fn to_trail_surrogate(self) -> Option<TrailSurrogate> {
135135
match self.value {
136-
trail @ 0xDC00..=0xDFFF => Some(trail as u16),
136+
trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)),
137137
_ => None,
138138
}
139139
}
@@ -216,6 +216,18 @@ impl PartialEq<CodePoint> for char {
216216
}
217217
}
218218

219+
#[derive(Clone, Copy)]
220+
pub struct LeadSurrogate(u16);
221+
222+
#[derive(Clone, Copy)]
223+
pub struct TrailSurrogate(u16);
224+
225+
impl LeadSurrogate {
226+
pub fn merge(self, trail: TrailSurrogate) -> char {
227+
decode_surrogate_pair(self.0, trail.0)
228+
}
229+
}
230+
219231
/// An owned, growable string of well-formed WTF-8 data.
220232
///
221233
/// Similar to `String`, but can additionally contain surrogate code points
@@ -291,6 +303,14 @@ impl Wtf8Buf {
291303
Wtf8Buf { bytes: value }
292304
}
293305

306+
/// Create a WTF-8 string from a WTF-8 byte vec.
307+
pub fn from_bytes(value: Vec<u8>) -> Result<Self, Vec<u8>> {
308+
match Wtf8::from_bytes(&value) {
309+
Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }),
310+
None => Err(value),
311+
}
312+
}
313+
294314
/// Creates a WTF-8 string from a UTF-8 `String`.
295315
///
296316
/// This takes ownership of the `String` and does not copy.
@@ -750,15 +770,10 @@ impl Wtf8 {
750770
}
751771

752772
fn decode_surrogate(b: &[u8]) -> Option<CodePoint> {
753-
let [a, b, c, ..] = *b else { return None };
754-
if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 {
755-
// it's a three-byte code
756-
let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f);
757-
let 0xD800..=0xDFFF = c else { return None };
758-
Some(CodePoint { value: c })
759-
} else {
760-
None
761-
}
773+
let [0xed, b2 @ (0xa0..), b3, ..] = *b else {
774+
return None;
775+
};
776+
Some(decode_surrogate(b2, b3).into())
762777
}
763778

764779
/// Returns the length, in WTF-8 bytes.
@@ -914,14 +929,6 @@ impl Wtf8 {
914929
}
915930
}
916931

917-
#[inline]
918-
fn final_lead_surrogate(&self) -> Option<u16> {
919-
match self.bytes {
920-
[.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)),
921-
_ => None,
922-
}
923-
}
924-
925932
pub fn is_code_point_boundary(&self, index: usize) -> bool {
926933
is_code_point_boundary(self, index)
927934
}
@@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 {
12221229
0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F
12231230
}
12241231

1232+
#[inline]
1233+
fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
1234+
let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32);
1235+
unsafe { char::from_u32_unchecked(code_point) }
1236+
}
1237+
12251238
/// Copied from str::is_char_boundary
12261239
#[inline]
12271240
fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {

stdlib/src/json.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod _json {
1313
types::{Callable, Constructor},
1414
};
1515
use malachite_bigint::BigInt;
16+
use rustpython_common::wtf8::Wtf8Buf;
1617
use std::str::FromStr;
1718

1819
#[pyattr(name = "make_scanner")]
@@ -253,8 +254,8 @@ mod _json {
253254
end: usize,
254255
strict: OptionalArg<bool>,
255256
vm: &VirtualMachine,
256-
) -> PyResult<(String, usize)> {
257-
machinery::scanstring(s.as_str(), end, strict.unwrap_or(true))
257+
) -> PyResult<(Wtf8Buf, usize)> {
258+
machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true))
258259
.map_err(|e| py_decode_error(e, s, vm))
259260
}
260261
}

stdlib/src/json/machinery.rs

+34-39
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
use std::io;
3030

31+
use itertools::Itertools;
32+
use rustpython_common::wtf8::{CodePoint, Wtf8, Wtf8Buf};
33+
3134
static ESCAPE_CHARS: [&str; 0x20] = [
3235
"\\u0000", "\\u0001", "\\u0002", "\\u0003", "\\u0004", "\\u0005", "\\u0006", "\\u0007", "\\b",
3336
"\\t", "\\n", "\\u000", "\\f", "\\r", "\\u000e", "\\u000f", "\\u0010", "\\u0011", "\\u0012",
@@ -111,39 +114,39 @@ impl DecodeError {
111114
}
112115

113116
enum StrOrChar<'a> {
114-
Str(&'a str),
115-
Char(char),
117+
Str(&'a Wtf8),
118+
Char(CodePoint),
116119
}
117120
impl StrOrChar<'_> {
118121
fn len(&self) -> usize {
119122
match self {
120123
StrOrChar::Str(s) => s.len(),
121-
StrOrChar::Char(c) => c.len_utf8(),
124+
StrOrChar::Char(c) => c.len_wtf8(),
122125
}
123126
}
124127
}
125128
pub fn scanstring<'a>(
126-
s: &'a str,
129+
s: &'a Wtf8,
127130
end: usize,
128131
strict: bool,
129-
) -> Result<(String, usize), DecodeError> {
132+
) -> Result<(Wtf8Buf, usize), DecodeError> {
130133
let mut chunks: Vec<StrOrChar<'a>> = Vec::new();
131134
let mut output_len = 0usize;
132135
let mut push_chunk = |chunk: StrOrChar<'a>| {
133136
output_len += chunk.len();
134137
chunks.push(chunk);
135138
};
136139
let unterminated_err = || DecodeError::new("Unterminated string starting at", end - 1);
137-
let mut chars = s.char_indices().enumerate().skip(end).peekable();
140+
let mut chars = s.code_point_indices().enumerate().skip(end).peekable();
138141
let &(_, (mut chunk_start, _)) = chars.peek().ok_or_else(unterminated_err)?;
139142
while let Some((char_i, (byte_i, c))) = chars.next() {
140-
match c {
143+
match c.to_char_lossy() {
141144
'"' => {
142145
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
143-
let mut out = String::with_capacity(output_len);
146+
let mut out = Wtf8Buf::with_capacity(output_len);
144147
for x in chunks {
145148
match x {
146-
StrOrChar::Str(s) => out.push_str(s),
149+
StrOrChar::Str(s) => out.push_wtf8(s),
147150
StrOrChar::Char(c) => out.push(c),
148151
}
149152
}
@@ -152,7 +155,7 @@ pub fn scanstring<'a>(
152155
'\\' => {
153156
push_chunk(StrOrChar::Str(&s[chunk_start..byte_i]));
154157
let (_, (_, c)) = chars.next().ok_or_else(unterminated_err)?;
155-
let esc = match c {
158+
let esc = match c.to_char_lossy() {
156159
'"' => "\"",
157160
'\\' => "\\",
158161
'/' => "/",
@@ -162,41 +165,33 @@ pub fn scanstring<'a>(
162165
'r' => "\r",
163166
't' => "\t",
164167
'u' => {
165-
let surrogate_err = || DecodeError::new("unpaired surrogate", char_i);
166168
let mut uni = decode_unicode(&mut chars, char_i)?;
167169
chunk_start = byte_i + 6;
168-
if (0xd800..=0xdbff).contains(&uni) {
170+
if let Some(lead) = uni.to_lead_surrogate() {
169171
// uni is a surrogate -- try to find its pair
170-
if let Some(&(pos2, (_, '\\'))) = chars.peek() {
171-
// ok, the next char starts an escape
172-
chars.next();
173-
if let Some((_, (_, 'u'))) = chars.peek() {
174-
// ok, it's a unicode escape
175-
chars.next();
176-
let uni2 = decode_unicode(&mut chars, pos2)?;
172+
let mut chars2 = chars.clone();
173+
if let Some(((pos2, _), (_, _))) = chars2
174+
.next_tuple()
175+
.filter(|((_, (_, c1)), (_, (_, c2)))| *c1 == '\\' && *c2 == 'u')
176+
{
177+
let uni2 = decode_unicode(&mut chars2, pos2)?;
178+
if let Some(trail) = uni2.to_trail_surrogate() {
179+
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
180+
uni = lead.merge(trail).into();
177181
chunk_start = pos2 + 6;
178-
if (0xdc00..=0xdfff).contains(&uni2) {
179-
// ok, we found what we were looking for -- \uXXXX\uXXXX, both surrogates
180-
uni = 0x10000 + (((uni - 0xd800) << 10) | (uni2 - 0xdc00));
181-
} else {
182-
// if we don't find a matching surrogate, error -- until str
183-
// isn't utf8 internally, we can't parse surrogates
184-
return Err(surrogate_err());
185-
}
186-
} else {
187-
return Err(surrogate_err());
182+
chars = chars2;
188183
}
189184
}
190185
}
191-
push_chunk(StrOrChar::Char(
192-
std::char::from_u32(uni).ok_or_else(surrogate_err)?,
193-
));
186+
push_chunk(StrOrChar::Char(uni));
194187
continue;
195188
}
196-
_ => return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i)),
189+
_ => {
190+
return Err(DecodeError::new(format!("Invalid \\escape: {c:?}"), char_i));
191+
}
197192
};
198193
chunk_start = byte_i + 2;
199-
push_chunk(StrOrChar::Str(esc));
194+
push_chunk(StrOrChar::Str(esc.as_ref()));
200195
}
201196
'\x00'..='\x1f' if strict => {
202197
return Err(DecodeError::new(
@@ -211,16 +206,16 @@ pub fn scanstring<'a>(
211206
}
212207

213208
#[inline]
214-
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<u32, DecodeError>
209+
fn decode_unicode<I>(it: &mut I, pos: usize) -> Result<CodePoint, DecodeError>
215210
where
216-
I: Iterator<Item = (usize, (usize, char))>,
211+
I: Iterator<Item = (usize, (usize, CodePoint))>,
217212
{
218213
let err = || DecodeError::new("Invalid \\uXXXX escape", pos);
219214
let mut uni = 0;
220215
for x in (0..4).rev() {
221216
let (_, (_, c)) = it.next().ok_or_else(err)?;
222-
let d = c.to_digit(16).ok_or_else(err)?;
223-
uni += d * 16u32.pow(x);
217+
let d = c.to_char().and_then(|c| c.to_digit(16)).ok_or_else(err)? as u16;
218+
uni += d * 16u16.pow(x);
224219
}
225-
Ok(uni)
220+
Ok(uni.into())
226221
}

0 commit comments

Comments
 (0)