Skip to content

Commit 85751dc

Browse files
authored
Merge pull request RustPython#1847 from youknowone/fix-bytes-replace
Fix bytes.replace
2 parents 22527a6 + 67a9dfd commit 85751dc

File tree

4 files changed

+167
-31
lines changed

4 files changed

+167
-31
lines changed

Lib/test/string_tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def test_rsplit(self):
500500
self.checkraises(ValueError, 'hello', 'rsplit', '')
501501
self.checkraises(ValueError, 'hello', 'rsplit', '', 0)
502502

503-
@unittest.skip("TODO: RUSTPYTHON test_bytes")
504503
def test_replace(self):
505504
EQ = self.checkequal
506505

vm/src/obj/objbytearray.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,9 +465,10 @@ impl PyByteArray {
465465
&self,
466466
old: PyByteInner,
467467
new: PyByteInner,
468-
count: OptionalArg<PyIntRef>,
468+
count: OptionalArg<isize>,
469+
vm: &VirtualMachine,
469470
) -> PyResult<PyByteArray> {
470-
Ok(self.borrow_value().replace(old, new, count)?.into())
471+
Ok(self.borrow_value().replace(old, new, count, vm)?.into())
471472
}
472473

473474
#[pymethod(name = "clear")]

vm/src/obj/objbyteinner.rs

Lines changed: 161 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,41 +1127,166 @@ impl PyByteInner {
11271127
bytes_zfill(&self.elements, width.to_usize().unwrap_or(0))
11281128
}
11291129

