diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index 7e544d6c0d..4950444336 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -500,7 +500,6 @@ def test_rsplit(self): self.checkraises(ValueError, 'hello', 'rsplit', '') self.checkraises(ValueError, 'hello', 'rsplit', '', 0) - @unittest.skip("TODO: RUSTPYTHON test_bytes") def test_replace(self): EQ = self.checkequal diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index c79aa38f68..a69076a3d1 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -465,9 +465,10 @@ impl PyByteArray { &self, old: PyByteInner, new: PyByteInner, - count: OptionalArg, + count: OptionalArg, + vm: &VirtualMachine, ) -> PyResult { - Ok(self.borrow_value().replace(old, new, count)?.into()) + Ok(self.borrow_value().replace(old, new, count, vm)?.into()) } #[pymethod(name = "clear")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 1a4c1d305c..3fbd7d10bb 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -1127,41 +1127,166 @@ impl PyByteInner { bytes_zfill(&self.elements, width.to_usize().unwrap_or(0)) } - pub fn replace( + // len(self)>=1, from="", len(to)>=1, maxcount>=1 + fn replace_interleave(&self, to: PyByteInner, maxcount: Option) -> Vec { + let place_count = self.elements.len() + 1; + let count = maxcount.map_or(place_count, |v| std::cmp::min(v, place_count)) - 1; + let capacity = self.elements.len() + count * to.len(); + let mut result = Vec::with_capacity(capacity); + let to_slice = to.elements.as_slice(); + result.extend_from_slice(to_slice); + for c in &self.elements[..count] { + result.push(*c); + result.extend_from_slice(to_slice); + } + result.extend_from_slice(&self.elements[count..]); + result + } + + fn replace_delete(&self, from: PyByteInner, maxcount: Option) -> Vec { + let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount); + if count == 0 { + // no matches + return self.elements.clone(); + } + + let result_len = self.len() - (count * from.len()); + debug_assert!(self.len() >= count * from.len()); + + let mut result = Vec::with_capacity(result_len); + let mut last_end = 0; + let mut count = count; + for offset in self.elements.find_iter(&from.elements) { + result.extend_from_slice(&self.elements[last_end..offset]); + last_end = offset + from.len(); + count -= 1; + if count == 0 { + break; + } + } + result.extend_from_slice(&self.elements[last_end..]); + result + } + + pub fn replace_in_place( &self, - old: PyByteInner, - new: PyByteInner, - count: OptionalArg, - ) -> PyResult> { - let count = match count.into_option() { - Some(int) => int - .as_bigint() - .to_u32() - .unwrap_or(self.elements.len() as u32), - None => self.elements.len() as u32, + from: PyByteInner, + to: PyByteInner, + maxcount: Option, + ) -> Vec { + let len = from.len(); + let mut iter = self.elements.find_iter(&from.elements); + + let mut new = if let Some(offset) = iter.next() { + let mut new = self.elements.clone(); + new[offset..offset + len].clone_from_slice(to.elements.as_slice()); + if maxcount == Some(1) { + return new; + } else { + new + } + } else { + return self.elements.clone(); }; - let mut res = vec![]; - let mut index = 0; - let mut done = 0; + let mut count = maxcount.unwrap_or(std::usize::MAX) - 1; + for offset in iter { + new[offset..offset + len].clone_from_slice(to.elements.as_slice()); + count -= 1; + if count == 0 { + break; + } + } + new + } - let slice = &self.elements; - loop { - if done == count || index > slice.len() - old.len() { - res.extend_from_slice(&slice[index..]); + fn replace_general( + &self, + from: PyByteInner, + to: PyByteInner, + maxcount: Option, + vm: &VirtualMachine, + ) -> PyResult> { + let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount); + if count == 0 { + // no matches, return unchanged + return Ok(self.elements.clone()); + } + + // Check for overflow + // result_len = self_len + count * (to_len-from_len) + debug_assert!(count > 0); + if to.len() as isize - from.len() as isize + > (std::isize::MAX - self.elements.len() as isize) / count as isize + { + return Err(vm.new_overflow_error("replace bytes is too long".to_owned())); + } + let result_len = self.elements.len() + count * (to.len() - from.len()); + + let mut result = Vec::with_capacity(result_len); + let mut last_end = 0; + let mut count = count; + for offset in self.elements.find_iter(&from.elements) { + result.extend_from_slice(&self.elements[last_end..offset]); + result.extend_from_slice(to.elements.as_slice()); + last_end = offset + from.len(); + count -= 1; + if count == 0 { break; } - if &slice[index..index + old.len()] == old.elements.as_slice() { - res.extend_from_slice(&new.elements); - index += old.len(); - done += 1; - } else { - res.push(slice[index]); - index += 1 + } + result.extend_from_slice(&self.elements[last_end..]); + Ok(result) + } + + pub fn replace( + &self, + from: PyByteInner, + to: PyByteInner, + maxcount: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult> { + // stringlib_replace in CPython + let maxcount = match maxcount { + OptionalArg::Present(maxcount) if maxcount >= 0 => { + if maxcount == 0 || self.elements.is_empty() { + // nothing to do; return the original bytes + return Ok(self.elements.clone()); + } + Some(maxcount as usize) } + _ => None, + }; + + // Handle zero-length special cases + if from.elements.is_empty() { + if to.elements.is_empty() { + // nothing to do; return the original bytes + return Ok(self.elements.clone()); + } + // insert the 'to' bytes everywhere. + // >>> b"Python".replace(b"", b".") + // b'.P.y.t.h.o.n.' + return Ok(self.replace_interleave(to, maxcount)); } - Ok(res) + // Except for b"".replace(b"", b"A") == b"A" there is no way beyond this + // point for an empty self bytes to generate a non-empty bytes + // Special case so the remaining code always gets a non-empty bytes + if self.elements.is_empty() { + return Ok(self.elements.clone()); + } + + if to.elements.is_empty() { + // delete all occurrences of 'from' bytes + Ok(self.replace_delete(from, maxcount)) + } else if from.len() == to.len() { + // Handle special case where both bytes have the same length + Ok(self.replace_in_place(from, to, maxcount)) + } else { + // Otherwise use the more generic algorithms + self.replace_general(from, to, maxcount, vm) + } } pub fn title(&self) -> Vec { @@ -1233,6 +1358,16 @@ pub fn try_as_byte(obj: &PyObjectRef) -> Option> { }) } +#[inline] +fn count_substring(haystack: &[u8], needle: &[u8], maxcount: Option) -> usize { + let substrings = haystack.find_iter(needle); + if let Some(maxcount) = maxcount { + std::cmp::min(substrings.take(maxcount).count(), maxcount) + } else { + substrings.count() + } +} + pub trait ByteOr: ToPrimitive { fn byte_or(&self, vm: &VirtualMachine) -> PyResult { match self.to_u8() { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index ffd18f5455..3127862d26 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -426,9 +426,10 @@ impl PyBytes { &self, old: PyByteInner, new: PyByteInner, - count: OptionalArg, + count: OptionalArg, + vm: &VirtualMachine, ) -> PyResult { - Ok(self.inner.replace(old, new, count)?.into()) + Ok(self.inner.replace(old, new, count, vm)?.into()) } #[pymethod(name = "title")]