Skip to content

Fix bytes.replace #1847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Lib/test/string_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def test_rsplit(self):
self.checkraises(ValueError, 'hello', 'rsplit', '')
self.checkraises(ValueError, 'hello', 'rsplit', '', 0)

@unittest.skip("TODO: RUSTPYTHON test_bytes")
def test_replace(self):
EQ = self.checkequal

Expand Down
5 changes: 3 additions & 2 deletions vm/src/obj/objbytearray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,9 +465,10 @@ impl PyByteArray {
&self,
old: PyByteInner,
new: PyByteInner,
count: OptionalArg<PyIntRef>,
count: OptionalArg<isize>,
vm: &VirtualMachine,
) -> PyResult<PyByteArray> {
Ok(self.borrow_value().replace(old, new, count)?.into())
Ok(self.borrow_value().replace(old, new, count, vm)?.into())
}

#[pymethod(name = "clear")]
Expand Down
187 changes: 161 additions & 26 deletions vm/src/obj/objbyteinner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,41 +1127,166 @@ impl PyByteInner {
bytes_zfill(&self.elements, width.to_usize().unwrap_or(0))
}

pub fn replace(
// len(self)>=1, from="", len(to)>=1, maxcount>=1
fn replace_interleave(&self, to: PyByteInner, maxcount: Option<usize>) -> Vec<u8> {
let place_count = self.elements.len() + 1;
let count = maxcount.map_or(place_count, |v| std::cmp::min(v, place_count)) - 1;
let capacity = self.elements.len() + count * to.len();
let mut result = Vec::with_capacity(capacity);
let to_slice = to.elements.as_slice();
result.extend_from_slice(to_slice);
for c in &self.elements[..count] {
result.push(*c);
result.extend_from_slice(to_slice);
}
result.extend_from_slice(&self.elements[count..]);
result
}

fn replace_delete(&self, from: PyByteInner, maxcount: Option<usize>) -> Vec<u8> {
let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount);
if count == 0 {
// no matches
return self.elements.clone();
}

let result_len = self.len() - (count * from.len());
debug_assert!(self.len() >= count * from.len());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove this? I know is is only in debug but it does not provide additional data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is line by line port of CPython implementation. Do you still want to remove this regardless of it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK with you keeping this. Just not sure if that is necessary. Your call.


let mut result = Vec::with_capacity(result_len);
let mut last_end = 0;
let mut count = count;
for offset in self.elements.find_iter(&from.elements) {
result.extend_from_slice(&self.elements[last_end..offset]);
last_end = offset + from.len();
count -= 1;
if count == 0 {
break;
}
}
result.extend_from_slice(&self.elements[last_end..]);
result
}

