Skip to content

Commit 0a07cd9

Browse files
committed
Fix more surrogate crashes
1 parent c6cab4c commit 0a07cd9

23 files changed

+142
-140
lines changed

Lib/test/test_codecs.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,11 @@ def test_bug691291(self):
869869
with reader:
870870
self.assertEqual(reader.read(), s1)
871871

872+
# TODO: RUSTPYTHON
873+
@unittest.expectedFailure
874+
def test_incremental_surrogatepass(self):
875+
super().test_incremental_surrogatepass()
876+
872877
class UTF16LETest(ReadTest, unittest.TestCase):
873878
encoding = "utf-16-le"
874879
ill_formed_sequence = b"\x80\xdc"
@@ -917,6 +922,11 @@ def test_nonbmp(self):
917922
self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding),
918923
"\U00010203")
919924

925+
# TODO: RUSTPYTHON
926+
@unittest.expectedFailure
927+
def test_incremental_surrogatepass(self):
928+
super().test_incremental_surrogatepass()
929+
920930
class UTF16BETest(ReadTest, unittest.TestCase):
921931
encoding = "utf-16-be"
922932
ill_formed_sequence = b"\xdc\x80"
@@ -965,6 +975,11 @@ def test_nonbmp(self):
965975
self.assertEqual(b'\xd8\x00\xde\x03'.decode(self.encoding),
966976
"\U00010203")
967977

978+
# TODO: RUSTPYTHON
979+
@unittest.expectedFailure
980+
def test_incremental_surrogatepass(self):
981+
super().test_incremental_surrogatepass()
982+
968983
class UTF8Test(ReadTest, unittest.TestCase):
969984
encoding = "utf-8"
970985
ill_formed_sequence = b"\xed\xb2\x80"
@@ -998,8 +1013,6 @@ def test_decoder_state(self):
9981013
self.check_state_handling_decode(self.encoding,
9991014
u, u.encode(self.encoding))
10001015

