Skip to content

Commit cc6f3d3

Browse files
committed
Make TextIOWrapper wtf8-compatible
1 parent b36b32b commit cc6f3d3

File tree

4 files changed

+53
-38
lines changed

4 files changed

+53
-38
lines changed

common/src/wtf8/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,12 @@ impl ToOwned for Wtf8 {
613613
}
614614
}
615615

616+
impl PartialEq<str> for Wtf8 {
617+
fn eq(&self, other: &str) -> bool {
618+
self.as_bytes().eq(other.as_bytes())
619+
}
620+
}
621+
616622
/// Formats the string in double quotes, with characters escaped according to
617623
/// [`char::escape_debug`] and unpaired surrogates represented as `\u{xxxx}`,
618624
/// where each `x` is a hexadecimal digit.
@@ -1046,6 +1052,11 @@ impl Wtf8 {
10461052
.strip_suffix(w.as_bytes())
10471053
.map(|w| unsafe { Wtf8::from_bytes_unchecked(w) })
10481054
}
1055+
1056+
pub fn replace(&self, from: &Wtf8, to: &Wtf8) -> Wtf8Buf {
1057+
let w = self.bytes.replace(from, to);
1058+
unsafe { Wtf8Buf::from_bytes_unchecked(w) }
1059+
}
10491060
}
10501061

