Skip to content

Commit f18f92c

Browse files
committed
Refactor and fix str/bytes count/find/index
1 parent f5de59a commit f18f92c

File tree

5 files changed

+116
-134
lines changed

5 files changed

+116
-134
lines changed

Lib/test/test_bytes.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,6 @@ def test_find(self):
606606
ValueError, r'byte must be in range\(0, 256\)',
607607
b.find, index)
608608

609-
# TODO: RUSTPYTHON
610-
@unittest.expectedFailure
611609
def test_rfind(self):
612610
b = self.type2test(b'mississippi')
613611
i = 105
@@ -647,8 +645,6 @@ def test_index(self):
647645
self.assertEqual(b.index(i, 1, 3), 1)
648646
self.assertRaises(ValueError, b.index, w, 1, 3)
649647

650-
# TODO: RUSTPYTHON
651-
@unittest.expectedFailure
652648
def test_rindex(self):
653649
b = self.type2test(b'mississippi')
654650
i = 105

vm/src/obj/objbytearray.rs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,30 +324,26 @@ impl PyByteArray {
324324

325325
#[pymethod(name = "find")]
326326
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
327-
self.inner.borrow().find(options, false, vm)
327+
let index = self.inner.borrow().find(options, false, vm)?;
328+
Ok(index.map_or(-1, |v| v as isize))
328329
}
329330

330331
#[pymethod(name = "index")]
331-
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
332-
let res = self.inner.borrow().find(options, false, vm)?;
333-
if res == -1 {
334-
return Err(vm.new_value_error("substring not found".to_owned()));
335-
}
336-
Ok(res)
332+
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
333+
let index = self.inner.borrow().find(options, false, vm)?;
334+
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
337335
}
338336

339337
#[pymethod(name = "rfind")]
340338
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
341-
self.inner.borrow().find(options, true, vm)
339+
let index = self.inner.borrow().find(options, true, vm)?;
340+
Ok(index.map_or(-1, |v| v as isize))
342341
}
343342

344343
#[pymethod(name = "rindex")]
345-
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
346-
let res = self.inner.borrow().find(options, true, vm)?;
347-
if res == -1 {
348-
return Err(vm.new_value_error("substring not found".to_owned()));
349-
}
350-
Ok(res)
344+
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
345+
let index = self.inner.borrow().find(options, true, vm)?;
346+
index.ok_or_else(|| vm.new_value_error("substring not found".to_owned()))
351347
}
352348

353349
#[pymethod(name = "remove")]

vm/src/obj/objbyteinner.rs

Lines changed: 34 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use super::objmemory::PyMemoryView;
1313
use super::objnone::PyNoneRef;
1414
use super::objsequence::{is_valid_slice_arg, PySliceableSequence};
1515
use super::objslice::PySliceRef;
16-
use super::objstr::{self, PyString, PyStringRef};
16+
use super::objstr::{self, adjust_indices, PyString, PyStringRef, StringRange};
1717
use super::objtuple::PyTupleRef;
1818
use crate::function::OptionalArg;
1919
use crate::pyhash;
@@ -137,34 +137,22 @@ pub struct ByteInnerFindOptions {
137137
#[pyarg(positional_only, optional = false)]
138138
sub: Either<PyByteInner, PyIntRef>,
139139
#[pyarg(positional_only, optional = true)]
140-
start: OptionalArg<Option<PyIntRef>>,
140+
start: OptionalArg<Option<isize>>,
141141
#[pyarg(positional_only, optional = true)]
142-
end: OptionalArg<Option<PyIntRef>>,
142+
end: OptionalArg<Option<isize>>,
143143
}
144144

145145
impl ByteInnerFindOptions {
146146
pub fn get_value(
147147
self,
148-
elements: &[u8],
148+
len: usize,
149149
vm: &VirtualMachine,
150-
) -> PyResult<(Vec<u8>, Range<usize>)> {
150+
) -> PyResult<(Vec<u8>, std::ops::Range<usize>)> {
151151
let sub = match self.sub {
152152
Either::A(v) => v.elements.to_vec(),
153153
Either::B(int) => vec![int.as_bigint().byte_or(vm)?],
154154
};
155-
156-
let start = match self.start {
157-
OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()),
158-
_ => None,
159-
};
160-
161-
let end = match self.end {
162-
OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()),
163-
_ => None,
164-
};
165-
166-
let range = elements.to_vec().get_slice_range(&start, &end);
167-
155+
let range = adjust_indices(self.start, self.end, len);
168156
Ok((sub, range))
169157
}
170158
}
@@ -808,25 +796,18 @@ impl PyByteInner {
808796
}
809797

