diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py index c496bfeffb..2733533d8f 100644 --- a/tests/snippets/bytes.py +++ b/tests/snippets/bytes.py @@ -6,12 +6,20 @@ assert bytes(range(4)) assert bytes(3) assert b"bla" -assert bytes("bla", "utf8") +assert bytes("bla", "utf8") == bytes("bla", encoding="utf-8") == b"bla" with assertRaises(TypeError): bytes("bla") +with assertRaises(TypeError): + bytes("bla", encoding=b"jilj") -assert b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" == bytes(range(0,256)) -assert b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff' == bytes(range(0,256)) +assert ( + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" + == bytes(range(0, 256)) +) +assert ( + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" + == bytes(range(0, 256)) +) assert b"omkmok\Xaa" == bytes([111, 109, 107, 109, 111, 107, 92, 88, 97, 97]) @@ -113,7 +121,7 @@ assert bytes(b"Is Title Case").istitle() assert not bytes(b"is Not title casE").istitle() -# upper lower, capitalize +# upper lower, capitalize, swapcase l = bytes(b"lower") b = bytes(b"UPPER") assert l.lower().islower() @@ -121,6 +129,8 @@ assert l.capitalize() == b"Lower" assert b.capitalize() == b"Upper" assert bytes().capitalize() == bytes() +assert b"AaBbCc123'@/".swapcase().swapcase() == b"AaBbCc123'@/" +assert b"AaBbCc123'@/".swapcase() == b"aAbBcC123'@/" # hex from hex assert bytes([0, 1, 9, 23, 90, 234]).hex() == "000109175aea" @@ -134,7 +144,8 @@ bytes.fromhex("6Z2") except ValueError as e: str(e) == "non-hexadecimal number found in fromhex() arg at position 1" - +with assertRaises(TypeError): + bytes.fromhex(b"hhjjk") # center assert [b"koki".center(i, b"|") for i in range(3, 10)] == [ b"koki", @@ -161,4 +172,176 @@ b"b".center(2, "a") with assertRaises(TypeError): b"b".center(2, b"ba") -b"kok".center(5, bytearray(b"x")) +with assertRaises(TypeError): + b"b".center(b"ba") +assert b"kok".center(5, bytearray(b"x")) == b"xkokx" +b"kok".center(-5) + + +# ljust +assert [b"koki".ljust(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"koki|", + b"koki||", + b"koki|||", + b"koki||||", + b"koki|||||", +] +assert [b"kok".ljust(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"kok|", + b"kok||", + b"kok|||", + b"kok||||", + b"kok|||||", + b"kok||||||", +] + +b"kok".ljust(4) == b"kok " # " test no arg" +with assertRaises(TypeError): + b"b".ljust(2, "a") +with assertRaises(TypeError): + b"b".ljust(2, b"ba") +with assertRaises(TypeError): + b"b".ljust(b"ba") +assert b"kok".ljust(5, bytearray(b"x")) == b"kokxx" +assert b"kok".ljust(-5) == b"kok" + +# rjust +assert [b"koki".rjust(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"|koki", + b"||koki", + b"|||koki", + b"||||koki", + b"|||||koki", +] +assert [b"kok".rjust(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"|kok", + b"||kok", + b"|||kok", + b"||||kok", + b"|||||kok", + b"||||||kok", +] + + +b"kok".rjust(4) == b" kok" # " test no arg" +with assertRaises(TypeError): + b"b".rjust(2, "a") +with assertRaises(TypeError): + b"b".rjust(2, b"ba") +with assertRaises(TypeError): + b"b".rjust(b"ba") +assert b"kok".rjust(5, bytearray(b"x")) == b"xxkok" +assert b"kok".rjust(-5) == b"kok" + + +# count +assert b"azeazerazeazopia".count(b"aze") == 3 +assert b"azeazerazeazopia".count(b"az") == 4 +assert b"azeazerazeazopia".count(b"a") == 5 +assert b"123456789".count(b"") == 10 +assert b"azeazerazeazopia".count(bytearray(b"aze")) == 3 +assert b"azeazerazeazopia".count(memoryview(b"aze")) == 3 +assert b"azeazerazeazopia".count(memoryview(b"aze"), 1, 9) == 1 +assert b"azeazerazeazopia".count(b"aze", None, None) == 3 +assert b"azeazerazeazopia".count(b"aze", 2, None) == 2 +assert b"azeazerazeazopia".count(b"aze", 2) == 2 +assert b"azeazerazeazopia".count(b"aze", None, 7) == 2 +assert b"azeazerazeazopia".count(b"aze", None, 7) == 2 +assert b"azeazerazeazopia".count(b"aze", 2, 7) == 1 +assert b"azeazerazeazopia".count(b"aze", -13, -10) == 1 +assert b"azeazerazeazopia".count(b"aze", 1, 10000) == 2 +with assertRaises(ValueError): + b"ilj".count(3550) +assert b"azeazerazeazopia".count(97) == 5 + +# join +assert ( + b"".join((b"jiljl", bytearray(b"kmoomk"), memoryview(b"aaaa"))) + == b"jiljlkmoomkaaaa" +) +with assertRaises(TypeError): + b"".join((b"km", "kl")) + + +# endswith startswith +assert b"abcde".endswith(b"de") +assert b"abcde".endswith(b"") +assert not b"abcde".endswith(b"zx") +assert b"abcde".endswith(b"bc", 0, 3) +assert not b"abcde".endswith(b"bc", 2, 3) +assert b"abcde".endswith((b"c", b"de")) + +assert b"abcde".startswith(b"ab") +assert b"abcde".startswith(b"") +assert not b"abcde".startswith(b"zx") +assert b"abcde".startswith(b"cd", 2) +assert not b"abcde".startswith(b"cd", 1, 4) +assert b"abcde".startswith((b"a", b"bc")) + + +# index find +assert b"abcd".index(b"cd") == 2 +assert b"abcd".index(b"cd", 0) == 2 +assert b"abcd".index(b"cd", 1) == 2 +assert b"abcd".index(99) == 2 +with assertRaises(ValueError): + b"abcde".index(b"c", 3, 1) +with assertRaises(ValueError): + b"abcd".index(b"cdaaaaa") +with assertRaises(ValueError): + b"abcd".index(b"b", 3, 4) +with assertRaises(ValueError): + b"abcd".index(1) + + +assert b"abcd".find(b"cd") == 2 +assert b"abcd".find(b"cd", 0) == 2 +assert b"abcd".find(b"cd", 1) == 2 +assert b"abcde".find(b"c", 3, 1) == -1 +assert b"abcd".find(b"cdaaaaa") == -1 +assert b"abcd".find(b"b", 3, 4) == -1 +assert b"abcd".find(1) == -1 +assert b"abcd".find(99) == 2 + +assert b"abcdabcda".find(b"a") == 0 +assert b"abcdabcda".rfind(b"a") == 8 +assert b"abcdabcda".rfind(b"a", 2, 6) == 4 +assert b"abcdabcda".rfind(b"a", None, 6) == 4 +assert b"abcdabcda".rfind(b"a", 2, None) == 8 +assert b"abcdabcda".index(b"a") == 0 +assert b"abcdabcda".rindex(b"a") == 8 + + +# make trans +# fmt: off +assert ( + bytes.maketrans(memoryview(b"abc"), bytearray(b"zzz")) + == bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 122, 122, 122, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]) +) +# fmt: on + +# translate +assert b"hjhtuyjyujuyj".translate(bytes.maketrans(b"hj", b"ab"), b"h") == b"btuybyubuyb" +assert ( + b"hjhtuyjyujuyj".translate(bytes.maketrans(b"hj", b"ab"), b"a") == b"abatuybyubuyb" +) +assert b"hjhtuyjyujuyj".translate(bytes.maketrans(b"hj", b"ab")) == b"abatuybyubuyb" +assert b"hjhtuyfjtyhuhjuyj".translate(None, b"ht") == b"juyfjyujuyj" +assert b"hjhtuyfjtyhuhjuyj".translate(None, delete=b"ht") == b"juyfjyujuyj" + + +# strip lstrip rstrip +assert b" spacious ".strip() == b"spacious" +assert b"www.example.com".strip(b"cmowz.") == b"example" +assert b" spacious ".lstrip() == b"spacious " +assert b"www.example.com".lstrip(b"cmowz.") == b"example.com" +assert b" spacious ".rstrip() == b" spacious" +assert b"mississippi".rstrip(b"ipz") == b"mississ" diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index d0a63633ac..b75f8be697 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -1,4 +1,13 @@ -use crate::pyobject::PyObjectRef; +use crate::obj::objint::PyIntRef; +use crate::obj::objnone::PyNoneRef; +use crate::obj::objslice::PySliceRef; +use crate::obj::objtuple::PyTupleRef; +use crate::pyobject::Either; +use crate::pyobject::PyRef; +use crate::pyobject::PyValue; +use crate::pyobject::TryFromObject; +use crate::pyobject::{PyIterable, PyObjectRef}; +use core::ops::Range; use num_bigint::BigInt; use crate::function::OptionalArg; @@ -7,52 +16,84 @@ use crate::vm::VirtualMachine; use crate::pyobject::{PyResult, TypeProtocol}; -use crate::obj::objstr::PyString; +use crate::obj::objstr::{PyString, PyStringRef}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use super::objint; -use super::objsequence::PySliceableSequence; +use super::objsequence::{is_valid_slice_arg, PySliceableSequence}; + use crate::obj::objint::PyInt; use num_integer::Integer; use num_traits::ToPrimitive; use super::objbytearray::{get_value as get_value_bytearray, PyByteArray}; use super::objbytes::PyBytes; +use super::objmemory::PyMemoryView; + +use super::objsequence; #[derive(Debug, Default, Clone)] pub struct PyByteInner { pub elements: Vec, } -impl PyByteInner { - pub fn new( - val_option: OptionalArg, - enc_option: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { +impl TryFromObject for PyByteInner { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + match_class!(obj, + + i @ PyBytes => Ok(PyByteInner{elements: i.get_value().to_vec()}), + j @ PyByteArray => Ok(PyByteInner{elements: get_value_bytearray(&j.as_object()).to_vec()}), + k @ PyMemoryView => Ok(PyByteInner{elements: k.get_obj_value().unwrap()}), + obj => Err(vm.new_type_error(format!( + "a bytes-like object is required, not {}", + obj.class() + ))) + ) + } +} + +impl TryFromObject for Either> { + fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult { + match PyByteInner::try_from_object(vm, obj.clone()) { + Ok(a) => Ok(Either::A(a)), + Err(_) => match obj.clone().downcast::() { + Ok(b) => Ok(Either::B(b)), + Err(_) => Err(vm.new_type_error(format!( + "a bytes-like object or {} is required, not {}", + B::class(vm), + obj.class() + ))), + }, + } + } +} + +#[derive(FromArgs)] +pub struct ByteInnerNewOptions { + #[pyarg(positional_only, optional = true)] + val_option: OptionalArg, + #[pyarg(positional_or_keyword, optional = true)] + encoding: OptionalArg, +} + +impl ByteInnerNewOptions { + pub fn get_value(self, vm: &VirtualMachine) -> PyResult { // First handle bytes(string, encoding[, errors]) - if let OptionalArg::Present(enc) = enc_option { - if let OptionalArg::Present(eval) = val_option { + if let OptionalArg::Present(enc) = self.encoding { + if let OptionalArg::Present(eval) = self.val_option { if let Ok(input) = eval.downcast::() { - if let Ok(encoding) = enc.clone().downcast::() { - if &encoding.value.to_lowercase() == "utf8" - || &encoding.value.to_lowercase() == "utf-8" - // TODO: different encoding - { - return Ok(PyByteInner { - elements: input.value.as_bytes().to_vec(), - }); - } else { - return Err( - vm.new_value_error(format!("unknown encoding: {}", encoding.value)), //should be lookup error - ); - } + let encoding = enc.as_str(); + if encoding.to_lowercase() == "utf8" || encoding.to_lowercase() == "utf-8" + // TODO: different encoding + { + return Ok(PyByteInner { + elements: input.value.as_bytes().to_vec(), + }); } else { - return Err(vm.new_type_error(format!( - "bytes() argument 2 must be str, not {}", - enc.class().name - ))); + return Err( + vm.new_value_error(format!("unknown encoding: {}", encoding)), //should be lookup error + ); } } else { return Err(vm.new_type_error("encoding without a string argument".to_string())); @@ -62,7 +103,7 @@ impl PyByteInner { } // Only one argument } else { - let value = if let OptionalArg::Present(ival) = val_option { + let value = if let OptionalArg::Present(ival) = self.val_option { match_class!(ival.clone(), i @ PyInt => { let size = objint::get_value(&i.into_object()).to_usize().unwrap(); @@ -94,7 +135,125 @@ impl PyByteInner { } } } +} + +#[derive(FromArgs)] +pub struct ByteInnerFindOptions { + #[pyarg(positional_only, optional = false)] + sub: Either, + #[pyarg(positional_only, optional = true)] + start: OptionalArg>, + #[pyarg(positional_only, optional = true)] + end: OptionalArg>, +} + +impl ByteInnerFindOptions { + pub fn get_value( + self, + elements: &[u8], + vm: &VirtualMachine, + ) -> PyResult<(Vec, Range)> { + let sub = match self.sub { + Either::A(v) => v.elements.to_vec(), + Either::B(int) => vec![int.as_bigint().byte_or(vm)?], + }; + + let start = match self.start { + OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()), + _ => None, + }; + + let end = match self.end { + OptionalArg::Present(Some(int)) => Some(int.as_bigint().clone()), + _ => None, + }; + + let range = elements.to_vec().get_slice_range(&start, &end); + + Ok((sub, range)) + } +} + +#[derive(FromArgs)] +pub struct ByteInnerPaddingOptions { + #[pyarg(positional_only, optional = false)] + width: PyIntRef, + #[pyarg(positional_only, optional = true)] + fillbyte: OptionalArg, +} +impl ByteInnerPaddingOptions { + fn get_value(self, fn_name: &str, len: usize, vm: &VirtualMachine) -> PyResult<(u8, usize)> { + let fillbyte = if let OptionalArg::Present(v) = &self.fillbyte { + match try_as_byte(&v) { + Some(x) => { + if x.len() == 1 { + x[0] + } else { + return Err(vm.new_type_error(format!( + "{}() argument 2 must be a byte string of length 1, not {}", + fn_name, &v + ))); + } + } + None => { + return Err(vm.new_type_error(format!( + "{}() argument 2 must be a byte string of length 1, not {}", + fn_name, &v + ))); + } + } + } else { + b' ' // default is space + }; + + // <0 = no change + let width = if let Some(x) = self.width.as_bigint().to_usize() { + if x <= len { + 0 + } else { + x + } + } else { + 0 + }; + + let diff: usize = if width != 0 { width - len } else { 0 }; + + Ok((fillbyte, diff)) + } +} + +#[derive(FromArgs)] +pub struct ByteInnerTranslateOptions { + #[pyarg(positional_only, optional = false)] + table: Either, + #[pyarg(positional_or_keyword, optional = true)] + delete: OptionalArg, +} + +impl ByteInnerTranslateOptions { + pub fn get_value(self, vm: &VirtualMachine) -> PyResult<(Vec, Vec)> { + let table = match self.table { + Either::A(v) => v.elements.to_vec(), + Either::B(_) => (0..=255).collect::>(), + }; + + if table.len() != 256 { + return Err( + vm.new_value_error("translation table must be 256 characters long".to_string()) + ); + } + + let delete = match self.delete { + OptionalArg::Present(byte) => byte.elements, + _ => vec![], + }; + + Ok((table, delete)) + } +} +impl PyByteInner { pub fn repr(&self) -> PyResult { let mut res = String::with_capacity(self.elements.len()); for i in self.elements.iter() { @@ -174,44 +333,45 @@ impl PyByteInner { elements } - pub fn contains_bytes(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - for (n, i) in self.elements.iter().enumerate() { - if n + other.len() <= self.len() - && *i == other.elements[0] - && &self.elements[n..n + other.len()] == other.elements.as_slice() - { - return Ok(vm.new_bool(true)); - } - } - Ok(vm.new_bool(false)) - } - - pub fn contains_int(&self, int: &PyInt, vm: &VirtualMachine) -> PyResult { - if let Some(int) = int.as_bigint().to_u8() { - if self.elements.contains(&int) { - Ok(vm.new_bool(true)) - } else { + pub fn contains(&self, needle: Either, vm: &VirtualMachine) -> PyResult { + match needle { + Either::A(byte) => { + let other = &byte.elements[..]; + for (n, i) in self.elements.iter().enumerate() { + if n + other.len() <= self.len() + && *i == other[0] + && &self.elements[n..n + other.len()] == other + { + return Ok(vm.new_bool(true)); + } + } Ok(vm.new_bool(false)) } - } else { - Err(vm.new_value_error("byte must be in range(0, 256)".to_string())) + Either::B(int) => { + if self.elements.contains(&int.as_bigint().byte_or(vm)?) { + Ok(vm.new_bool(true)) + } else { + Ok(vm.new_bool(false)) + } + } } } - pub fn getitem_int(&self, int: &PyInt, vm: &VirtualMachine) -> PyResult { - if let Some(idx) = self.elements.get_pos(int.as_bigint().to_i32().unwrap()) { - Ok(vm.new_int(self.elements[idx])) - } else { - Err(vm.new_index_error("index out of range".to_string())) + pub fn getitem(&self, needle: Either, vm: &VirtualMachine) -> PyResult { + match needle { + Either::A(int) => { + if let Some(idx) = self.elements.get_pos(int.as_bigint().to_i32().unwrap()) { + Ok(vm.new_int(self.elements[idx])) + } else { + Err(vm.new_index_error("index out of range".to_string())) + } + } + Either::B(slice) => Ok(vm + .ctx + .new_bytes(self.elements.get_slice_items(vm, slice.as_object())?)), } } - pub fn getitem_slice(&self, slice: &PyObjectRef, vm: &VirtualMachine) -> PyResult { - Ok(vm - .ctx - .new_bytes(self.elements.get_slice_items(vm, slice).unwrap())) - } - pub fn isalnum(&self, vm: &VirtualMachine) -> PyResult { Ok(vm.new_bool( !self.elements.is_empty() @@ -319,6 +479,18 @@ impl PyByteInner { new } + pub fn swapcase(&self, _vm: &VirtualMachine) -> Vec { + let mut new: Vec = Vec::with_capacity(self.elements.len()); + for w in &self.elements { + match w { + 65..=90 => new.push(w.to_ascii_lowercase()), + 97..=122 => new.push(w.to_ascii_uppercase()), + x => new.push(*x), + } + } + new + } + pub fn hex(&self, vm: &VirtualMachine) -> PyResult { let bla = self .elements @@ -328,7 +500,7 @@ impl PyByteInner { Ok(vm.ctx.new_str(bla)) } - pub fn fromhex(string: String, vm: &VirtualMachine) -> Result, PyObjectRef> { + pub fn fromhex(string: &str, vm: &VirtualMachine) -> PyResult> { // first check for invalid character for (i, c) in string.char_indices() { if !c.is_digit(16) && !c.is_whitespace() { @@ -360,14 +532,13 @@ impl PyByteInner { .collect::>()) } - pub fn center(&self, width: &BigInt, fillbyte: u8, _vm: &VirtualMachine) -> Vec { - let width = width.to_usize().unwrap(); + pub fn center( + &self, + options: ByteInnerPaddingOptions, + vm: &VirtualMachine, + ) -> PyResult> { + let (fillbyte, diff) = options.get_value("center", self.len(), vm)?; - // adjust right et left side - if width <= self.len() { - return self.elements.clone(); - } - let diff: usize = width - self.len(); let mut ln: usize = diff / 2; let mut rn: usize = ln; @@ -384,14 +555,231 @@ impl PyByteInner { res.extend_from_slice(&self.elements[..]); res.extend_from_slice(&vec![fillbyte; rn][..]); - res + Ok(res) + } + + pub fn ljust( + &self, + options: ByteInnerPaddingOptions, + vm: &VirtualMachine, + ) -> PyResult> { + let (fillbyte, diff) = options.get_value("ljust", self.len(), vm)?; + + // merge all + let mut res = vec![]; + res.extend_from_slice(&self.elements[..]); + res.extend_from_slice(&vec![fillbyte; diff][..]); + + Ok(res) + } + + pub fn rjust( + &self, + options: ByteInnerPaddingOptions, + vm: &VirtualMachine, + ) -> PyResult> { + let (fillbyte, diff) = options.get_value("rjust", self.len(), vm)?; + + // merge all + let mut res = vec![fillbyte; diff]; + res.extend_from_slice(&self.elements[..]); + + Ok(res) + } + + pub fn count(&self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let (sub, range) = options.get_value(&self.elements, vm)?; + + if sub.is_empty() { + return Ok(self.len() + 1); + } + + let mut total: usize = 0; + let mut i_start = range.start; + let i_end = range.end; + + for i in self.elements.do_slice(range) { + if i_start + sub.len() <= i_end + && i == sub[0] + && &self.elements[i_start..(i_start + sub.len())] == sub.as_slice() + { + total += 1; + } + i_start += 1; + } + Ok(total) + } + + pub fn join(&self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { + let mut refs = vec![]; + for v in iter.iter(vm)? { + let v = v?; + refs.extend(PyByteInner::try_from_object(vm, v)?.elements) + } + + Ok(vm.ctx.new_bytes(refs)) + } + + pub fn startsendswith( + &self, + arg: Either, + start: OptionalArg, + end: OptionalArg, + endswith: bool, // true for endswith, false for startswith + vm: &VirtualMachine, + ) -> PyResult { + let suff = match arg { + Either::A(byte) => byte.elements, + Either::B(tuple) => { + let mut flatten = vec![]; + for v in objsequence::get_elements(tuple.as_object()).to_vec() { + flatten.extend(PyByteInner::try_from_object(vm, v)?.elements) + } + flatten + } + }; + + if suff.is_empty() { + return Ok(vm.new_bool(true)); + } + let range = self.elements.get_slice_range( + &is_valid_slice_arg(start, vm)?, + &is_valid_slice_arg(end, vm)?, + ); + + if range.end - range.start < suff.len() { + return Ok(vm.new_bool(false)); + } + + let offset = if endswith { + (range.end - suff.len())..range.end + } else { + 0..suff.len() + }; + + Ok(vm.new_bool(suff.as_slice() == &self.elements.do_slice(range)[offset])) + } + + pub fn find( + &self, + options: ByteInnerFindOptions, + reverse: bool, + vm: &VirtualMachine, + ) -> PyResult { + let (sub, range) = options.get_value(&self.elements, vm)?; + // not allowed for this method + if range.end < range.start { + return Ok(-1isize); + } + + let start = range.start; + let end = range.end; + + if reverse { + let slice = self.elements.do_slice_reverse(range); + for (n, _) in slice.iter().enumerate() { + if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() { + return Ok((end - n - 1) as isize); + } + } + } else { + let slice = self.elements.do_slice(range); + for (n, _) in slice.iter().enumerate() { + if n + sub.len() <= slice.len() && &slice[n..n + sub.len()] == sub.as_slice() { + return Ok((start + n) as isize); + } + } + }; + Ok(-1isize) + } + + pub fn maketrans(from: PyByteInner, to: PyByteInner, vm: &VirtualMachine) -> PyResult { + let mut res = vec![]; + + for i in 0..=255 { + res.push( + if let Some(position) = from.elements.iter().position(|&x| x == i) { + to.elements[position] + } else { + i + }, + ); + } + + Ok(vm.ctx.new_bytes(res)) + } + + pub fn translate(&self, options: ByteInnerTranslateOptions, vm: &VirtualMachine) -> PyResult { + let (table, delete) = options.get_value(vm)?; + + let mut res = vec![]; + + for i in self.elements.iter() { + if !delete.contains(&i) { + res.push(table[*i as usize]); + } + } + + Ok(vm.ctx.new_bytes(res)) + } + + pub fn strip( + &self, + chars: OptionalArg, + position: ByteInnerPosition, + _vm: &VirtualMachine, + ) -> PyResult> { + let chars = if let OptionalArg::Present(bytes) = chars { + bytes.elements + } else { + vec![b' '] + }; + + let mut start = 0; + let mut end = self.len(); + + if let ByteInnerPosition::Left | ByteInnerPosition::All = position { + for (n, i) in self.elements.iter().enumerate() { + if !chars.contains(i) { + start = n; + break; + } + } + } + + if let ByteInnerPosition::Right | ByteInnerPosition::All = position { + for (n, i) in self.elements.iter().rev().enumerate() { + if !chars.contains(i) { + end = self.len() - n; + break; + } + } + } + Ok(self.elements[start..end].to_vec()) } } -pub fn is_byte(obj: &PyObjectRef) -> Option> { +pub fn try_as_byte(obj: &PyObjectRef) -> Option> { match_class!(obj.clone(), i @ PyBytes => Some(i.get_value().to_vec()), j @ PyByteArray => Some(get_value_bytearray(&j.as_object()).to_vec()), _ => None) } + +pub trait ByteOr: ToPrimitive { + fn byte_or(&self, vm: &VirtualMachine) -> Result { + match self.to_u8() { + Some(value) => Ok(value), + None => Err(vm.new_value_error("byte must be in range(0, 256)".to_string())), + } + } +} + +impl ByteOr for BigInt {} + +pub enum ByteInnerPosition { + Left, + Right, + All, +} diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 0556de0535..c16fd123fe 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -1,17 +1,23 @@ -use crate::obj::objint::PyInt; -use crate::obj::objstr::PyString; +use crate::obj::objint::PyIntRef; +use crate::obj::objslice::PySliceRef; +use crate::obj::objstr::PyStringRef; +use crate::obj::objtuple::PyTupleRef; + +use crate::pyobject::Either; use crate::vm::VirtualMachine; use core::cell::Cell; use std::ops::Deref; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; -use super::objbyteinner::{is_byte, PyByteInner}; +use super::objbyteinner::{ + ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerPosition, + ByteInnerTranslateOptions, PyByteInner, +}; use super::objiter; -use super::objslice::PySlice; -use super::objtype::PyClassRef; +use super::objtype::PyClassRef; /// "bytes(iterable_of_ints) -> bytes\n\ /// bytes(string, encoding[, errors]) -> bytes\n\ /// bytes(bytes_or_buffer) -> immutable copy of bytes_or_buffer\n\ @@ -34,7 +40,6 @@ impl PyBytes { inner: PyByteInner { elements }, } } - pub fn get_value(&self) -> &[u8] { &self.inner.elements } @@ -63,6 +68,8 @@ pub fn init(context: &PyContext) { let bytes_type = &context.bytes_type; extend_class!(context, bytes_type, { "fromhex" => context.new_rustfunc(PyBytesRef::fromhex), + "maketrans" => context.new_rustfunc(PyByteInner::maketrans), + }); let bytesiterator_type = &context.bytesiterator_type; extend_class!(context, bytesiterator_type, { @@ -76,12 +83,11 @@ impl PyBytesRef { #[pymethod(name = "__new__")] fn bytes_new( cls: PyClassRef, - val_option: OptionalArg, - enc_option: OptionalArg, + options: ByteInnerNewOptions, vm: &VirtualMachine, ) -> PyResult { PyBytes { - inner: PyByteInner::new(val_option, enc_option, vm)?, + inner: options.get_value(vm)?, } .into_ref_with_type(vm, cls) } @@ -148,19 +154,13 @@ impl PyBytesRef { } #[pymethod(name = "__contains__")] - fn contains(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(needle, - bytes @ PyBytes => self.inner.contains_bytes(&bytes.inner, vm), - int @ PyInt => self.inner.contains_int(&int, vm), - obj => Err(vm.new_type_error(format!("a bytes-like object is required, not {}", obj)))) + fn contains(self, needle: Either, vm: &VirtualMachine) -> PyResult { + self.inner.contains(needle, vm) } #[pymethod(name = "__getitem__")] - fn getitem(self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(needle, - int @ PyInt => self.inner.getitem_int(&int, vm), - slice @ PySlice => self.inner.getitem_slice(slice.as_object(), vm), - obj => Err(vm.new_type_error(format!("byte indices must be integers or slices, not {}", obj)))) + fn getitem(self, needle: Either, vm: &VirtualMachine) -> PyResult { + self.inner.getitem(needle, vm) } #[pymethod(name = "isalnum")] @@ -218,56 +218,119 @@ impl PyBytesRef { Ok(vm.ctx.new_bytes(self.inner.capitalize(vm))) } + #[pymethod(name = "swapcase")] + fn swapcase(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes(self.inner.swapcase(vm))) + } + #[pymethod(name = "hex")] fn hex(self, vm: &VirtualMachine) -> PyResult { self.inner.hex(vm) } - // #[pymethod(name = "fromhex")] - fn fromhex(string: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(string, - s @ PyString => { - match PyByteInner::fromhex(s.to_string(), vm) { - Ok(x) => Ok(vm.ctx.new_bytes(x)), - Err(y) => Err(y)}}, - obj => Err(vm.new_type_error(format!("fromhex() argument must be str, not {}", obj ))) - ) + fn fromhex(string: PyStringRef, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes(PyByteInner::fromhex(string.as_str(), vm)?)) } #[pymethod(name = "center")] - fn center( + fn center(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes(self.inner.center(options, vm)?)) + } + + #[pymethod(name = "ljust")] + fn ljust(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes(self.inner.ljust(options, vm)?)) + } + + #[pymethod(name = "rjust")] + fn rjust(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes(self.inner.rjust(options, vm)?)) + } + + #[pymethod(name = "count")] + fn count(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.count(options, vm) + } + + #[pymethod(name = "join")] + fn join(self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { + self.inner.join(iter, vm) + } + + #[pymethod(name = "endswith")] + fn endswith( self, - width: PyObjectRef, - fillbyte: OptionalArg, + suffix: Either, + start: OptionalArg, + end: OptionalArg, vm: &VirtualMachine, ) -> PyResult { - let sym = if let OptionalArg::Present(v) = fillbyte { - match is_byte(&v) { - Some(x) => { - if x.len() == 1 { - x[0] - } else { - return Err(vm.new_type_error(format!( - "center() argument 2 must be a byte string of length 1, not {}", - &v - ))); - } - } - None => { - return Err(vm.new_type_error(format!( - "center() argument 2 must be a byte string of length 1, not {}", - &v - ))); - } - } - } else { - 32 // default is space - }; + self.inner.startsendswith(suffix, start, end, true, vm) + } + + #[pymethod(name = "startswith")] + fn startswith( + self, + prefix: Either, + start: OptionalArg, + end: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + self.inner.startsendswith(prefix, start, end, false, vm) + } + + #[pymethod(name = "find")] + fn find(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.find(options, false, vm) + } + + #[pymethod(name = "index")] + fn index(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let res = self.inner.find(options, false, vm)?; + if res == -1 { + return Err(vm.new_value_error("substring not found".to_string())); + } + Ok(res) + } + + #[pymethod(name = "rfind")] + fn rfind(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.find(options, true, vm) + } + + #[pymethod(name = "rindex")] + fn rindex(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let res = self.inner.find(options, true, vm)?; + if res == -1 { + return Err(vm.new_value_error("substring not found".to_string())); + } + Ok(res) + } + + #[pymethod(name = "translate")] + fn translate(self, options: ByteInnerTranslateOptions, vm: &VirtualMachine) -> PyResult { + self.inner.translate(options, vm) + } + + #[pymethod(name = "strip")] + fn strip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytes(self.inner.strip(chars, ByteInnerPosition::All, vm)?)) + } + + #[pymethod(name = "lstrip")] + fn lstrip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytes(self.inner.strip(chars, ByteInnerPosition::Left, vm)?)) + } - match_class!(width, - i @PyInt => Ok(vm.ctx.new_bytes(self.inner.center(i.as_bigint(), sym, vm))), - obj => {Err(vm.new_type_error(format!("{} cannot be interpreted as an integer", obj)))} - ) + #[pymethod(name = "rstrip")] + fn rstrip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytes(self.inner.strip(chars, ByteInnerPosition::Right, vm)?)) } } diff --git a/vm/src/obj/objmemory.rs b/vm/src/obj/objmemory.rs index cfe44dea38..88040ae14a 100644 --- a/vm/src/obj/objmemory.rs +++ b/vm/src/obj/objmemory.rs @@ -1,3 +1,4 @@ +use crate::obj::objbyteinner::try_as_byte; use crate::obj::objtype::PyClassRef; use crate::pyobject::{PyContext, PyObjectRef, PyRef, PyResult, PyValue}; use crate::vm::VirtualMachine; @@ -9,6 +10,12 @@ pub struct PyMemoryView { obj: PyObjectRef, } +impl PyMemoryView { + pub fn get_obj_value(&self) -> Option> { + try_as_byte(&self.obj) + } +} + impl PyValue for PyMemoryView { fn class(vm: &VirtualMachine) -> PyClassRef { vm.ctx.memoryview_type() diff --git a/vm/src/obj/objsequence.rs b/vm/src/obj/objsequence.rs index 5594ac8586..af20343708 100644 --- a/vm/src/obj/objsequence.rs +++ b/vm/src/obj/objsequence.rs @@ -1,3 +1,5 @@ +use crate::function::OptionalArg; +use crate::obj::objnone::PyNone; use std::cell::RefCell; use std::marker::Sized; use std::ops::{Deref, DerefMut, Range}; @@ -371,3 +373,20 @@ pub fn get_mut_elements<'a>(obj: &'a PyObjectRef) -> impl DerefMut, + vm: &VirtualMachine, +) -> Result, PyObjectRef> { + if let OptionalArg::Present(value) = arg { + match_class!(value, + i @ PyInt => Ok(Some(i.as_bigint().clone())), + _obj @ PyNone => Ok(None), + _=> {return Err(vm.new_type_error("slice indices must be integers or None or have an __index__ method".to_string()));} + // TODO: check for an __index__ method + ) + } else { + Ok(None) + } +}