1001-
# TODO: RUSTPYTHON
1002-
@unittest.expectedFailure
10031016
def test_decode_error(self):
10041017
for data, error_handler, expected in (
10051018
(b'[\x80\xff]', 'ignore', '[]'),
@@ -1026,8 +1039,6 @@ def test_lone_surrogates(self):
10261039
exc = cm.exception
10271040
self.assertEqual(exc.object[exc.start:exc.end], '\uD800\uDFFF')
10281041

1029-
# TODO: RUSTPYTHON
1030-
@unittest.expectedFailure
10311042
def test_surrogatepass_handler(self):
10321043
self.assertEqual("abc\ud800def".encode(self.encoding, "surrogatepass"),
10331044
self.BOM + b"abc\xed\xa0\x80def")
@@ -2884,8 +2895,6 @@ def test_escape_encode(self):
28842895

28852896
class SurrogateEscapeTest(unittest.TestCase):
28862897

2887-
# TODO: RUSTPYTHON
2888-
@unittest.expectedFailure
28892898
def test_utf8(self):
28902899
# Bad byte
28912900
self.assertEqual(b"foo\x80bar".decode("utf-8", "surrogateescape"),
@@ -2898,8 +2907,6 @@ def test_utf8(self):
28982907
self.assertEqual("\udced\udcb0\udc80".encode("utf-8", "surrogateescape"),
28992908
b"\xed\xb0\x80")
29002909

2901-
# TODO: RUSTPYTHON
2902-
@unittest.expectedFailure
29032910
def test_ascii(self):
29042911
# bad byte
29052912
self.assertEqual(b"foo\x80bar".decode("ascii", "surrogateescape"),
@@ -2916,8 +2923,6 @@ def test_charmap(self):
29162923
self.assertEqual("foo\udca5bar".encode("iso-8859-3", "surrogateescape"),
29172924
b"foo\xa5bar")
29182925

2919-
# TODO: RUSTPYTHON
2920-
@unittest.expectedFailure
29212926
def test_latin1(self):
29222927
# Issue6373
29232928
self.assertEqual("\udce4\udceb\udcef\udcf6\udcfc".encode("latin-1", "surrogateescape"),
@@ -3561,8 +3566,6 @@ class ASCIITest(unittest.TestCase):
35613566
def test_encode(self):
35623567
self.assertEqual('abc123'.encode('ascii'), b'abc123')
35633568

3564-
# TODO: RUSTPYTHON
3565-
@unittest.expectedFailure
35663569
def test_encode_error(self):
35673570
for data, error_handler, expected in (
35683571
('[\x80\xff\u20ac]', 'ignore', b'[]'),
@@ -3585,8 +3588,6 @@ def test_encode_surrogateescape_error(self):
35853588
def test_decode(self):
35863589
self.assertEqual(b'abc'.decode('ascii'), 'abc')
35873590

3588-
# TODO: RUSTPYTHON
3589-
@unittest.expectedFailure
35903591
def test_decode_error(self):
35913592
for data, error_handler, expected in (
35923593
(b'[\x80\xff]', 'ignore', '[]'),
@@ -3609,8 +3610,6 @@ def test_encode(self):
36093610
with self.subTest(data=data, expected=expected):
36103611
self.assertEqual(data.encode('latin1'), expected)
36113612

3612-
# TODO: RUSTPYTHON
3613-
@unittest.expectedFailure
36143613
def test_encode_errors(self):
36153614
for data, error_handler, expected in (
36163615
('[\u20ac\udc80]', 'ignore', b'[]'),

Lib/test/test_json/test_scanstring.py

-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,6 @@ def test_scanstring(self):
8686
scanstring('["Bad value", truth]', 2, True),
8787
('Bad value', 12))
8888

89-
# TODO: RUSTPYTHON
90-
@unittest.expectedFailure
9189
def test_surrogates(self):
9290
scanstring = self.json.decoder.scanstring
9391
def assertScan(given, expect):

Lib/test/test_regrtest.py

-2
Original file line numberDiff line numberDiff line change
@@ -945,15 +945,13 @@ def test_leak(self):
945945
""")
946946
self.check_leak(code, 'file descriptors')
947947

948-
@unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows')
949948
def test_list_tests(self):
950949
# test --list-tests
951950
tests = [self.create_test() for i in range(5)]
952951
output = self.run_tests('--list-tests', *tests)
953952
self.assertEqual(output.rstrip().splitlines(),
954953
tests)
955954

956-
@unittest.expectedFailureIfWindows('TODO: RUSTPYTHON Windows')
957955
def test_list_cases(self):
958956
# test --list-cases
959957
code = textwrap.dedent("""

Lib/test/test_stringprep.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from stringprep import *
77

88
class StringprepTests(unittest.TestCase):
9-
# TODO: RUSTPYTHON
10-
@unittest.expectedFailure
119
def test(self):
1210
self.assertTrue(in_table_a1("\u0221"))
1311
self.assertFalse(in_table_a1("\u0222"))

Lib/test/test_subprocess.py

-2
Original file line numberDiff line numberDiff line change
@@ -1198,8 +1198,6 @@ def test_universal_newlines_communicate_encodings(self):
11981198
stdout, stderr = popen.communicate(input='')
11991199
self.assertEqual(stdout, '1\n2\n3\n4')
12001200

1201-
# TODO: RUSTPYTHON
1202-
@unittest.expectedFailure
12031201
def test_communicate_errors(self):
12041202
for errors, expected in [
12051203
('ignore', ''),

Lib/test/test_tarfile.py

-14
Original file line numberDiff line numberDiff line change
@@ -2086,11 +2086,6 @@ class UstarUnicodeTest(UnicodeTest, unittest.TestCase):
20862086

20872087
format = tarfile.USTAR_FORMAT
20882088

2089-
# TODO: RUSTPYTHON
2090-
@unittest.expectedFailure
2091-
def test_uname_unicode(self):
2092-
super().test_uname_unicode()
2093-
20942089
# Test whether the utf-8 encoded version of a filename exceeds the 100
20952090
# bytes name field limit (every occurrence of '\xff' will be expanded to 2
20962091
# bytes).
@@ -2170,13 +2165,6 @@ class GNUUnicodeTest(UnicodeTest, unittest.TestCase):
21702165

21712166
format = tarfile.GNU_FORMAT
21722167

2173-
# TODO: RUSTPYTHON
2174-
@unittest.expectedFailure
2175-
def test_uname_unicode(self):
2176-
super().test_uname_unicode()
2177-
2178-
# TODO: RUSTPYTHON
2179-
@unittest.expectedFailure
21802168
def test_bad_pax_header(self):
21812169
# Test for issue #8633. GNU tar <= 1.23 creates raw binary fields
21822170
# without a hdrcharset=BINARY header.
@@ -2198,8 +2186,6 @@ class PAXUnicodeTest(UnicodeTest, unittest.TestCase):
21982186
# PAX_FORMAT ignores encoding in write mode.
21992187
test_unicode_filename_error = None
22002188

2201-
# TODO: RUSTPYTHON
2202-
@unittest.expectedFailure
22032189
def test_binary_header(self):
22042190
# Test a POSIX.1-2008 compatible header with a hdrcharset=BINARY field.
22052191
for encoding, name in (

Lib/test/test_unicode.py

-8
Original file line numberDiff line numberDiff line change
@@ -608,8 +608,6 @@ def test_bytes_comparison(self):
608608
self.assertEqual('abc' == bytearray(b'abc'), False)
609609
self.assertEqual('abc' != bytearray(b'abc'), True)
610610

611-
# TODO: RUSTPYTHON
612-
@unittest.expectedFailure
613611
def test_comparison(self):
614612
# Comparisons:
615613
self.assertEqual('abc', 'abc')
@@ -830,8 +828,6 @@ def test_isidentifier_legacy(self):
830828
warnings.simplefilter('ignore', DeprecationWarning)
831829
self.assertTrue(_testcapi.unicode_legacy_string(u).isidentifier())
832830

833-
# TODO: RUSTPYTHON
834-
@unittest.expectedFailure
835831
def test_isprintable(self):
836832
self.assertTrue("".isprintable())
837833
self.assertTrue(" ".isprintable())
@@ -847,8 +843,6 @@ def test_isprintable(self):
847843
self.assertTrue('\U0001F46F'.isprintable())
848844
self.assertFalse('\U000E0020'.isprintable())
849845

850-
# TODO: RUSTPYTHON
851-
@unittest.expectedFailure
852846
def test_surrogates(self):
853847
for s in ('a\uD800b\uDFFF', 'a\uDFFFb\uD800',
854848
'a\uD800b\uDFFFa', 'a\uDFFFb\uD800a'):
@@ -1827,8 +1821,6 @@ def test_codecs_utf7(self):
18271821
'ill-formed sequence'):
18281822
b'+@'.decode('utf-7')
18291823

1830-
# TODO: RUSTPYTHON
1831-
@unittest.expectedFailure
18321824
def test_codecs_utf8(self):
18331825
self.assertEqual(''.encode('utf-8'), b'')
18341826
self.assertEqual('\u20ac'.encode('utf-8'), b'\xe2\x82\xac')

Lib/test/test_userstring.py

-4
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,13 @@ def __rmod__(self, other):
5353
str3 = ustr3('TEST')
5454
self.assertEqual(fmt2 % str3, 'value is TEST')
5555

56-
# TODO: RUSTPYTHON
57-
@unittest.expectedFailure
5856
def test_encode_default_args(self):
5957
self.checkequal(b'hello', 'hello', 'encode')
6058
# Check that encoding defaults to utf-8
6159
self.checkequal(b'\xf0\xa3\x91\x96', '\U00023456', 'encode')
6260
# Check that errors defaults to 'strict'
6361
self.checkraises(UnicodeError, '\ud800', 'encode')
6462

65-
# TODO: RUSTPYTHON
66-
@unittest.expectedFailure
6763
def test_encode_explicit_none_args(self):
6864
self.checkequal(b'hello', 'hello', 'encode', None, None)
6965
# Check that encoding defaults to utf-8

Lib/test/test_zipimport.py

+1
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ def testTraceback(self):
730730

731731
@unittest.skipIf(os_helper.TESTFN_UNENCODABLE is None,
732732
"need an unencodable filename")
733+
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
733734
def testUnencodable(self):
734735
filename = os_helper.TESTFN_UNENCODABLE + ".zip"
735736
self.addCleanup(os_helper.unlink, filename)

common/src/wtf8/mod.rs

+34-21
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,18 @@ impl CodePoint {
122122

123123
/// Returns the numeric value of the code point if it is a leading surrogate.
124124
#[inline]
125-
pub fn to_lead_surrogate(self) -> Option<u16> {
125+
pub fn to_lead_surrogate(self) -> Option<LeadSurrogate> {
126126
match self.value {
127-
lead @ 0xD800..=0xDBFF => Some(lead as u16),
127+
lead @ 0xD800..=0xDBFF => Some(LeadSurrogate(lead as u16)),
128128
_ => None,
129129
}
130130
}
131131

132132
/// Returns the numeric value of the code point if it is a trailing surrogate.
133133
#[inline]
134-
pub fn to_trail_surrogate(self) -> Option<u16> {
134+
pub fn to_trail_surrogate(self) -> Option<TrailSurrogate> {
135135
match self.value {
136-
trail @ 0xDC00..=0xDFFF => Some(trail as u16),
136+
trail @ 0xDC00..=0xDFFF => Some(TrailSurrogate(trail as u16)),
137137
_ => None,
138138
}
139139
}
@@ -216,6 +216,18 @@ impl PartialEq<CodePoint> for char {
216216
}
217217
}
218218

219+
#[derive(Clone, Copy)]
220+
pub struct LeadSurrogate(u16);
221+
222+
#[derive(Clone, Copy)]
223+
pub struct TrailSurrogate(u16);
224+
225+
impl LeadSurrogate {
226+
pub fn merge(self, trail: TrailSurrogate) -> char {
227+
decode_surrogate_pair(self.0, trail.0)
228+
}
229+
}
230+
219231
/// An owned, growable string of well-formed WTF-8 data.
220232
///
221233
/// Similar to `String`, but can additionally contain surrogate code points
@@ -291,6 +303,14 @@ impl Wtf8Buf {
291303
Wtf8Buf { bytes: value }
292304
}
293305

306+
/// Create a WTF-8 string from a WTF-8 byte vec.
307+
pub fn from_bytes(value: Vec<u8>) -> Result<Self, Vec<u8>> {
308+
match Wtf8::from_bytes(&value) {
309+
Some(_) => Ok(unsafe { Self::from_bytes_unchecked(value) }),
310+
None => Err(value),
311+
}
312+
}
313+
294314
/// Creates a WTF-8 string from a UTF-8 `String`.
295315
///
296316
/// This takes ownership of the `String` and does not copy.
@@ -750,15 +770,10 @@ impl Wtf8 {
750770
}
751771

752772
fn decode_surrogate(b: &[u8]) -> Option<CodePoint> {
753-
let [a, b, c, ..] = *b else { return None };
754-
if (a & 0xf0) == 0xe0 && (b & 0xc0) == 0x80 && (c & 0xc0) == 0x80 {
755-
// it's a three-byte code
756-
let c = ((a as u32 & 0x0f) << 12) + ((b as u32 & 0x3f) << 6) + (c as u32 & 0x3f);
757-
let 0xD800..=0xDFFF = c else { return None };
758-
Some(CodePoint { value: c })
759-
} else {
760-
None
761-
}
773+
let [0xed, b2 @ (0xa0..), b3, ..] = *b else {
774+
return None;
775+
};
776+
Some(decode_surrogate(b2, b3).into())
762777
}
763778

764779
/// Returns the length, in WTF-8 bytes.
@@ -914,14 +929,6 @@ impl Wtf8 {
914929
}
915930
}
916931

917-
#[inline]
918-
fn final_lead_surrogate(&self) -> Option<u16> {
919-
match self.bytes {
920-
[.., 0xED, b2 @ 0xA0..=0xAF, b3] => Some(decode_surrogate(b2, b3)),
921-
_ => None,
922-
}
923-
}
924-
925932
pub fn is_code_point_boundary(&self, index: usize) -> bool {
926933
is_code_point_boundary(self, index)
927934
}
@@ -1222,6 +1229,12 @@ fn decode_surrogate(second_byte: u8, third_byte: u8) -> u16 {
12221229
0xD800 | (second_byte as u16 & 0x3F) << 6 | third_byte as u16 & 0x3F
12231230
}
12241231

1232+
#[inline]
1233+
fn decode_surrogate_pair(lead: u16, trail: u16) -> char {
1234+
let code_point = 0x10000 + ((((lead - 0xD800) as u32) << 10) | (trail - 0xDC00) as u32);
1235+
unsafe { char::from_u32_unchecked(code_point) }
1236+
}
1237+
12251238
/// Copied from str::is_char_boundary
12261239
#[inline]
12271240
fn is_code_point_boundary(slice: &Wtf8, index: usize) -> bool {

stdlib/src/json.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ mod _json {
1313
types::{Callable, Constructor},
1414
};
1515
use malachite_bigint::BigInt;
16+
use rustpython_common::wtf8::Wtf8Buf;
1617
use std::str::FromStr;
1718

1819
#[pyattr(name = "make_scanner")]
@@ -253,8 +254,8 @@ mod _json {
253254
end: usize,
254255
strict: OptionalArg<bool>,
255256
vm: &VirtualMachine,
256-
) -> PyResult<(String, usize)> {
257-
machinery::scanstring(s.as_str(), end, strict.unwrap_or(true))
257+
) -> PyResult<(Wtf8Buf, usize)> {
258+
machinery::scanstring(s.as_wtf8(), end, strict.unwrap_or(true))
258259
.map_err(|e| py_decode_error(e, s, vm))
259260
}
260261
}

0 commit comments

Comments
 (0)