Skip to content

Commit d1f3c2a

Browse files
committed
Remove try_to_str
1 parent 81e8638 commit d1f3c2a

File tree

7 files changed

+68
-46
lines changed

7 files changed

+68
-46
lines changed

stdlib/src/sqlite.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ mod _sqlite {
5959
builtins::{
6060
PyBaseException, PyBaseExceptionRef, PyByteArray, PyBytes, PyDict, PyDictRef, PyFloat,
6161
PyInt, PyIntRef, PySlice, PyStr, PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef,
62+
PyUtf8Str, PyUtf8StrRef,
6263
},
6364
convert::IntoObject,
6465
function::{ArgCallable, ArgIterable, FsPath, FuncArgs, OptionalArg, PyComparisonValue},
@@ -848,7 +849,7 @@ mod _sqlite {
848849
}
849850

850851
impl Callable for Connection {
851-
type Args = (PyStrRef,);
852+
type Args = (PyUtf8StrRef,);
852853

853854
fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
854855
if let Some(stmt) = Statement::new(zelf, args.0, vm)? {
@@ -983,7 +984,7 @@ mod _sqlite {
983984
#[pymethod]
984985
fn execute(
985986
zelf: PyRef<Self>,
986-
sql: PyStrRef,
987+
sql: PyUtf8StrRef,
987988
parameters: OptionalArg<PyObjectRef>,
988989
vm: &VirtualMachine,
989990
) -> PyResult<PyRef<Cursor>> {
@@ -995,7 +996,7 @@ mod _sqlite {
995996
#[pymethod]
996997
fn executemany(
997998
zelf: PyRef<Self>,
998-
sql: PyStrRef,
999+
sql: PyUtf8StrRef,
9991000
seq_of_params: ArgIterable,
10001001
vm: &VirtualMachine,
10011002
) -> PyResult<PyRef<Cursor>> {
@@ -1477,7 +1478,7 @@ mod _sqlite {
14771478
#[pymethod]
14781479
fn execute(
14791480
zelf: PyRef<Self>,
1480-
sql: PyStrRef,
1481+
sql: PyUtf8StrRef,
14811482
parameters: OptionalArg<PyObjectRef>,
14821483
vm: &VirtualMachine,
14831484
) -> PyResult<PyRef<Self>> {
@@ -1549,7 +1550,7 @@ mod _sqlite {
15491550
#[pymethod]
15501551
fn executemany(
15511552
zelf: PyRef<Self>,
1552-
sql: PyStrRef,
1553+
sql: PyUtf8StrRef,
15531554
seq_of_params: ArgIterable,
15541555
vm: &VirtualMachine,
15551556
) -> PyResult<PyRef<Self>> {
@@ -2298,10 +2299,9 @@ mod _sqlite {
22982299
impl Statement {
22992300
fn new(
23002301
connection: &Connection,
2301-
sql: PyStrRef,
2302+
sql: PyUtf8StrRef,
23022303
vm: &VirtualMachine,
23032304
) -> PyResult<Option<Self>> {
2304-
let _ = sql.try_to_str(vm)?;
23052305
if sql.as_str().contains('\0') {
23062306
return Err(new_programming_error(
23072307
vm,
@@ -2654,6 +2654,7 @@ mod _sqlite {
26542654
let val = val.to_f64();
26552655
unsafe { sqlite3_bind_double(self.st, pos, val) }
26562656
} else if let Some(val) = obj.downcast_ref::<PyStr>() {
2657+
let val = val.try_as_utf8(vm)?;
26572658
let (ptr, len) = str_to_ptr_len(val, vm)?;
26582659
unsafe { sqlite3_bind_text(self.st, pos, ptr, len, SQLITE_TRANSIENT()) }
26592660
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, obj) {
@@ -2905,6 +2906,7 @@ mod _sqlite {
29052906
} else if let Some(val) = val.downcast_ref::<PyFloat>() {
29062907
sqlite3_result_double(self.ctx, val.to_f64())
29072908
} else if let Some(val) = val.downcast_ref::<PyStr>() {
2909+
let val = val.try_as_utf8(vm)?;
29082910
let (ptr, len) = str_to_ptr_len(val, vm)?;
29092911
sqlite3_result_text(self.ctx, ptr, len, SQLITE_TRANSIENT())
29102912
} else if let Ok(buffer) = PyBuffer::try_from_borrowed_object(vm, val) {
@@ -2985,8 +2987,8 @@ mod _sqlite {
29852987
}
29862988
}
29872989

2988-
fn str_to_ptr_len(s: &PyStr, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> {
2989-
let s_str = s.try_to_str(vm)?;
2990+
fn str_to_ptr_len(s: &PyUtf8Str, vm: &VirtualMachine) -> PyResult<(*const libc::c_char, i32)> {
2991+
let s_str = s.as_str();
29902992
let len = c_int::try_from(s_str.len())
29912993
.map_err(|_| vm.new_overflow_error("TEXT longer than INT_MAX bytes"))?;
29922994
let ptr = s_str.as_ptr().cast();

vm/src/builtins/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub(crate) mod bool_;
5959
pub use bool_::PyBool;
6060
#[path = "str.rs"]
6161
pub(crate) mod pystr;
62-
pub use pystr::{PyStr, PyStrInterned, PyStrRef, PyWtf8Str, PyWtf8StrRef};
62+
pub use pystr::{PyStr, PyStrInterned, PyStrRef, PyUtf8Str, PyUtf8StrRef, PyWtf8Str, PyWtf8StrRef};
6363
#[path = "super.rs"]
6464
pub(crate) mod super_;
6565
pub use super_::PySuper;

vm/src/builtins/str.rs

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ impl Default for PyStr {
219219

220220
pub type PyStrRef = PyRef<PyStr>;
221221
pub type PyWtf8StrRef = PyStrRef;
222+
pub type PyUtf8StrRef = PyRef<PyUtf8Str>;
222223

223224
impl fmt::Display for PyStr {
224225
#[inline]
@@ -434,26 +435,6 @@ impl PyStr {
434435
self.data.as_str()
435436
}
436437

437-
pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
438-
if self.is_utf8() {
439-
// SAFETY: is_utf8() passed, so unwrap is safe.
440-
Ok(unsafe { self.to_str().unwrap_unchecked() })
441-
} else {
442-
let start = self
443-
.as_wtf8()
444-
.code_points()
445-
.position(|c| c.to_char().is_none())
446-
.unwrap();
447-
Err(vm.new_unicode_encode_error_real(
448-
identifier!(vm, utf_8).to_owned(),
449-
vm.ctx.new_str(self.data.clone()),
450-
start,
451-
start + 1,
452-
vm.ctx.new_str("surrogates not allowed"),
453-
))
454-
}
455-
}
456-
457438
pub fn to_string_lossy(&self) -> Cow<'_, str> {
458439
self.to_str()
459440
.map(Cow::Borrowed)
@@ -1505,6 +1486,15 @@ impl PyRef<PyWtf8Str> {
15051486
}
15061487
}
15071488

1489+
impl Py<PyWtf8Str> {
1490+
pub fn try_as_utf8<'a>(&'a self, vm: &VirtualMachine) -> PyResult<&'a Py<PyUtf8Str>> {
1491+
// Check if the string contains surrogates
1492+
self.ensure_valid_utf8(vm)?;
1493+
// If no surrogates, we can safely cast to PyStr
1494+
Ok(unsafe { &*(self as *const _ as *const Py<PyUtf8Str>) })
1495+
}
1496+
}
1497+
15081498
impl Representable for PyStr {
15091499
#[inline]
15101500
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
@@ -1924,12 +1914,26 @@ impl AnyStrWrapper<AsciiStr> for PyStrRef {
19241914
#[derive(Debug)]
19251915
pub struct PyUtf8Str(PyStr);
19261916

1917+
impl fmt::Display for PyUtf8Str {
1918+
#[inline]
1919+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1920+
self.0.fmt(f)
1921+
}
1922+
}
1923+
19271924
impl MaybeTraverse for PyUtf8Str {
19281925
fn try_traverse(&self, traverse_fn: &mut TraverseFn<'_>) {
19291926
self.0.try_traverse(traverse_fn);
19301927
}
19311928
}
19321929

1930+
impl AsRef<Wtf8> for PyUtf8Str {
1931+
#[inline]
1932+
fn as_ref(&self) -> &Wtf8 {
1933+
self.0.as_wtf8()
1934+
}
1935+
}
1936+
19331937
impl PyPayload for PyUtf8Str {
19341938
#[inline]
19351939
fn class(ctx: &Context) -> &'static Py<PyType> {
@@ -1964,6 +1968,21 @@ impl PyUtf8Str {
19641968
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
19651969
unsafe { self.0.to_str().unwrap_unchecked() }
19661970
}
1971+
1972+
#[inline]
1973+
pub fn byte_len(&self) -> usize {
1974+
self.0.byte_len()
1975+
}
1976+
1977+
#[inline]
1978+
pub fn is_empty(&self) -> bool {
1979+
self.0.is_empty()
1980+
}
1981+
1982+
#[inline]
1983+
pub fn char_len(&self) -> usize {
1984+
self.0.char_len()
1985+
}
19671986
}
19681987

19691988
impl Py<PyUtf8Str> {

vm/src/stdlib/codecs.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ mod _codecs {
77
use crate::common::wtf8::Wtf8Buf;
88
use crate::{
99
AsObject, PyObjectRef, PyResult, VirtualMachine,
10-
builtins::PyStrRef,
10+
builtins::{PyStrRef, PyUtf8StrRef},
1111
codecs,
1212
function::{ArgBytesLike, FuncArgs},
1313
};
@@ -23,10 +23,10 @@ mod _codecs {
2323
}
2424

2525
#[pyfunction]
26-
fn lookup(encoding: PyStrRef, vm: &VirtualMachine) -> PyResult {
26+
fn lookup(encoding: PyUtf8StrRef, vm: &VirtualMachine) -> PyResult {
2727
vm.state
2828
.codec_registry
29-
.lookup(encoding.try_to_str(vm)?, vm)
29+
.lookup(encoding.as_str(), vm)
3030
.map(|codec| codec.into_tuple().into())
3131
}
3232

vm/src/stdlib/io.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ mod _io {
120120
TryFromBorrowedObject, TryFromObject,
121121
builtins::{
122122
PyBaseExceptionRef, PyByteArray, PyBytes, PyBytesRef, PyIntRef, PyMemoryView, PyStr,
123-
PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyWtf8Str,
123+
PyStrRef, PyTuple, PyTupleRef, PyType, PyTypeRef, PyUtf8StrRef, PyWtf8Str,
124124
},
125125
class::StaticType,
126126
common::lock::{
@@ -2309,7 +2309,7 @@ mod _io {
23092309

23102310
let newline = args.newline.unwrap_or_default();
23112311
let (encoder, decoder) =
2312-
Self::find_coder(&buffer, encoding.try_to_str(vm)?, &errors, newline, vm)?;
2312+
Self::find_coder(&buffer, encoding.as_str(), &errors, newline, vm)?;
23132313

23142314
*data = Some(TextIOData {
23152315
buffer,
@@ -2409,7 +2409,7 @@ mod _io {
24092409
if let Some(encoding) = args.encoding {
24102410
let (encoder, decoder) = Self::find_coder(
24112411
&data.buffer,
2412-
encoding.try_to_str(vm)?,
2412+
encoding.as_str(),
24132413
&data.errors,
24142414
data.newline,
24152415
vm,
@@ -3908,7 +3908,7 @@ mod _io {
39083908
#[pyarg(any, default = -1)]
39093909
pub buffering: isize,
39103910
#[pyarg(any, default)]
3911-
pub encoding: Option<PyStrRef>,
3911+
pub encoding: Option<PyUtf8StrRef>,
39123912
#[pyarg(any, default)]
39133913
pub errors: Option<PyStrRef>,
39143914
#[pyarg(any, default)]

vm/src/stdlib/time.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ unsafe extern "C" {
3535
mod decl {
3636
use crate::{
3737
AsObject, PyObjectRef, PyResult, TryFromObject, VirtualMachine,
38-
builtins::{PyStrRef, PyTypeRef},
38+
builtins::{PyStrRef, PyTypeRef, PyUtf8StrRef},
3939
function::{Either, FuncArgs, OptionalArg},
4040
types::PyStructSequence,
4141
};
@@ -345,7 +345,11 @@ mod decl {
345345
}
346346

347347
#[pyfunction]
348-
fn strftime(format: PyStrRef, t: OptionalArg<PyStructTime>, vm: &VirtualMachine) -> PyResult {
348+
fn strftime(
349+
format: PyUtf8StrRef,
350+
t: OptionalArg<PyStructTime>,
351+
vm: &VirtualMachine,
352+
) -> PyResult {
349353
use std::fmt::Write;
350354

351355
let instant = t.naive_or_local(vm)?;
@@ -356,12 +360,8 @@ mod decl {
356360
* raises an error if unsupported format is supplied.
357361
* If error happens, we set result as input arg.
358362
*/
359-
write!(
360-
&mut formatted_time,
361-
"{}",
362-
instant.format(format.try_to_str(vm)?)
363-
)
364-
.unwrap_or_else(|_| formatted_time = format.to_string());
363+
write!(&mut formatted_time, "{}", instant.format(format.as_str()))
364+
.unwrap_or_else(|_| formatted_time = format.to_string());
365365
Ok(vm.ctx.new_str(formatted_time).into())
366366
}
367367

vm/src/utils.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use rustpython_common::wtf8::Wtf8;
22

33
use crate::{
44
PyObjectRef, PyResult, VirtualMachine,
5-
builtins::PyStr,
5+
builtins::{PyStr, PyUtf8Str},
66
convert::{ToPyException, ToPyObject},
77
exceptions::cstring_error,
88
};
@@ -35,6 +35,7 @@ pub trait ToCString: AsRef<Wtf8> {
3535

3636
impl ToCString for &str {}
3737
impl ToCString for PyStr {}
38+
impl ToCString for PyUtf8Str {}
3839

3940
pub(crate) fn collection_repr<'a, I>(
4041
class_name: Option<&str>,

0 commit comments

Comments
 (0)