1130-
pub fn replace(
1130+
// len(self)>=1, from="", len(to)>=1, maxcount>=1
1131+
fn replace_interleave(&self, to: PyByteInner, maxcount: Option<usize>) -> Vec<u8> {
1132+
let place_count = self.elements.len() + 1;
1133+
let count = maxcount.map_or(place_count, |v| std::cmp::min(v, place_count)) - 1;
1134+
let capacity = self.elements.len() + count * to.len();
1135+
let mut result = Vec::with_capacity(capacity);
1136+
let to_slice = to.elements.as_slice();
1137+
result.extend_from_slice(to_slice);
1138+
for c in &self.elements[..count] {
1139+
result.push(*c);
1140+
result.extend_from_slice(to_slice);
1141+
}
1142+
result.extend_from_slice(&self.elements[count..]);
1143+
result
1144+
}
1145+
1146+
fn replace_delete(&self, from: PyByteInner, maxcount: Option<usize>) -> Vec<u8> {
1147+
let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount);
1148+
if count == 0 {
1149+
// no matches
1150+
return self.elements.clone();
1151+
}
1152+
1153+
let result_len = self.len() - (count * from.len());
1154+
debug_assert!(self.len() >= count * from.len());
1155+
1156+
let mut result = Vec::with_capacity(result_len);
1157+
let mut last_end = 0;
1158+
let mut count = count;
1159+
for offset in self.elements.find_iter(&from.elements) {
1160+
result.extend_from_slice(&self.elements[last_end..offset]);
1161+
last_end = offset + from.len();
1162+
count -= 1;
1163+
if count == 0 {
1164+
break;
1165+
}
1166+
}
1167+
result.extend_from_slice(&self.elements[last_end..]);
1168+
result
1169+
}
1170+
1171+
pub fn replace_in_place(
11311172
&self,
1132-
old: PyByteInner,
1133-
new: PyByteInner,
1134-
count: OptionalArg<PyIntRef>,
1135-
) -> PyResult<Vec<u8>> {
1136-
let count = match count.into_option() {
1137-
Some(int) => int
1138-
.as_bigint()
1139-
.to_u32()
1140-
.unwrap_or(self.elements.len() as u32),
1141-
None => self.elements.len() as u32,
1173+
from: PyByteInner,
1174+
to: PyByteInner,
1175+
maxcount: Option<usize>,
1176+
) -> Vec<u8> {
1177+
let len = from.len();
1178+
let mut iter = self.elements.find_iter(&from.elements);
1179+
1180+
let mut new = if let Some(offset) = iter.next() {
1181+
let mut new = self.elements.clone();
1182+
new[offset..offset + len].clone_from_slice(to.elements.as_slice());
1183+
if maxcount == Some(1) {
1184+
return new;
1185+
} else {
1186+
new
1187+
}
1188+
} else {
1189+
return self.elements.clone();
11421190
};
11431191

1144-
let mut res = vec![];
1145-
let mut index = 0;
1146-
let mut done = 0;
1192+
let mut count = maxcount.unwrap_or(std::usize::MAX) - 1;
1193+
for offset in iter {
1194+
new[offset..offset + len].clone_from_slice(to.elements.as_slice());
1195+
count -= 1;
1196+
if count == 0 {
1197+
break;
1198+
}
1199+
}
1200+
new
1201+
}
11471202

1148-
let slice = &self.elements;
1149-
loop {
1150-
if done == count || index > slice.len() - old.len() {
1151-
res.extend_from_slice(&slice[index..]);
1203+
fn replace_general(
1204+
&self,
1205+
from: PyByteInner,
1206+
to: PyByteInner,
1207+
maxcount: Option<usize>,
1208+
vm: &VirtualMachine,
1209+
) -> PyResult<Vec<u8>> {
1210+
let count = count_substring(self.elements.as_slice(), from.elements.as_slice(), maxcount);
1211+
if count == 0 {
1212+
// no matches, return unchanged
1213+
return Ok(self.elements.clone());
1214+
}
1215+
1216+
// Check for overflow
1217+
// result_len = self_len + count * (to_len-from_len)
1218+
debug_assert!(count > 0);
1219+
if to.len() as isize - from.len() as isize
1220+
> (std::isize::MAX - self.elements.len() as isize) / count as isize
1221+
{
1222+
return Err(vm.new_overflow_error("replace bytes is too long".to_owned()));
1223+
}
1224+
let result_len = self.elements.len() + count * (to.len() - from.len());
1225+
1226+
let mut result = Vec::with_capacity(result_len);
1227+
let mut last_end = 0;
1228+
let mut count = count;
1229+
for offset in self.elements.find_iter(&from.elements) {
1230+
result.extend_from_slice(&self.elements[last_end..offset]);
1231+
result.extend_from_slice(to.elements.as_slice());
1232+
last_end = offset + from.len();
1233+
count -= 1;
1234+
if count == 0 {
11521235
break;
11531236
}
1154-
if &slice[index..index + old.len()] == old.elements.as_slice() {
1155-
res.extend_from_slice(&new.elements);
1156-
index += old.len();
1157-
done += 1;
1158-
} else {
1159-
res.push(slice[index]);
1160-
index += 1
1237+
}
1238+
result.extend_from_slice(&self.elements[last_end..]);
1239+
Ok(result)
1240+
}
1241+
1242+
pub fn replace(
1243+
&self,
1244+
from: PyByteInner,
1245+
to: PyByteInner,
1246+
maxcount: OptionalArg<isize>,
1247+
vm: &VirtualMachine,
1248+
) -> PyResult<Vec<u8>> {
1249+
// stringlib_replace in CPython
1250+
let maxcount = match maxcount {
1251+
OptionalArg::Present(maxcount) if maxcount >= 0 => {
1252+
if maxcount == 0 || self.elements.is_empty() {
1253+
// nothing to do; return the original bytes
1254+
return Ok(self.elements.clone());
1255+
}
1256+
Some(maxcount as usize)
11611257
}
1258+
_ => None,
1259+
};
1260+
1261+
// Handle zero-length special cases
1262+
if from.elements.is_empty() {
1263+
if to.elements.is_empty() {
1264+
// nothing to do; return the original bytes
1265+
return Ok(self.elements.clone());
1266+
}
1267+
// insert the 'to' bytes everywhere.
1268+
// >>> b"Python".replace(b"", b".")
1269+
// b'.P.y.t.h.o.n.'
1270+
return Ok(self.replace_interleave(to, maxcount));
11621271
}
11631272

1164-
Ok(res)
1273+
// Except for b"".replace(b"", b"A") == b"A" there is no way beyond this
1274+
// point for an empty self bytes to generate a non-empty bytes
1275+
// Special case so the remaining code always gets a non-empty bytes
1276+
if self.elements.is_empty() {
1277+
return Ok(self.elements.clone());
1278+
}
1279+
1280+
if to.elements.is_empty() {
1281+
// delete all occurrences of 'from' bytes
1282+
Ok(self.replace_delete(from, maxcount))
1283+
} else if from.len() == to.len() {
1284+
// Handle special case where both bytes have the same length
1285+
Ok(self.replace_in_place(from, to, maxcount))
1286+
} else {
1287+
// Otherwise use the more generic algorithms
1288+
self.replace_general(from, to, maxcount, vm)
1289+
}
11651290
}
11661291

11671292
pub fn title(&self) -> Vec<u8> {
@@ -1233,6 +1358,16 @@ pub fn try_as_byte(obj: &PyObjectRef) -> Option<Vec<u8>> {
12331358
})
12341359
}
12351360

1361+
#[inline]
1362+
fn count_substring(haystack: &[u8], needle: &[u8], maxcount: Option<usize>) -> usize {
1363+
let substrings = haystack.find_iter(needle);
1364+
if let Some(maxcount) = maxcount {
1365+
std::cmp::min(substrings.take(maxcount).count(), maxcount)
1366+
} else {
1367+
substrings.count()
1368+
}
1369+
}
1370+
12361371
pub trait ByteOr: ToPrimitive {
12371372
fn byte_or(&self, vm: &VirtualMachine) -> PyResult<u8> {
12381373
match self.to_u8() {

vm/src/obj/objbytes.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,10 @@ impl PyBytes {
426426
&self,
427427
old: PyByteInner,
428428
new: PyByteInner,
429-
count: OptionalArg<PyIntRef>,
429+
count: OptionalArg<isize>,
430+
vm: &VirtualMachine,
430431
) -> PyResult<PyBytes> {
431-
Ok(self.inner.replace(old, new, count)?.into())
432+
Ok(self.inner.replace(old, new, count, vm)?.into())
432433
}
433434

434435
#[pymethod(name = "title")]

0 commit comments

Comments
 (0)