810798
pub fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
811-
let (sub, range) = options.get_value(&self.elements, vm)?;
812-
813-
if sub.is_empty() {
814-
return Ok(self.len() + 1);
799+
let (needle, range) = options.get_value(self.elements.len(), vm)?;
800+
if !range.is_normal() {
801+
return Ok(0);
815802
}
816-
817-
let mut total: usize = 0;
818-
let mut i_start = range.start;
819-
let i_end = range.end;
820-
821-
for i in self.elements.do_slice(range) {
822-
if i_start + sub.len() <= i_end
823-
&& i == sub[0]
824-
&& &self.elements[i_start..(i_start + sub.len())] == sub.as_slice()
825-
{
826-
total += 1;
827-
}
828-
i_start += 1;
803+
if needle.is_empty() {
804+
return Ok(range.len() + 1);
829805
}
806+
let haystack = &self.elements[range];
807+
let total = haystack
808+
.windows(needle.len())
809+
.filter(|w| *w == needle.as_slice())
810+
.count();
830811
Ok(total)
831812
}
832813

@@ -884,37 +865,36 @@ impl PyByteInner {
884865
Ok(suff.as_slice() == &self.elements.do_slice(range)[offset])
885866
}
886867

868+
#[inline]
887869
pub fn find(
888870
&self,
889871
options: ByteInnerFindOptions,
890872
reverse: bool,
891873
vm: &VirtualMachine,
892-
) -> PyResult<isize> {
893-
let (sub, range) = options.get_value(&self.elements, vm)?;
894-
// not allowed for this method
895-
if range.end < range.start {
896-
return Ok(-1isize);
874+
) -> PyResult<Option<usize>> {
875+
let (needle, range) = options.get_value(self.elements.len(), vm)?;
876+
if !range.is_normal() {
877+
return Ok(None);
897878
}
898-
899-
let start = range.start;
900-
let end = range.end;
901-
879+
if needle.is_empty() {
880+
return Ok(Some(if reverse { range.end } else { range.start }));
881+
}
882+
let haystack = &self.elements[range.clone()];
883+
let windows = haystack.windows(needle.len());
902884
if reverse {
903-
let slice = self.elements.do_slice_reverse(range);
904-
for (n, _) in slice.iter().enumerate() {
905-
if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() {
906-
return Ok((end - n - 1) as isize);
885+
for (i, w) in windows.rev().enumerate() {
886+
if w == needle.as_slice() {
887+
return Ok(Some(range.end - i - needle.len()));
907888
}
908889
}
909890
} else {
910-
let slice = self.elements.do_slice(range);
911-
for (n, _) in slice.iter().enumerate() {
912-
if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() {
913-
return Ok((start + n) as isize);
891+
for (i, w) in windows.enumerate() {
892+
if w == needle.as_slice() {
893+
return Ok(Some(range.start + i));
914894
}
915895
}
916-
};
917-
Ok(-1isize)
896+
}
897+
Ok(None)
918898
}
919899

920900
pub fn maketrans(from: PyByteInner, to: PyByteInner, vm: &VirtualMachine) -> PyResult {

vm/src/obj/objbytes.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -294,30 +294,32 @@ impl PyBytes {
294294

295295
#[pymethod(name = "find")]
296296
fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
297-
self.inner.find(options, false, vm)
297+
let index = self.inner.find(options, false, vm)?;
298+
Ok(index.map_or(-1, |v| v as isize))
298299
}
299300

300301
#[pymethod(name = "index")]
301-
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
302-
let res = self.inner.find(options, false, vm)?;
303-
if res == -1 {
304-
return Err(vm.new_value_error("substring not found".to_owned()));
302+
fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
303+
if let Some(index) = self.inner.find(options, false, vm)? {
304+
Ok(index)
305+
} else {
306+
Err(vm.new_value_error("substring not found".to_owned()))
305307
}
306-
Ok(res)
307308
}
308309

309310
#[pymethod(name = "rfind")]
310311
fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
311-
self.inner.find(options, true, vm)
312+
let index = self.inner.find(options, true, vm)?;
313+
Ok(index.map_or(-1, |v| v as isize))
312314
}
313315

314316
#[pymethod(name = "rindex")]
315-
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<isize> {
316-
let res = self.inner.find(options, true, vm)?;
317-
if res == -1 {
318-
return Err(vm.new_value_error("substring not found".to_owned()));
317+
fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult<usize> {
318+
if let Some(index) = self.inner.find(options, true, vm)? {
319+
Ok(index)
320+
} else {
321+
Err(vm.new_value_error("substring not found".to_owned()))
319322
}
320-
Ok(res)
321323
}
322324

323325
#[pymethod(name = "translate")]

0 commit comments

Comments
 (0)