diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 0bdd1a37e8..643ac188b3 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -606,8 +606,6 @@ def test_find(self): ValueError, r'byte must be in range\(0, 256\)', b.find, index) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_rfind(self): b = self.type2test(b'mississippi') i = 105 @@ -647,8 +645,6 @@ def test_index(self): self.assertEqual(b.index(i, 1, 3), 1) self.assertRaises(ValueError, b.index, w, 1, 3) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_rindex(self): b = self.type2test(b'mississippi') i = 105 diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 440749730e..7af25d1efc 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -324,30 +324,26 @@ impl PyByteArray { #[pymethod(name = "find")] fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().find(options, false, vm) + let index = self.inner.borrow().find(options, false, vm)?; + Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "index")] - fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let res = self.inner.borrow().find(options, false, vm)?; - if res == -1 { - return Err(vm.new_value_error("substring not found".to_owned())); - } - Ok(res) + fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let index = self.inner.borrow().find(options, false, vm)?; + index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "rfind")] fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.borrow().find(options, true, vm) + let index = self.inner.borrow().find(options, true, vm)?; + Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "rindex")] - fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let res = self.inner.borrow().find(options, true, vm)?; - if res == -1 { - return Err(vm.new_value_error("substring not found".to_owned())); - } - Ok(res) + fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let index = self.inner.borrow().find(options, true, vm)?; + index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "remove")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 58030e219e..47bceeb0c7 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -13,7 +13,7 @@ use super::objmemory::PyMemoryView; use super::objnone::PyNoneRef; use super::objsequence::{is_valid_slice_arg, PySliceableSequence}; use super::objslice::PySliceRef; -use super::objstr::{self, PyString, PyStringRef}; +use super::objstr::{self, adjust_indices, PyString, PyStringRef, StringRange}; use super::objtuple::PyTupleRef; use crate::function::OptionalArg; use crate::pyhash; @@ -137,34 +137,22 @@ pub struct ByteInnerFindOptions { #[pyarg(positional_only, optional = false)] sub: Either, #[pyarg(positional_only, optional = true)] - start: OptionalArg>, + start: OptionalArg>, #[pyarg(positional_only, optional = true)] - end: OptionalArg>, + end: OptionalArg>, } impl ByteInnerFindOptions { pub fn get_value( self, - elements: &[u8], + len: usize, vm: &VirtualMachine, - ) -> PyResult<(Vec, Range)> { + ) -> PyResult<(Vec, std::ops::Range)> { let sub = match self.sub { Either::A(v) => v.elements.to_vec(), Either::B(int) => vec![int.as_bigint().byte_or(vm)?], }; - - let start = match self.start { - OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()), - _ => None, - }; - - let end = match self.end { - OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()), - _ => None, - }; - - let range = elements.to_vec().get_slice_range(&start, &end); - + let range = adjust_indices(self.start, self.end, len); Ok((sub, range)) } } @@ -808,25 +796,18 @@ impl PyByteInner { } pub fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let (sub, range) = options.get_value(&self.elements, vm)?; - - if sub.is_empty() { - return Ok(self.len() + 1); + let (needle, range) = options.get_value(self.elements.len(), vm)?; + if !range.is_normal() { + return Ok(0); } - - let mut total: usize = 0; - let mut i_start = range.start; - let i_end = range.end; - - for i in self.elements.do_slice(range) { - if i_start + sub.len() <= i_end - && i == sub[0] - && &self.elements[i_start..(i_start + sub.len())] == sub.as_slice() - { - total += 1; - } - i_start += 1; + if needle.is_empty() { + return Ok(range.len() + 1); } + let haystack = &self.elements[range]; + let total = haystack + .windows(needle.len()) + .filter(|w| *w == needle.as_slice()) + .count(); Ok(total) } @@ -884,37 +865,36 @@ impl PyByteInner { Ok(suff.as_slice() == &self.elements.do_slice(range)[offset]) } + #[inline] pub fn find( &self, options: ByteInnerFindOptions, reverse: bool, vm: &VirtualMachine, - ) -> PyResult { - let (sub, range) = options.get_value(&self.elements, vm)?; - // not allowed for this method - if range.end < range.start { - return Ok(-1isize); + ) -> PyResult> { + let (needle, range) = options.get_value(self.elements.len(), vm)?; + if !range.is_normal() { + return Ok(None); } - - let start = range.start; - let end = range.end; - + if needle.is_empty() { + return Ok(Some(if reverse { range.end } else { range.start })); + } + let haystack = &self.elements[range.clone()]; + let windows = haystack.windows(needle.len()); if reverse { - let slice = self.elements.do_slice_reverse(range); - for (n, _) in slice.iter().enumerate() { - if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() { - return Ok((end - n - 1) as isize); + for (i, w) in windows.rev().enumerate() { + if w == needle.as_slice() { + return Ok(Some(range.end - i - needle.len())); } } } else { - let slice = self.elements.do_slice(range); - for (n, _) in slice.iter().enumerate() { - if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() { - return Ok((start + n) as isize); + for (i, w) in windows.enumerate() { + if w == needle.as_slice() { + return Ok(Some(range.start + i)); } } - }; - Ok(-1isize) + } + Ok(None) } pub fn maketrans(from: PyByteInner, to: PyByteInner, vm: &VirtualMachine) -> PyResult { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 762a075f60..18d8a7698d 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -294,30 +294,26 @@ impl PyBytes { #[pymethod(name = "find")] fn find(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.find(options, false, vm) + let index = self.inner.find(options, false, vm)?; + Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "index")] - fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let res = self.inner.find(options, false, vm)?; - if res == -1 { - return Err(vm.new_value_error("substring not found".to_owned())); - } - Ok(res) + fn index(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let index = self.inner.find(options, false, vm)?; + index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "rfind")] fn rfind(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - self.inner.find(options, true, vm) + let index = self.inner.find(options, true, vm)?; + Ok(index.map_or(-1, |v| v as isize)) } #[pymethod(name = "rindex")] - fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { - let res = self.inner.find(options, true, vm)?; - if res == -1 { - return Err(vm.new_value_error("substring not found".to_owned())); - } - Ok(res) + fn rindex(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let index = self.inner.find(options, true, vm)?; + index.ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod(name = "translate")] diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 731f7054b8..c1e698609c 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -579,8 +579,9 @@ impl PyString { end: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - if let Some((start, end)) = adjust_indices(start, end, self.value.len()) { - let value = &self.value[start..end]; + let range = adjust_indices(start, end, self.value.len()); + if range.is_normal() { + let value = &self.value[range]; single_or_tuple_any( suffix, |s: PyStringRef| Ok(value.ends_with(&s.value)), @@ -605,8 +606,9 @@ impl PyString { end: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - if let Some((start, end)) = adjust_indices(start, end, self.value.len()) { - let value = &self.value[start..end]; + let range = adjust_indices(start, end, self.value.len()); + if range.is_normal() { + let value = &self.value[range]; single_or_tuple_any( prefix, |s: PyStringRef| Ok(value.starts_with(&s.value)), @@ -898,6 +900,25 @@ impl PyString { Ok(joined) } + fn _find( + &self, + sub: PyStringRef, + start: OptionalArg>, + end: OptionalArg>, + find: F, + ) -> Option + where + F: Fn(&str, &str) -> Option, + { + let range = adjust_indices(start, end, self.value.len()); + if range.is_normal() { + if let Some(index) = find(&self.value[range.clone()], &sub.value) { + return Some(range.start + index); + } + } + None + } + #[pymethod] fn find( &self, @@ -905,15 +926,8 @@ impl PyString { start: OptionalArg>, end: OptionalArg>, ) -> isize { - let value = &self.value; - if let Some((start, end)) = adjust_indices(start, end, value.len()) { - match value[start..end].find(&sub.value) { - Some(num) => (start + num) as isize, - None => -1 as isize, - } - } else { - -1 as isize - } + self._find(sub, start, end, |r, s| r.find(s)) + .map_or(-1, |v| v as isize) } #[pymethod] @@ -923,15 +937,8 @@ impl PyString { start: OptionalArg>, end: OptionalArg>, ) -> isize { - let value = &self.value; - if let Some((start, end)) = adjust_indices(start, end, value.len()) { - match value[start..end].rfind(&sub.value) { - Some(num) => (start + num) as isize, - None => -1 as isize, - } - } else { - -1 as isize - } + self._find(sub, start, end, |r, s| r.rfind(s)) + .map_or(-1, |v| v as isize) } #[pymethod] @@ -942,15 +949,8 @@ impl PyString { end: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - let value = &self.value; - if let Some((start, end)) = adjust_indices(start, end, value.len()) { - match value[start..end].find(&sub.value) { - Some(num) => Ok(start + num), - None => Err(vm.new_value_error("substring not found".to_owned())), - } - } else { - Err(vm.new_value_error("substring not found".to_owned())) - } + self._find(sub, start, end, |r, s| r.find(s)) + .ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod] @@ -961,15 +961,8 @@ impl PyString { end: OptionalArg>, vm: &VirtualMachine, ) -> PyResult { - let value = &self.value; - if let Some((start, end)) = adjust_indices(start, end, value.len()) { - match value[start..end].rfind(&sub.value) { - Some(num) => Ok(start + num), - None => Err(vm.new_value_error("substring not found".to_owned())), - } - } else { - Err(vm.new_value_error("substring not found".to_owned())) - } + self._find(sub, start, end, |r, s| r.rfind(s)) + .ok_or_else(|| vm.new_value_error("substring not found".to_owned())) } #[pymethod] @@ -1048,9 +1041,9 @@ impl PyString { start: OptionalArg>, end: OptionalArg>, ) -> usize { - let value = &self.value; - if let Some((start, end)) = adjust_indices(start, end, value.len()) { - self.value[start..end].matches(&sub.value).count() + let range = adjust_indices(start, end, self.value.len()); + if range.is_normal() { + self.value[range].matches(&sub.value).count() } else { 0 } @@ -1764,12 +1757,22 @@ impl PySliceableSequence for String { } } +pub trait StringRange { + fn is_normal(&self) -> bool; +} + +impl StringRange for std::ops::Range { + fn is_normal(&self) -> bool { + self.start <= self.end + } +} + // help get optional string indices -fn adjust_indices( +pub fn adjust_indices( start: OptionalArg>, end: OptionalArg>, len: usize, -) -> Option<(usize, usize)> { +) -> std::ops::Range { let mut start = start.flat_option().unwrap_or(0); let mut end = end.flat_option().unwrap_or(len as isize); if end > len as isize { @@ -1786,11 +1789,7 @@ fn adjust_indices( start = 0; } } - if start > end { - None - } else { - Some((start as usize, end as usize)) - } + start as usize..end as usize } // According to python following categories aren't printable: