diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 36d032e21f..64d3114273 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -451,13 +451,10 @@ impl PyByteArray { #[pymethod(name = "splitlines")] fn splitlines(&self, options: pystr::SplitLinesArgs, vm: &VirtualMachine) -> PyResult { - let as_bytes = self + let lines = self .borrow_value() - .splitlines(options) - .iter() - .map(|x| vm.ctx.new_bytearray(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + .splitlines(options, |x| vm.ctx.new_bytearray(x.to_vec())); + Ok(vm.ctx.new_list(lines)) } #[pymethod(name = "zfill")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 02139637bf..b2049aac63 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -961,46 +961,11 @@ impl PyByteInner { res } - pub fn splitlines(&self, options: pystr::SplitLinesArgs) -> Vec<&[u8]> { - let mut res = vec![]; - - if self.elements.is_empty() { - return vec![]; - } - - let mut prev_index = 0; - let mut index = 0; - let keep = if options.keepends { 1 } else { 0 }; - let slice = &self.elements; - - while index < slice.len() { - match slice[index] { - b'\n' => { - res.push(&slice[prev_index..index + keep]); - index += 1; - prev_index = index; - } - b'\r' => { - if index + 2 <= slice.len() && slice[index + 1] == b'\n' { - res.push(&slice[prev_index..index + keep + keep]); - index += 2; - } else { - res.push(&slice[prev_index..index + keep]); - index += 1; - } - prev_index = index; - } - _x => { - if index == slice.len() - 1 { - res.push(&slice[prev_index..=index]); - break; - } - index += 1 - } - } - } - - res + pub fn splitlines(&self, options: pystr::SplitLinesArgs, into_wrapper: FW) -> Vec + where + FW: Fn(&[u8]) -> W, + { + self.elements.py_splitlines(options, into_wrapper) } pub fn zfill(&self, width: isize) -> Vec { @@ -1329,6 +1294,10 @@ impl PyCommonString for [u8] { Vec::with_capacity(capacity) } + fn as_bytes(&self) -> &[u8] { + self + } + fn get_bytes<'a>(&'a self, range: std::ops::Range) -> &'a Self { &self[range] } diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index b64656de01..19d9f71e78 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -418,13 +418,10 @@ impl PyBytes { #[pymethod(name = "splitlines")] fn splitlines(&self, options: pystr::SplitLinesArgs, vm: &VirtualMachine) -> PyResult { - let as_bytes = self + let lines = self .inner - .splitlines(options) - .iter() - .map(|x| vm.ctx.new_bytes(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + .splitlines(options, |x| vm.ctx.new_bytes(x.to_vec())); + Ok(vm.ctx.new_list(lines)) } #[pymethod(name = "zfill")] diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index d9998a50d7..dc1959d3db 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -803,27 +803,8 @@ impl PyString { #[pymethod] fn splitlines(&self, args: pystr::SplitLinesArgs, vm: &VirtualMachine) -> PyObjectRef { - let mut elements = vec![]; - let mut curr = "".to_owned(); - let mut chars = self.value.chars().peekable(); - while let Some(ch) = chars.next() { - if ch == '\n' || ch == '\r' { - if args.keepends { - curr.push(ch); - } - if ch == '\r' && chars.peek() == Some(&'\n') { - continue; - } - elements.push(vm.ctx.new_str(curr.clone())); - curr.clear(); - } else { - curr.push(ch); - } - } - if !curr.is_empty() { - elements.push(vm.ctx.new_str(curr)); - } - vm.ctx.new_list(elements) + vm.ctx + .new_list(self.value.py_splitlines(args, |s| vm.new_str(s.to_owned()))) } #[pymethod] @@ -1752,6 +1733,10 @@ impl PyCommonString for str { String::with_capacity(capacity) } + fn as_bytes(&self) -> &[u8] { + self.as_bytes() + } + fn get_bytes<'a>(&'a self, range: std::ops::Range) -> &'a Self { &self[range] } diff --git a/vm/src/obj/pystr.rs b/vm/src/obj/pystr.rs index 1e09386b4f..6ac9e83fb9 100644 --- a/vm/src/obj/pystr.rs +++ b/vm/src/obj/pystr.rs @@ -129,6 +129,7 @@ pub trait PyCommonString { type Container; fn with_capacity(capacity: usize) -> Self::Container; + fn as_bytes(&self) -> &[u8]; fn get_bytes<'a>(&'a self, range: std::ops::Range) -> &'a Self; // FIXME: get_chars is expensive for str fn get_chars<'a>(&'a self, range: std::ops::Range) -> &'a Self; @@ -297,4 +298,38 @@ pub trait PyCommonString { &self } } + + fn py_splitlines(&self, options: SplitLinesArgs, into_wrapper: FW) -> Vec + where + FW: Fn(&Self) -> W, + { + let keep = if options.keepends { 1 } else { 0 }; + let mut elements = Vec::new(); + let mut last_i = 0; + let mut enumerated = self.as_bytes().iter().enumerate().peekable(); + while let Some((i, ch)) = enumerated.next() { + let (end_len, i_diff) = match *ch { + b'\n' => (keep, 1), + b'\r' => { + let is_rn = enumerated.peek().map_or(false, |(_, ch)| **ch == b'\n'); + if is_rn { + let _ = enumerated.next(); + (keep + keep, 2) + } else { + (keep, 1) + } + } + _ => { + continue; + } + }; + let range = last_i..i + end_len; + last_i = i + i_diff; + elements.push(into_wrapper(self.get_bytes(range))); + } + if last_i != self.bytes_len() { + elements.push(into_wrapper(self.get_bytes(last_i..self.bytes_len()))); + } + elements + } }