pub fn replace_in_place(
&self,
old: PyByteInner,
new: PyByteInner,
count: OptionalArg<PyIntRef>,
) -> PyResult<Vec<u8>> {
let count = match count.into_option() {
Some(int) => int
.as_bigint()
.to_u32()
.unwrap_or(self.elements.len() as u32),
None => self.elements.len() as u32,
from: PyByteInner,
to: PyByteInner,
maxcount: Option<usize>,
) -> Vec<u8> {
let len = from.len();
let mut iter = self.elements.find_iter(&from.elements);

let mut new = if let Some(offset) = iter.next() {
let mut new = self.elements.clone();
new[offset..offset + len].clone_from_slice(to.elements.as_slice());
if maxcount == Some(1) {
return new;
} else {
new
}
} else {
return self.elements.clone();
};

let mut res = vec![];
let mut index = 0;
let mut done = 0;
let mut count = maxcount.unwrap_or(std::usize::MAX) - 1;
for offset in iter {
new[offset..offset + len].clone_from_slice(to.elements.as_slice());
count -= 1;
if count == 0 {
break;
}
}
new
}

let slice = &self.elements;
loop {
if done == count || index > slice.len() - old.len() {
res.extend_from_slice(&slice[index..]);
fn replace_general(
&self,
from: PyByteInner,
to: PyByteInner,
maxcount: Option<usize>,
vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount);
if count == 0 {
// no matches, return unchanged
return Ok(self.elements.clone());
}

// Check for overflow
// result_len = self_len + count * (to_len-from_len)
debug_assert!(count > 0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this?

if to.len() as isize - from.len() as isize
> (std::isize::MAX - self.elements.len() as isize) / count as isize
{
return Err(vm.new_overflow_error("replace bytes is too long".to_owned()));
}
let result_len = self.elements.len() + count * (to.len() - from.len());

let mut result = Vec::with_capacity(result_len);
let mut last_end = 0;
let mut count = count;
for offset in self.elements.find_iter(&from.elements) {
result.extend_from_slice(&self.elements[last_end..offset]);
result.extend_from_slice(to.elements.as_slice());
last_end = offset + from.len();
count -= 1;
if count == 0 {
break;
}
if &slice[index..index + old.len()] == old.elements.as_slice() {
res.extend_from_slice(&new.elements);
index += old.len();
done += 1;
} else {
res.push(slice[index]);
index += 1
}
result.extend_from_slice(&self.elements[last_end..]);
Ok(result)
}

pub fn replace(
&self,
from: PyByteInner,
to: PyByteInner,
maxcount: OptionalArg<isize>,
vm: &VirtualMachine,
) -> PyResult<Vec<u8>> {
// stringlib_replace in CPython
let maxcount = match maxcount {
OptionalArg::Present(maxcount) if maxcount >= 0 => {
if maxcount == 0 || self.elements.is_empty() {
// nothing to do; return the original bytes
return Ok(self.elements.clone());
}
Some(maxcount as usize)
}
_ => None,
};

// Handle zero-length special cases
if from.elements.is_empty() {
if to.elements.is_empty() {
// nothing to do; return the original bytes
return Ok(self.elements.clone());
}
// insert the 'to' bytes everywhere.
// >>> b"Python".replace(b"", b".")
// b'.P.y.t.h.o.n.'
return Ok(self.replace_interleave(to, maxcount));
}

Ok(res)
// Except for b"".replace(b"", b"A") == b"A" there is no way beyond this
// point for an empty self bytes to generate a non-empty bytes
// Special case so the remaining code always gets a non-empty bytes
if self.elements.is_empty() {
return Ok(self.elements.clone());
}

if to.elements.is_empty() {
// delete all occurrences of 'from' bytes
Ok(self.replace_delete(from, maxcount))
} else if from.len() == to.len() {
// Handle special case where both bytes have the same length
Ok(self.replace_in_place(from, to, maxcount))
} else {
// Otherwise use the more generic algorithms
self.replace_general(from, to, maxcount, vm)
}
}

pub fn title(&self) -> Vec<u8> {
Expand Down Expand Up @@ -1233,6 +1358,16 @@ pub fn try_as_byte(obj: &PyObjectRef) -> Option<Vec<u8>> {
})
}

#[inline]
fn count_substring(haystack: &[u8], needle: &[u8], maxcount: Option<usize>) -> usize {
let substrings = haystack.find_iter(needle);
if let Some(maxcount) = maxcount {
std::cmp::min(substrings.take(maxcount).count(), maxcount)
} else {
substrings.count()
}
}

pub trait ByteOr: ToPrimitive {
fn byte_or(&self, vm: &VirtualMachine) -> PyResult<u8> {
match self.to_u8() {
Expand Down
5 changes: 3 additions & 2 deletions vm/src/obj/objbytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,10 @@ impl PyBytes {
&self,
old: PyByteInner,
new: PyByteInner,
count: OptionalArg<PyIntRef>,
count: OptionalArg<isize>,
vm: &VirtualMachine,
) -> PyResult<PyBytes> {
Ok(self.inner.replace(old, new, count)?.into())
Ok(self.inner.replace(old, new, count, vm)?.into())
}

#[pymethod(name = "title")]
Expand Down