diff --git a/Lib/test/string_tests.py b/Lib/test/string_tests.py index ad980812a5..7e544d6c0d 100644 --- a/Lib/test/string_tests.py +++ b/Lib/test/string_tests.py @@ -360,7 +360,6 @@ def test_expandtabs(self): self.checkraises(OverflowError, '\ta\n\tb', 'expandtabs', sys.maxsize) - @unittest.skip("TODO: RUSTPYTHON test_bytes") def test_split(self): # by a char self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'split', '|') @@ -431,7 +430,6 @@ def test_split(self): self.checkraises(ValueError, 'hello', 'split', '') self.checkraises(ValueError, 'hello', 'split', '', 0) - @unittest.skip("TODO: RUSTPYTHON test_bytes") def test_rsplit(self): # by a char self.checkequal(['a', 'b', 'c', 'd'], 'a|b|c|d', 'rsplit', '|') @@ -697,8 +695,6 @@ def test_capitalize(self): self.checkraises(TypeError, 'hello', 'capitalize', 42) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_additional_split(self): self.checkequal(['this', 'is', 'the', 'split', 'function'], 'this is the split function', 'split') @@ -735,8 +731,6 @@ def test_additional_split(self): self.checkequal(['arf', 'barf'], b, 'split', None) self.checkequal(['arf', 'barf'], b, 'split', None, 2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_additional_rsplit(self): self.checkequal(['this', 'is', 'the', 'rsplit', 'function'], 'this is the rsplit function', 'rsplit') diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 65e12683ea..1bbdef59e7 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -722,8 +722,6 @@ def test_split_int_error(self): self.assertRaises(TypeError, self.type2test(b'a b').split, 32) self.assertRaises(TypeError, self.type2test(b'a b').rsplit, 32) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_split_unicodewhitespace(self): for b in (b'a\x1Cb', b'a\x1Db', b'a\x1Eb', b'a\x1Fb'): b = self.type2test(b) @@ -731,8 +729,6 @@ def test_split_unicodewhitespace(self): b = self.type2test(b"\x09\x0A\x0B\x0C\x0D\x1C\x1D\x1E\x1F") self.assertEqual(b.split(), [b'\x1c\x1d\x1e\x1f']) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_rsplit_unicodewhitespace(self): b = self.type2test(b"\x09\x0A\x0B\x0C\x0D\x1C\x1D\x1E\x1F") self.assertEqual(b.rsplit(), [b'\x1c\x1d\x1e\x1f']) @@ -1841,8 +1837,6 @@ class BytearrayPEP3137Test(unittest.TestCase): def marshal(self, x): return bytearray(x) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_returns_new_copy(self): val = self.marshal(b'1234') # On immutable types these MAY return a reference to themselves diff --git a/vm/src/obj/mod.rs b/vm/src/obj/mod.rs index 9a5523e462..f11d0eccfe 100644 --- a/vm/src/obj/mod.rs +++ b/vm/src/obj/mod.rs @@ -44,3 +44,4 @@ pub mod objtype; pub mod objweakproxy; pub mod objweakref; pub mod objzip; +mod pystr; diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 1e87a92a35..9a2fc0cdf5 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -69,8 +69,8 @@ impl PyByteArray { } impl From> for PyByteArray { - fn from(elements: Vec) -> PyByteArray { - PyByteArray::new(elements) + fn from(elements: Vec) -> Self { + Self::new(elements) } } @@ -389,24 +389,14 @@ impl PyByteArray { #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { - let as_bytes = self - .borrow_value() - .split(options, false)? - .iter() - .map(|x| vm.ctx.new_bytearray(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + self.borrow_value() + .split(options, |s, vm| vm.ctx.new_bytearray(s.to_vec()), vm) } #[pymethod(name = "rsplit")] fn rsplit(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { - let as_bytes = self - .borrow_value() - .split(options, true)? - .iter() - .map(|x| vm.ctx.new_bytearray(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + self.borrow_value() + .rsplit(options, |s, vm| vm.ctx.new_bytearray(s.to_vec()), vm) } #[pymethod(name = "partition")] diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 90056738ae..753bbf4e35 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -15,6 +15,7 @@ use super::objsequence::{is_valid_slice_arg, PySliceableSequence}; use super::objslice::PySliceRef; use super::objstr::{self, adjust_indices, PyString, PyStringRef, StringRange}; use super::objtuple::PyTupleRef; +use super::pystr::PyCommonString; use crate::function::{OptionalArg, OptionalOption}; use crate::pyhash; use crate::pyobject::{ @@ -28,6 +29,12 @@ pub struct PyByteInner { pub elements: Vec, } +impl From> for PyByteInner { + fn from(elements: Vec) -> PyByteInner { + Self { elements } + } +} + impl ThreadSafe for PyByteInner {} impl TryFromObject for PyByteInner { @@ -251,26 +258,24 @@ impl ByteInnerTranslateOptions { #[derive(FromArgs)] pub struct ByteInnerSplitOptions { - #[pyarg(positional_or_keyword, optional = true)] - sep: OptionalArg>, - #[pyarg(positional_or_keyword, optional = true)] - maxsplit: OptionalArg, + #[pyarg(positional_or_keyword, default = "None")] + sep: Option, + #[pyarg(positional_or_keyword, default = "-1")] + maxsplit: isize, } impl ByteInnerSplitOptions { - pub fn get_value(self) -> PyResult<(Vec, isize)> { - let sep = match self.sep.into_option() { - Some(Some(bytes)) => bytes.elements, - _ => vec![], - }; - - let maxsplit = if let OptionalArg::Present(value) = self.maxsplit { - value + pub fn get_value(self, vm: &VirtualMachine) -> PyResult<(Option>, isize)> { + let sep = if let Some(s) = self.sep { + let sep = s.elements; + if sep.is_empty() { + return Err(vm.new_value_error("empty separator".to_owned())); + } + Some(sep) } else { - -1 + None }; - - Ok((sep, maxsplit)) + Ok((sep, self.maxsplit)) } } @@ -995,21 +1000,55 @@ impl PyByteInner { .to_owned() } - pub fn split(&self, options: ByteInnerSplitOptions, reverse: bool) -> PyResult> { - let (sep, maxsplit) = options.get_value()?; - - if self.elements.is_empty() { - if !sep.is_empty() { - return Ok(vec![&[]]); - } - return Ok(vec![]); - } + pub fn split( + &self, + options: ByteInnerSplitOptions, + convert: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: Fn(&[u8], &VirtualMachine) -> PyObjectRef, + { + let (sep, maxsplit) = options.get_value(vm)?; + let sep_ref = match sep { + Some(ref v) => Some(&v[..]), + None => None, + }; + let elements = self.elements.py_split( + sep_ref, + maxsplit, + vm, + |v, s, vm| v.split_str(s).map(|v| convert(v, vm)).collect(), + |v, s, n, vm| v.splitn_str(n, s).map(|v| convert(v, vm)).collect(), + |v, n, vm| v.py_split_whitespace(n, |v| convert(v, vm)), + ); + Ok(vm.ctx.new_list(elements)) + } - if reverse { - Ok(split_slice_reverse(&self.elements, &sep, maxsplit)) - } else { - Ok(split_slice(&self.elements, &sep, maxsplit)) - } + pub fn rsplit( + &self, + options: ByteInnerSplitOptions, + convert: F, + vm: &VirtualMachine, + ) -> PyResult + where + F: Fn(&[u8], &VirtualMachine) -> PyObjectRef, + { + let (sep, maxsplit) = options.get_value(vm)?; + let sep_ref = match sep { + Some(ref v) => Some(&v[..]), + None => None, + }; + let mut elements = self.elements.py_split( + sep_ref, + maxsplit, + vm, + |v, s, vm| v.rsplit_str(s).map(|v| convert(v, vm)).collect(), + |v, s, n, vm| v.rsplitn_str(n, s).map(|v| convert(v, vm)).collect(), + |v, n, vm| v.py_rsplit_whitespace(n, |v| convert(v, vm)), + ); + elements.reverse(); + Ok(vm.ctx.new_list(elements)) } pub fn partition( @@ -1247,187 +1286,6 @@ pub trait ByteOr: ToPrimitive { impl ByteOr for BigInt {} -fn split_slice<'a>(slice: &'a [u8], sep: &[u8], maxsplit: isize) -> Vec<&'a [u8]> { - let mut splitted: Vec<&[u8]> = vec![]; - let mut prev_index = 0; - let mut index = 0; - let mut count = 0; - let mut in_string = false; - - // No sep given, will split for any \t \n \r and space = [9, 10, 13, 32] - if sep.is_empty() { - // split wihtout sep always trim left spaces for any maxsplit - // so we have to ignore left spaces. - loop { - if [9, 10, 13, 32].contains(&slice[index]) { - index += 1 - } else { - prev_index = index; - break; - } - } - - // most simple case - if maxsplit == 0 { - splitted.push(&slice[index..slice.len()]); - return splitted; - } - - // main loop. in_string means previous char is ascii char(true) or space(false) - // loop from left to right - loop { - if [9, 10, 13, 32].contains(&slice[index]) { - if in_string { - splitted.push(&slice[prev_index..index]); - in_string = false; - count += 1; - if count == maxsplit { - // while index < slice.len() - splitted.push(&slice[index + 1..slice.len()]); - break; - } - } - } else if !in_string { - prev_index = index; - in_string = true; - } - - index += 1; - - // handle last item in slice - if index == slice.len() { - if in_string { - if [9, 10, 13, 32].contains(&slice[index - 1]) { - splitted.push(&slice[prev_index..index - 1]); - } else { - splitted.push(&slice[prev_index..index]); - } - } - break; - } - } - } else { - // sep is given, we match exact slice - while index != slice.len() { - if index + sep.len() >= slice.len() { - if &slice[index..slice.len()] == sep { - splitted.push(&slice[prev_index..index]); - splitted.push(&[]); - break; - } - splitted.push(&slice[prev_index..slice.len()]); - break; - } - - if &slice[index..index + sep.len()] == sep { - splitted.push(&slice[prev_index..index]); - index += sep.len(); - prev_index = index; - count += 1; - if count == maxsplit { - // maxsplit reached, append, the remaing - splitted.push(&slice[prev_index..slice.len()]); - break; - } - continue; - } - - index += 1; - } - } - splitted -} - -fn split_slice_reverse<'a>(slice: &'a [u8], sep: &[u8], maxsplit: isize) -> Vec<&'a [u8]> { - let mut splitted: Vec<&[u8]> = vec![]; - let mut prev_index = slice.len(); - let mut index = slice.len(); - let mut count = 0; - - // No sep given, will split for any \t \n \r and space = [9, 10, 13, 32] - if sep.is_empty() { - //adjust index - index -= 1; - - // rsplit without sep always trim right spaces for any maxsplit - // so we have to ignore right spaces. - loop { - if [9, 10, 13, 32].contains(&slice[index]) { - index -= 1 - } else { - break; - } - } - prev_index = index + 1; - - // most simple case - if maxsplit == 0 { - splitted.push(&slice[0..=index]); - return splitted; - } - - // main loop. in_string means previous char is ascii char(true) or space(false) - // loop from right to left and reverse result the end - let mut in_string = true; - loop { - if [9, 10, 13, 32].contains(&slice[index]) { - if in_string { - splitted.push(&slice[index + 1..prev_index]); - count += 1; - if count == maxsplit { - // maxsplit reached, append, the remaing - splitted.push(&slice[0..index]); - break; - } - in_string = false; - index -= 1; - continue; - } - } else if !in_string { - in_string = true; - if index == 0 { - splitted.push(&slice[0..1]); - break; - } - prev_index = index + 1; - } - if index == 0 { - break; - } - index -= 1; - } - } else { - // sep is give, we match exact slice going backwards - while index != 0 { - if index <= sep.len() { - if &slice[0..index] == sep { - splitted.push(&slice[index..prev_index]); - splitted.push(&[]); - break; - } - splitted.push(&slice[0..prev_index]); - break; - } - if &slice[(index - sep.len())..index] == sep { - splitted.push(&slice[index..prev_index]); - index -= sep.len(); - prev_index = index; - count += 1; - if count == maxsplit { - // maxsplit reached, append, the remaing - splitted.push(&slice[0..prev_index]); - break; - } - continue; - } - - index -= 1; - } - } - splitted.reverse(); - splitted -} - pub enum PyBytesLike { Bytes(PyBytesRef), Bytearray(PyByteArrayRef), @@ -1480,3 +1338,53 @@ pub fn bytes_zfill(bytes: &[u8], width: usize) -> Vec { filled } } + +const ASCII_WHITESPACES: [u8; 6] = [0x20, 0x09, 0x0a, 0x0c, 0x0d, 0x0b]; + +impl PyCommonString<'_, u8> for [u8] { + fn py_split_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + let mut splited = Vec::new(); + let mut count = maxsplit; + let mut haystack = &self[..]; + while let Some(offset) = haystack.find_byteset(ASCII_WHITESPACES) { + if offset != 0 { + if count == 0 { + break; + } + splited.push(convert(&haystack[..offset])); + count -= 1; + } + haystack = &haystack[offset + 1..]; + } + if !haystack.is_empty() { + splited.push(convert(haystack)); + } + splited + } + + fn py_rsplit_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + let mut splited = Vec::new(); + let mut count = maxsplit; + let mut haystack = &self[..]; + while let Some(offset) = haystack.rfind_byteset(ASCII_WHITESPACES) { + if offset + 1 != haystack.len() { + if count == 0 { + break; + } + splited.push(convert(&haystack[offset + 1..])); + count -= 1; + } + haystack = &haystack[..offset]; + } + if !haystack.is_empty() { + splited.push(convert(haystack)); + } + splited + } +} diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 94272742d6..1e7675f23d 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -57,8 +57,8 @@ impl PyBytes { } impl From> for PyBytes { - fn from(elements: Vec) -> PyBytes { - PyBytes::new(elements) + fn from(elements: Vec) -> Self { + Self::new(elements) } } @@ -344,24 +344,14 @@ impl PyBytes { #[pymethod(name = "split")] fn split(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { - let as_bytes = self - .inner - .split(options, false)? - .iter() - .map(|x| vm.ctx.new_bytes(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + self.inner + .split(options, |s, vm| vm.ctx.new_bytes(s.to_vec()), vm) } #[pymethod(name = "rsplit")] fn rsplit(&self, options: ByteInnerSplitOptions, vm: &VirtualMachine) -> PyResult { - let as_bytes = self - .inner - .split(options, true)? - .iter() - .map(|x| vm.ctx.new_bytes(x.to_vec())) - .collect::>(); - Ok(vm.ctx.new_list(as_bytes)) + self.inner + .rsplit(options, |s, vm| vm.ctx.new_bytes(s.to_vec()), vm) } #[pymethod(name = "partition")] diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index 4110cfb6d4..59e082e501 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -21,6 +21,7 @@ use super::objsequence::PySliceableSequence; use super::objslice::PySliceRef; use super::objtuple; use super::objtype::{self, PyClassRef}; +use super::pystr::PyCommonString; use crate::cformat::{ CFormatPart, CFormatPreconversor, CFormatQuantity, CFormatSpec, CFormatString, CFormatType, CNumberType, @@ -455,61 +456,27 @@ impl PyString { #[pymethod] fn split(&self, args: SplitArgs, vm: &VirtualMachine) -> PyResult { - let value = &self.value; - let pattern = args.non_empty_sep(vm)?; - let num_splits = args.maxsplit; - let elements: Vec<_> = match (pattern, num_splits.is_negative()) { - (Some(pattern), true) => value - .split(pattern) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (Some(pattern), false) => value - .splitn(num_splits as usize + 1, pattern) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (None, true) => value - .trim_start() - .split(|c: char| c.is_ascii_whitespace()) - .filter(|s| !s.is_empty()) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (None, false) => value - .trim_start() - .splitn(num_splits as usize + 1, |c: char| c.is_ascii_whitespace()) - .filter(|s| !s.is_empty()) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - }; + let elements = self.value.py_split( + args.non_empty_sep(vm)?, + args.maxsplit, + vm, + |v, s, vm| v.split(s).map(|s| vm.ctx.new_str(s)).collect(), + |v, s, n, vm| v.splitn(n, s).map(|s| vm.ctx.new_str(s)).collect(), + |v, n, vm| v.py_split_whitespace(n, |s| vm.ctx.new_str(s)), + ); Ok(vm.ctx.new_list(elements)) } #[pymethod] fn rsplit(&self, args: SplitArgs, vm: &VirtualMachine) -> PyResult { - let value = &self.value; - let pattern = args.non_empty_sep(vm)?; - let num_splits = args.maxsplit; - let mut elements: Vec<_> = match (pattern, num_splits.is_negative()) { - (Some(pattern), true) => value - .rsplit(pattern) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (Some(pattern), false) => value - .rsplitn(num_splits as usize + 1, pattern) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (None, true) => value - .trim_end() - .rsplit(|c: char| c.is_ascii_whitespace()) - .filter(|s| !s.is_empty()) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - (None, false) => value - .trim_end() - .rsplitn(num_splits as usize + 1, |c: char| c.is_ascii_whitespace()) - .filter(|s| !s.is_empty()) - .map(|o| vm.ctx.new_str(o.to_owned())) - .collect(), - }; + let mut elements = self.value.py_split( + args.non_empty_sep(vm)?, + args.maxsplit, + vm, + |v, s, vm| v.rsplit(s).map(|s| vm.ctx.new_str(s)).collect(), + |v, s, n, vm| v.rsplitn(n, s).map(|s| vm.ctx.new_str(s)).collect(), + |v, n, vm| v.py_rsplit_whitespace(n, |s| vm.ctx.new_str(s)), + ); // Unlike Python rsplit, Rust rsplitn returns an iterator that // starts from the end of the string. elements.reverse(); @@ -1882,3 +1849,57 @@ mod tests { assert_eq!(translated.unwrap_err().class().name, "TypeError".to_owned()); } } + +impl PyCommonString<'_, char> for str { + fn py_split_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + // CPython split_whitespace + let mut splited = Vec::new(); + let mut last_offset = 0; + let mut count = maxsplit; + for (offset, _) in self.match_indices(|c: char| c.is_ascii_whitespace() || c == '\x0b') { + if last_offset == offset { + last_offset += 1; + continue; + } + if count == 0 { + break; + } + splited.push(convert(&self[last_offset..offset])); + last_offset = offset + 1; + count -= 1; + } + if last_offset != self.len() { + splited.push(convert(&self[last_offset..])); + } + splited + } + + fn py_rsplit_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef, + { + // CPython rsplit_whitespace + let mut splited = Vec::new(); + let mut last_offset = self.len(); + let mut count = maxsplit; + for (offset, _) in self.rmatch_indices(|c: char| c.is_ascii_whitespace() || c == '\x0b') { + if last_offset == offset + 1 { + last_offset -= 1; + continue; + } + if count == 0 { + break; + } + splited.push(convert(&self[offset + 1..last_offset])); + last_offset = offset; + count -= 1; + } + if last_offset != 0 { + splited.push(convert(&self[..last_offset])); + } + splited + } +} diff --git a/vm/src/obj/pystr.rs b/vm/src/obj/pystr.rs new file mode 100644 index 0000000000..54be550a15 --- /dev/null +++ b/vm/src/obj/pystr.rs @@ -0,0 +1,38 @@ +use crate::pyobject::PyObjectRef; +use crate::vm::VirtualMachine; + +pub trait PyCommonString<'a, E> +where + Self: 'a, +{ + fn py_split( + &self, + sep: Option<&Self>, + maxsplit: isize, + vm: &VirtualMachine, + split: SP, + splitn: SN, + splitw: SW, + ) -> Vec + where + SP: Fn(&Self, &Self, &VirtualMachine) -> Vec, + SN: Fn(&Self, &Self, usize, &VirtualMachine) -> Vec, + SW: Fn(&Self, isize, &VirtualMachine) -> Vec, + { + if let Some(pattern) = sep { + if maxsplit < 0 { + split(self, pattern, vm) + } else { + splitn(self, pattern, (maxsplit + 1) as usize, vm) + } + } else { + splitw(self, maxsplit, vm) + } + } + fn py_split_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef; + fn py_rsplit_whitespace(&self, maxsplit: isize, convert: F) -> Vec + where + F: Fn(&Self) -> PyObjectRef; +} diff --git a/vm/src/pyobject.rs b/vm/src/pyobject.rs index 03987778af..3f0bfdc44b 100644 --- a/vm/src/pyobject.rs +++ b/vm/src/pyobject.rs @@ -420,7 +420,10 @@ impl PyContext { PyObject::new(PyComplex::from(value), self.complex_type(), None) } - pub fn new_str(&self, s: String) -> PyObjectRef { + pub fn new_str(&self, s: S) -> PyObjectRef + where + objstr::PyString: std::convert::From, + { PyObject::new(objstr::PyString::from(s), self.str_type(), None) } @@ -1305,7 +1308,7 @@ pub trait PyClassImpl: PyClassDef { class.slots.borrow_mut().flags = Self::TP_FLAGS; ctx.add_tp_new_wrapper(&class); if let Some(doc) = Self::DOC { - class.set_str_attr("__doc__", ctx.new_str(doc.into())); + class.set_str_attr("__doc__", ctx.new_str(doc)); } }