Skip to content

Commit 436a7e8

Browse files
committed
Allow &'static str to be stored in StringRef with no allocation
1 parent 3b68321 commit 436a7e8

File tree

6 files changed

+73
-35
lines changed

6 files changed

+73
-35
lines changed

bytecode/src/bytecode.rs

+53-17
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,44 @@ pub struct FrozenModule {
690690
pub package: bool,
691691
}
692692

693+
#[derive(Clone)]
694+
enum StringDataInner {
695+
Static(&'static str),
696+
Owned(Box<str>),
697+
}
698+
impl StringDataInner {
699+
fn as_str(&self) -> &str {
700+
match self {
701+
Self::Static(s) => s,
702+
Self::Owned(ref s) => &*s,
703+
}
704+
}
705+
}
706+
impl fmt::Debug for StringDataInner {
707+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
708+
fmt::Debug::fmt(self.as_str(), f)
709+
}
710+
}
711+
impl serde::Serialize for StringDataInner {
712+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
713+
where
714+
S: serde::Serializer,
715+
{
716+
serializer.serialize_str(self.as_str())
717+
}
718+
}
719+
impl<'de> serde::Deserialize<'de> for StringDataInner {
720+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
721+
where
722+
D: serde::Deserializer<'de>,
723+
{
724+
serde::Deserialize::deserialize(deserializer).map(Self::Owned)
725+
}
726+
}
727+
693728
#[derive(Debug, Clone, Serialize, Deserialize)]
694729
pub struct StringData {
695-
s: Box<str>,
730+
inner: StringDataInner,
696731
#[serde(skip)]
697732
hash: OnceCell<i64>,
698733
#[serde(skip)]
@@ -702,30 +737,33 @@ pub struct StringData {
702737
impl StringData {
703738
#[inline]
704739
pub fn as_str(&self) -> &str {
705-
self.as_ref()
740+
self.inner.as_str()
706741
}
707742

708743
pub fn hash_value(&self) -> i64 {
709744
*self.hash.get_or_init(|| {
710745
use std::hash::*;
711746
let mut hasher = std::collections::hash_map::DefaultHasher::new();
712-
self.s.hash(&mut hasher);
747+
self.as_str().hash(&mut hasher);
713748
hasher.finish() as i64
714749
})
715750
}
716751

717752
pub fn char_len(&self) -> usize {
718-
*self.len.get_or_init(|| self.s.chars().count())
753+
*self.len.get_or_init(|| self.as_str().chars().count())
719754
}
720755

721756
pub fn into_string(self) -> String {
722-
self.s.into_string()
757+
match self.inner {
758+
StringDataInner::Static(s) => s.to_owned(),
759+
StringDataInner::Owned(s) => s.into(),
760+
}
723761
}
724762
}
725763

726764
impl PartialEq for StringData {
727765
fn eq(&self, other: &Self) -> bool {
728-
self.s == other.s
766+
self.as_str() == other.as_str()
729767
}
730768
}
731769

@@ -734,14 +772,14 @@ impl Eq for StringData {}
734772
impl AsRef<str> for StringData {
735773
#[inline]
736774
fn as_ref(&self) -> &str {
737-
&self.s
775+
self.as_str()
738776
}
739777
}
740778

741779
impl From<Box<str>> for StringData {
742780
fn from(s: Box<str>) -> Self {
743781
StringData {
744-
s,
782+
inner: StringDataInner::Owned(s),
745783
hash: OnceCell::new(),
746784
len: OnceCell::new(),
747785
}
@@ -754,15 +792,13 @@ impl From<String> for StringData {
754792
}
755793
}
756794

757-
impl From<&str> for StringData {
758-
fn from(s: &str) -> Self {
759-
Box::<str>::from(s).into()
760-
}
761-
}
762-
763-
impl From<&String> for StringData {
764-
fn from(s: &String) -> Self {
765-
s.as_str().into()
795+
impl From<&'static str> for StringData {
796+
fn from(s: &'static str) -> Self {
797+
StringData {
798+
inner: StringDataInner::Static(s),
799+
hash: OnceCell::new(),
800+
len: OnceCell::new(),
801+
}
766802
}
767803
}
768804

compiler/src/compile.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -2290,9 +2290,9 @@ fn compile_conversion_flag(conversion_flag: ast::ConversionFlag) -> bytecode::Co
22902290
}
22912291
}
22922292

2293-
fn string_constant(s: impl Into<bytecode::StringData>) -> bytecode::Constant {
2293+
fn string_constant(s: impl Into<String>) -> bytecode::Constant {
22942294
bytecode::Constant::String {
2295-
value: s.into().into(),
2295+
value: bytecode::StringData::from(s.into()).into(),
22962296
}
22972297
}
22982298

src/shell/helper.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl<'vm> ShellHelper<'vm> {
7979

8080
let mut current = self
8181
.scope
82-
.load_global(self.vm, &StringRef::new(first.into()))?;
82+
.load_global(self.vm, &StringRef::new(first.to_owned().into()))?;
8383

8484
for attr in parents {
8585
current = self.vm.get_attribute(current.clone(), attr.as_str()).ok()?;

vm/src/obj/objstr.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl TryIntoRef<PyString> for String {
9090

9191
impl TryIntoRef<PyString> for &str {
9292
fn try_into_ref(self, vm: &VirtualMachine) -> PyResult<PyRef<PyString>> {
93-
Ok(PyString::from(self).into_ref(vm))
93+
Ok(PyString::from(self.to_owned()).into_ref(vm))
9494
}
9595
}
9696

@@ -197,7 +197,7 @@ impl PyString {
197197
if string.class().is(&cls) {
198198
Ok(string)
199199
} else {
200-
PyString::from(string.as_str()).into_ref_with_type(vm, cls)
200+
PyString::from(string.as_str().to_owned()).into_ref_with_type(vm, cls)
201201
}
202202
}
203203
#[pymethod(name = "__add__")]
@@ -433,9 +433,9 @@ impl PyString {
433433
let elements = self.as_str().py_split(
434434
args,
435435
vm,
436-
|v, s, vm| v.split(s).map(|s| vm.ctx.new_str(s)).collect(),
437-
|v, s, n, vm| v.splitn(n, s).map(|s| vm.ctx.new_str(s)).collect(),
438-
|v, n, vm| v.py_split_whitespace(n, |s| vm.ctx.new_str(s)),
436+
|v, s, vm| v.split(s).map(|s| vm.new_str(s.to_owned())).collect(),
437+
|v, s, n, vm| v.splitn(n, s).map(|s| vm.new_str(s.to_owned())).collect(),
438+
|v, n, vm| v.py_split_whitespace(n, |s| vm.new_str(s.to_owned())),
439439
)?;
440440
Ok(vm.ctx.new_list(elements))
441441
}
@@ -445,9 +445,9 @@ impl PyString {
445445
let mut elements = self.as_str().py_split(
446446
args,
447447
vm,
448-
|v, s, vm| v.rsplit(s).map(|s| vm.ctx.new_str(s)).collect(),
449-
|v, s, n, vm| v.rsplitn(n, s).map(|s| vm.ctx.new_str(s)).collect(),
450-
|v, n, vm| v.py_rsplit_whitespace(n, |s| vm.ctx.new_str(s)),
448+
|v, s, vm| v.rsplit(s).map(|s| vm.new_str(s.to_owned())).collect(),
449+
|v, s, n, vm| v.rsplitn(n, s).map(|s| vm.new_str(s.to_owned())).collect(),
450+
|v, n, vm| v.py_rsplit_whitespace(n, |s| vm.new_str(s.to_owned())),
451451
)?;
452452
// Unlike Python rsplit, Rust rsplitn returns an iterator that
453453
// starts from the end of the string.

vm/src/stdlib/io.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -850,12 +850,12 @@ fn text_io_wrapper_init(
850850
let mut self_encoding = None; // TODO: Try os.device_encoding(fileno)
851851
if encoding.is_none() && self_encoding.is_none() {
852852
// TODO: locale module
853-
self_encoding = Some("utf-8");
853+
self_encoding = Some("utf-8".to_owned());
854854
}
855-
if let Some(self_encoding) = self_encoding {
856-
encoding = Some(PyString::from(self_encoding).into_ref(vm));
855+
if let Some(ref self_encoding) = self_encoding {
856+
encoding = Some(PyString::from(self_encoding.clone()).into_ref(vm));
857857
} else if let Some(ref encoding) = encoding {
858-
self_encoding = Some(encoding.as_str())
858+
self_encoding = Some(encoding.as_str().to_owned())
859859
} else {
860860
return Err(vm.new_os_error("could not determine default encoding".to_owned()));
861861
}
@@ -870,7 +870,7 @@ fn text_io_wrapper_init(
870870
vm.set_attr(
871871
&instance,
872872
"encoding",
873-
self_encoding.map_or_else(|| vm.get_none(), |s| vm.ctx.new_str(s)),
873+
self_encoding.map_or_else(|| vm.get_none(), |s| vm.new_str(s)),
874874
)?;
875875
vm.set_attr(&instance, "errors", errors)?;
876876
vm.set_attr(&instance, "buffer", args.buffer.clone())?;

vm/src/stdlib/pystruct.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -695,8 +695,10 @@ mod _struct {
695695
) -> PyResult<PyRef<Self>> {
696696
let fmt_str = match fmt {
697697
Either::A(s) => s,
698-
Either::B(b) => PyString::from(std::str::from_utf8(b.get_value()).unwrap())
699-
.into_ref_with_type(vm, vm.ctx.str_type())?,
698+
Either::B(b) => {
699+
PyString::from(std::str::from_utf8(b.get_value()).unwrap().to_owned())
700+
.into_ref_with_type(vm, vm.ctx.str_type())?
701+
}
700702
};
701703
let spec = FormatSpec::parse(fmt_str.as_str()).map_err(|e| new_struct_error(vm, e))?;
702704
PyStruct { spec, fmt_str }.into_ref_with_type(vm, cls)

0 commit comments

Comments
 (0)