diff --git a/Lib/test/test_bytes.py b/Lib/test/test_bytes.py index 1e1b484c6b..65e12683ea 100644 --- a/Lib/test/test_bytes.py +++ b/Lib/test/test_bytes.py @@ -737,13 +737,11 @@ def test_rsplit_unicodewhitespace(self): b = self.type2test(b"\x09\x0A\x0B\x0C\x0D\x1C\x1D\x1E\x1F") self.assertEqual(b.rsplit(), [b'\x1c\x1d\x1e\x1f']) - @unittest.skip("TODO: RUSTPYTHON") def test_partition(self): b = self.type2test(b'mississippi') self.assertEqual(b.partition(b'ss'), (b'mi', b'ss', b'issippi')) self.assertEqual(b.partition(b'w'), (b'mississippi', b'', b'')) - @unittest.skip("TODO: RUSTPYTHON") def test_rpartition(self): b = self.type2test(b'mississippi') self.assertEqual(b.rpartition(b'ss'), (b'missi', b'ss', b'ippi')) @@ -1578,7 +1576,6 @@ def test_copied(self): x = bytearray(b'') self.assertIsNot(x, x.translate(t)) - @unittest.skip("TODO: RUSTPYTHON") def test_partition_bytearray_doesnt_share_nullstring(self): a, b, c = bytearray(b"x").partition(b"y") self.assertEqual(b, b"") diff --git a/Lib/test/test_unicode.py b/Lib/test/test_unicode.py index 8c0d07ed35..476540e215 100644 --- a/Lib/test/test_unicode.py +++ b/Lib/test/test_unicode.py @@ -419,8 +419,6 @@ def test_rsplit(self): self.checkequal([left, right], left + delim * 2 + right, 'rsplit', delim *2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_partition(self): string_tests.MixinStrUnicodeUserStringTest.test_partition(self) # test mixed kinds @@ -438,8 +436,6 @@ def test_partition(self): self.checkequal((left, delim * 2, right), left + delim * 2 + right, 'partition', delim * 2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_rpartition(self): string_tests.MixinStrUnicodeUserStringTest.test_rpartition(self) # test mixed kinds diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index b62a9603b5..4cf49e746b 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -413,21 +413,25 @@ impl PyByteArray { fn partition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { // sep ALWAYS converted to bytearray even it's bytes or memoryview // so its ok to accept PyByteInner - let (left, right) = self.borrow_value().partition(&sep, false)?; + let value = self.borrow_value(); + let (front, has_mid, back) = value.partition(&sep, vm)?; Ok(vm.ctx.new_tuple(vec![ - vm.ctx.new_bytearray(left), - vm.ctx.new_bytearray(sep.elements), - vm.ctx.new_bytearray(right), + vm.ctx.new_bytearray(front.to_vec()), + vm.ctx + .new_bytearray(if has_mid { sep.elements } else { Vec::new() }), + vm.ctx.new_bytearray(back.to_vec()), ])) } #[pymethod(name = "rpartition")] fn rpartition(&self, sep: PyByteInner, vm: &VirtualMachine) -> PyResult { - let (left, right) = self.borrow_value().partition(&sep, true)?; + let value = self.borrow_value(); + let (front, has_mid, back) = value.rpartition(&sep, vm)?; Ok(vm.ctx.new_tuple(vec![ - vm.ctx.new_bytearray(left), - vm.ctx.new_bytearray(sep.elements), - vm.ctx.new_bytearray(right), + vm.ctx.new_bytearray(front.to_vec()), + vm.ctx + .new_bytearray(if has_mid { sep.elements } else { Vec::new() }), + vm.ctx.new_bytearray(back.to_vec()), ])) } diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index 01bdc49a24..90056738ae 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -1012,13 +1012,42 @@ impl PyByteInner { } } - pub fn partition(&self, sep: &PyByteInner, reverse: bool) -> PyResult<(Vec, Vec)> { - let splitted = if reverse { - split_slice_reverse(&self.elements, &sep.elements, 1) + pub fn partition( + &self, + sub: &PyByteInner, + vm: &VirtualMachine, + ) -> PyResult<(Vec, bool, Vec)> { + if sub.elements.is_empty() { + return Err(vm.new_value_error("empty separator".to_owned())); + } + + let mut sp = self.elements.splitn_str(2, &sub.elements); + let front = sp.next().unwrap().to_vec(); + let (has_mid, back) = if let Some(back) = sp.next() { + (true, back.to_vec()) + } else { + (false, Vec::new()) + }; + Ok((front, has_mid, back)) + } + + pub fn rpartition( + &self, + sub: &PyByteInner, + vm: &VirtualMachine, + ) -> PyResult<(Vec, bool, Vec)> { + if sub.elements.is_empty() { + return Err(vm.new_value_error("empty separator".to_owned())); + } + + let mut sp = self.elements.rsplitn_str(2, &sub.elements); + let back = sp.next().unwrap().to_vec(); + let (has_mid, front) = if let Some(front) = sp.next() { + (true, front.to_vec()) } else { - split_slice(&self.elements, &sep.elements, 1) + (false, Vec::new()) }; - Ok((splitted[0].to_vec(), splitted[1].to_vec())) + Ok((front, has_mid, back)) } pub fn expandtabs(&self, options: ByteInnerExpandtabsOptions) -> Vec { diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 94a65d843a..94272742d6 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -366,21 +366,32 @@ impl PyBytes { #[pymethod(name = "partition")] fn partition(&self, sep: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let sepa = PyByteInner::try_from_object(vm, sep.clone())?; - - let (left, right) = self.inner.partition(&sepa, false)?; - Ok(vm - .ctx - .new_tuple(vec![vm.ctx.new_bytes(left), sep, vm.ctx.new_bytes(right)])) + let sub = PyByteInner::try_from_object(vm, sep.clone())?; + let (front, has_mid, back) = self.inner.partition(&sub, vm)?; + Ok(vm.ctx.new_tuple(vec![ + vm.ctx.new_bytes(front), + if has_mid { + sep + } else { + vm.ctx.new_bytes(Vec::new()) + }, + vm.ctx.new_bytes(back), + ])) } + #[pymethod(name = "rpartition")] fn rpartition(&self, sep: PyObjectRef, vm: &VirtualMachine) -> PyResult { - let sepa = PyByteInner::try_from_object(vm, sep.clone())?; - - let (left, right) = self.inner.partition(&sepa, true)?; - Ok(vm - .ctx - .new_tuple(vec![vm.ctx.new_bytes(left), sep, vm.ctx.new_bytes(right)])) + let sub = PyByteInner::try_from_object(vm, sep.clone())?; + let (front, has_mid, back) = self.inner.rpartition(&sub, vm)?; + Ok(vm.ctx.new_tuple(vec![ + vm.ctx.new_bytes(front), + if has_mid { + sep + } else { + vm.ctx.new_bytes(Vec::new()) + }, + vm.ctx.new_bytes(back), + ])) } #[pymethod(name = "expandtabs")] diff --git a/vm/src/obj/objstr.rs b/vm/src/obj/objstr.rs index a3e8c347ca..4110cfb6d4 100644 --- a/vm/src/obj/objstr.rs +++ b/vm/src/obj/objstr.rs @@ -954,42 +954,43 @@ impl PyString { } #[pymethod] - fn partition(&self, sub: PyStringRef, vm: &VirtualMachine) -> PyObjectRef { - let value = &self.value; - let sub = &sub.value; - let mut new_tup = Vec::new(); - if value.contains(sub) { - new_tup = value - .splitn(2, sub) - .map(|s| vm.ctx.new_str(s.to_owned())) - .collect(); - new_tup.insert(1, vm.ctx.new_str(sub.clone())); - } else { - new_tup.push(vm.ctx.new_str(value.clone())); - new_tup.push(vm.ctx.new_str("".to_owned())); - new_tup.push(vm.ctx.new_str("".to_owned())); + fn partition(&self, sub: PyStringRef, vm: &VirtualMachine) -> PyResult { + if sub.value.is_empty() { + return Err(vm.new_value_error("empty separator".to_owned())); } - vm.ctx.new_tuple(new_tup) + let mut sp = self.value.splitn(2, &sub.value); + let front = sp.next().unwrap(); + let elems = if let Some(back) = sp.next() { + [front, &sub.value, back] + } else { + [front, "", ""] + }; + Ok(vm.ctx.new_tuple( + elems + .iter() + .map(|&s| vm.ctx.new_str(s.to_owned())) + .collect(), + )) } #[pymethod] - fn rpartition(&self, sub: PyStringRef, vm: &VirtualMachine) -> PyObjectRef { - let value = &self.value; - let sub = &sub.value; - let mut new_tup = Vec::new(); - if value.contains(sub) { - new_tup = value - .rsplitn(2, sub) - .map(|s| vm.ctx.new_str(s.to_owned())) - .collect(); - new_tup.swap(0, 1); // so it's in the right order - new_tup.insert(1, vm.ctx.new_str(sub.clone())); - } else { - new_tup.push(vm.ctx.new_str("".to_owned())); - new_tup.push(vm.ctx.new_str("".to_owned())); - new_tup.push(vm.ctx.new_str(value.clone())); + fn rpartition(&self, sub: PyStringRef, vm: &VirtualMachine) -> PyResult { + if sub.value.is_empty() { + return Err(vm.new_value_error("empty separator".to_owned())); } - vm.ctx.new_tuple(new_tup) + let mut sp = self.value.rsplitn(2, &sub.value); + let back = sp.next().unwrap(); + let elems = if let Some(front) = sp.next() { + [front, &sub.value, back] + } else { + ["", "", back] + }; + Ok(vm.ctx.new_tuple( + elems + .iter() + .map(|&s| vm.ctx.new_str(s.to_owned())) + .collect(), + )) } /// Return `true` if the sequence is ASCII titlecase and the sequence is not