10511062
impl AsRef<Wtf8> for str {

vm/src/builtins/str.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,14 +1408,14 @@ impl PyStrRef {
14081408
(**self).is_empty()
14091409
}
14101410

1411-
pub fn concat_in_place(&mut self, other: &str, vm: &VirtualMachine) {
1411+
pub fn concat_in_place(&mut self, other: &Wtf8, vm: &VirtualMachine) {
14121412
// TODO: call [A]Rc::get_mut on the str to try to mutate the data in place
14131413
if other.is_empty() {
14141414
return;
14151415
}
1416-
let mut s = String::with_capacity(self.byte_len() + other.len());
1417-
s.push_str(self.as_ref());
1418-
s.push_str(other);
1416+
let mut s = Wtf8Buf::with_capacity(self.byte_len() + other.len());
1417+
s.push_wtf8(self.as_ref());
1418+
s.push_wtf8(other);
14191419
*self = PyStr::from(s).into_ref(&vm.ctx);
14201420
}
14211421
}

vm/src/dictdatatype.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ impl DictKey for str {
835835

836836
fn key_eq(&self, vm: &VirtualMachine, other_key: &PyObject) -> PyResult<bool> {
837837
if let Some(pystr) = other_key.payload_if_exact::<PyStr>(vm) {
838-
Ok(pystr.as_wtf8() == self.as_ref())
838+
Ok(pystr.as_wtf8() == self)
839839
} else {
840840
// Fall back to PyObjectRef implementation.
841841
let s = vm.ctx.new_str(self);

vm/src/stdlib/io.rs

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ mod _io {
129129
PyMappedThreadMutexGuard, PyRwLock, PyRwLockReadGuard, PyRwLockWriteGuard,
130130
PyThreadMutex, PyThreadMutexGuard,
131131
},
132+
common::wtf8::{Wtf8, Wtf8Buf},
132133
convert::ToPyObject,
133134
function::{
134135
ArgBytesLike, ArgIterable, ArgMemoryBuffer, ArgSize, Either, FuncArgs, IntoFuncArgs,
@@ -147,7 +148,6 @@ mod _io {
147148
use crossbeam_utils::atomic::AtomicCell;
148149
use malachite_bigint::{BigInt, BigUint};
149150
use num_traits::ToPrimitive;
150-
use rustpython_common::wtf8::Wtf8Buf;
151151
use std::{
152152
borrow::Cow,
153153
io::{self, Cursor, SeekFrom, prelude::*},
@@ -1910,10 +1910,12 @@ mod _io {
19101910
impl Newlines {
19111911
/// returns position where the new line starts if found, otherwise position at which to
19121912
/// continue the search after more is read into the buffer
1913-
fn find_newline(&self, s: &str) -> Result<usize, usize> {
1913+
fn find_newline(&self, s: &Wtf8) -> Result<usize, usize> {
19141914
let len = s.len();
19151915
match self {
1916-
Newlines::Universal | Newlines::Lf => s.find('\n').map(|p| p + 1).ok_or(len),
1916+
Newlines::Universal | Newlines::Lf => {
1917+
s.find("\n".as_ref()).map(|p| p + 1).ok_or(len)
1918+
}
19171919
Newlines::Passthrough => {
19181920
let bytes = s.as_bytes();
19191921
memchr::memchr2(b'\n', b'\r', bytes)
@@ -1928,7 +1930,7 @@ mod _io {
19281930
})
19291931
.ok_or(len)
19301932
}
1931-
Newlines::Cr => s.find('\n').map(|p| p + 1).ok_or(len),
1933+
Newlines::Cr => s.find("\n".as_ref()).map(|p| p + 1).ok_or(len),
19321934
Newlines::Crlf => {
19331935
// s[searched..] == remaining
19341936
let mut searched = 0;
@@ -1993,10 +1995,10 @@ mod _io {
19931995
}
19941996
}
19951997

1996-
fn len_str(s: &str) -> Self {
1998+
fn len_str(s: &Wtf8) -> Self {
19971999
Utf8size {
19982000
bytes: s.len(),
1999-
chars: s.chars().count(),
2001+
chars: s.code_points().count(),
20002002
}
20012003
}
20022004
}
@@ -2224,7 +2226,7 @@ mod _io {
22242226

22252227
let encoding = match args.encoding {
22262228
None if vm.state.settings.utf8_mode > 0 => PyStr::from("utf-8").into_ref(&vm.ctx),
2227-
Some(enc) if enc.as_str() != "locale" => enc,
2229+
Some(enc) if enc.as_wtf8() != "locale" => enc,
22282230
_ => {
22292231
// None without utf8_mode or "locale" encoding
22302232
vm.import("locale", 0)?
@@ -2534,9 +2536,10 @@ mod _io {
25342536
*snapshot = Some((cookie.dec_flags, input_chunk.clone()));
25352537
let decoded = vm.call_method(decoder, "decode", (input_chunk, cookie.need_eof))?;
25362538
let decoded = check_decoded(decoded, vm)?;
2537-
let pos_is_valid = decoded
2538-
.as_str()
2539-
.is_char_boundary(cookie.bytes_to_skip as usize);
2539+
let pos_is_valid = crate::common::wtf8::is_code_point_boundary(
2540+
decoded.as_wtf8(),
2541+
cookie.bytes_to_skip as usize,
2542+
);
25402543
textio.set_decoded_chars(Some(decoded));
25412544
if !pos_is_valid {
25422545
return Err(vm.new_os_error("can't restore logical file position".to_owned()));
@@ -2715,9 +2718,9 @@ mod _io {
27152718
} else if chunks.len() == 1 {
27162719
chunks.pop().unwrap()
27172720
} else {
2718-
let mut ret = String::with_capacity(chunks_bytes);
2721+
let mut ret = Wtf8Buf::with_capacity(chunks_bytes);
27192722
for chunk in chunks {
2720-
ret.push_str(chunk.as_str())
2723+
ret.push_wtf8(chunk.as_wtf8())
27212724
}
27222725
PyStr::from(ret).into_ref(&vm.ctx)
27232726
}
@@ -2744,7 +2747,7 @@ mod _io {
27442747

27452748
let char_len = obj.char_len();
27462749

2747-
let data = obj.as_str();
2750+
let data = obj.as_wtf8();
27482751

27492752
let replace_nl = match textio.newline {
27502753
Newlines::Lf => Some("\n"),
@@ -2753,11 +2756,12 @@ mod _io {
27532756
Newlines::Universal if cfg!(windows) => Some("\r\n"),
27542757
_ => None,
27552758
};
2756-
let has_lf = (replace_nl.is_some() || textio.line_buffering) && data.contains('\n');
2757-
let flush = textio.line_buffering && (has_lf || data.contains('\r'));
2759+
let has_lf = (replace_nl.is_some() || textio.line_buffering)
2760+
&& data.contains_code_point('\n'.into());
2761+
let flush = textio.line_buffering && (has_lf || data.contains_code_point('\r'.into()));
27582762
let chunk = if let Some(replace_nl) = replace_nl {
27592763
if has_lf {
2760-
PyStr::from(data.replace('\n', replace_nl)).into_ref(&vm.ctx)
2764+
PyStr::from(data.replace("\n".as_ref(), replace_nl.as_ref())).into_ref(&vm.ctx)
27612765
} else {
27622766
obj
27632767
}
@@ -2834,16 +2838,16 @@ mod _io {
28342838
if self.is_full_slice() {
28352839
self.0.char_len()
28362840
} else {
2837-
self.slice().chars().count()
2841+
self.slice().code_points().count()
28382842
}
28392843
}
28402844
#[inline]
28412845
fn is_full_slice(&self) -> bool {
28422846
self.1.len() >= self.0.byte_len()
28432847
}
28442848
#[inline]
2845-
fn slice(&self) -> &str {
2846-
&self.0.as_str()[self.1.clone()]
2849+
fn slice(&self) -> &Wtf8 {
2850+
&self.0.as_wtf8()[self.1.clone()]
28472851
}
28482852
#[inline]
28492853
fn slice_pystr(self, vm: &VirtualMachine) -> PyStrRef {
@@ -2894,24 +2898,24 @@ mod _io {
28942898
Some(remaining) => {
28952899
assert_eq!(textio.decoded_chars_used.bytes, 0);
28962900
offset_to_buffer = remaining.utf8_len();
2897-
let decoded_chars = decoded_chars.as_str();
2901+
let decoded_chars = decoded_chars.as_wtf8();
28982902
let line = if remaining.is_full_slice() {
28992903
let mut line = remaining.0;
29002904
line.concat_in_place(decoded_chars, vm);
29012905
line
29022906
} else {
29032907
let remaining = remaining.slice();
29042908
let mut s =
2905-
String::with_capacity(remaining.len() + decoded_chars.len());
2906-
s.push_str(remaining);
2907-
s.push_str(decoded_chars);
2909+
Wtf8Buf::with_capacity(remaining.len() + decoded_chars.len());
2910+
s.push_wtf8(remaining);
2911+
s.push_wtf8(decoded_chars);
29082912
PyStr::from(s).into_ref(&vm.ctx)
29092913
};
29102914
start = Utf8size::default();
29112915
line
29122916
}
29132917
};
2914-
let line_from_start = &line.as_str()[start.bytes..];
2918+
let line_from_start = &line.as_wtf8()[start.bytes..];
29152919
let nl_res = textio.newline.find_newline(line_from_start);
29162920
match nl_res {
29172921
Ok(p) | Err(p) => {
@@ -2922,7 +2926,7 @@ mod _io {
29222926
endpos = start
29232927
+ Utf8size {
29242928
chars: limit - chunked.chars,
2925-
bytes: crate::common::str::char_range_end(
2929+
bytes: crate::common::str::codepoint_range_end(
29262930
line_from_start,
29272931
limit - chunked.chars,
29282932
)
@@ -2963,9 +2967,9 @@ mod _io {
29632967
chunked += cur_line.byte_len();
29642968
chunks.push(cur_line);
29652969
}
2966-
let mut s = String::with_capacity(chunked);
2970+
let mut s = Wtf8Buf::with_capacity(chunked);
29672971
for chunk in chunks {
2968-
s.push_str(chunk.slice())
2972+
s.push_wtf8(chunk.slice())
29692973
}
29702974
PyStr::from(s).into_ref(&vm.ctx)
29712975
} else if let Some(cur_line) = cur_line {
@@ -3100,7 +3104,7 @@ mod _io {
31003104
return None;
31013105
}
31023106
let decoded_chars = self.decoded_chars.as_ref()?;
3103-
let avail = &decoded_chars.as_str()[self.decoded_chars_used.bytes..];
3107+
let avail = &decoded_chars.as_wtf8()[self.decoded_chars_used.bytes..];
31043108
if avail.is_empty() {
31053109
return None;
31063110
}
@@ -3112,7 +3116,7 @@ mod _io {
31123116
(PyStr::from(avail).into_ref(&vm.ctx), avail_chars)
31133117
}
31143118
} else {
3115-
let s = crate::common::str::get_chars(avail, 0..n);
3119+
let s = crate::common::str::get_codepoints(avail, 0..n);
31163120
(PyStr::from(s).into_ref(&vm.ctx), n)
31173121
};
31183122
self.decoded_chars_used += Utf8size {
@@ -3142,11 +3146,11 @@ mod _io {
31423146
return decoded_chars;
31433147
}
31443148
// TODO: in-place editing of `str` when refcount == 1
3145-
let decoded_chars_unused = &decoded_chars.as_str()[chars_pos..];
3146-
let mut s = String::with_capacity(decoded_chars_unused.len() + append_len);
3147-
s.push_str(decoded_chars_unused);
3149+
let decoded_chars_unused = &decoded_chars.as_wtf8()[chars_pos..];
3150+
let mut s = Wtf8Buf::with_capacity(decoded_chars_unused.len() + append_len);
3151+
s.push_wtf8(decoded_chars_unused);
31483152
if let Some(append) = append {
3149-
s.push_str(append.as_str())
3153+
s.push_wtf8(append.as_wtf8())
31503154
}
31513155
PyStr::from(s).into_ref(&vm.ctx)
31523156
}

0 commit comments

Comments
 (0)