From 5864c6e063ab2d2f8f5378029def87a5c242e496 Mon Sep 17 00:00:00 2001 From: jgirardet Date: Sat, 4 May 2019 01:14:08 +0200 Subject: [PATCH] rewrite pybytearray with pybyteinner --- tests/snippets/bytearray.py | 384 ++++++++++++++++++++++++--- tests/snippets/bytes.py | 3 +- vm/src/obj/objbytearray.rs | 500 +++++++++++++++++++++--------------- vm/src/obj/objbyteinner.rs | 55 ++-- vm/src/obj/objbytes.rs | 34 +-- vm/src/stdlib/io.rs | 8 +- 6 files changed, 695 insertions(+), 289 deletions(-) diff --git a/tests/snippets/bytearray.py b/tests/snippets/bytearray.py index 4286f81710..4f29fa1877 100644 --- a/tests/snippets/bytearray.py +++ b/tests/snippets/bytearray.py @@ -1,53 +1,368 @@ -#__getitem__ not implemented yet -#a = bytearray(b'abc') -#assert a[0] == b'a' -#assert a[1] == b'b' +from testutils import assertRaises -assert len(bytearray([1,2,3])) == 3 +# new +assert bytearray([1, 2, 3]) +assert bytearray((1, 2, 3)) +assert bytearray(range(4)) +assert bytearray(3) +assert b"bla" +assert ( + bytearray("bla", "utf8") == bytearray("bla", encoding="utf-8") == bytearray(b"bla") +) +with assertRaises(TypeError): + bytearray("bla") +with assertRaises(TypeError): + bytearray("bla", encoding=b"jilj") -assert bytearray(b'1a23').isalnum() -assert not bytearray(b'1%a23').isalnum() +assert bytearray( + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" +) == bytearray(range(0, 256)) +assert bytearray( + b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\x7f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" +) == bytearray(range(0, 256)) +assert bytearray(b"omkmok\Xaa") == bytearray( + [111, 109, 107, 109, 111, 107, 92, 88, 97, 97] +) -assert bytearray(b'abc').isalpha() -assert not bytearray(b'abc1').isalpha() + +a = bytearray(b"abcd") +b = bytearray(b"ab") +c = bytearray(b"abcd") + + +# repr +assert repr(bytearray([0, 1, 2])) == repr(bytearray(b"\x00\x01\x02")) +assert ( + repr(bytearray([0, 1, 9, 10, 11, 13, 31, 32, 33, 89, 120, 255])) + == "bytearray(b'\\x00\\x01\\t\\n\\x0b\\r\\x1f !Yx\\xff')" +) +assert repr(bytearray(b"abcd")) == "bytearray(b'abcd')" + +# len +assert len(bytearray("abcdé", "utf8")) == 6 + +# comp +assert a == b"abcd" +assert a > b +assert a >= b +assert b < a +assert b <= a + +assert bytearray(b"foobar").__eq__(2) == NotImplemented +assert bytearray(b"foobar").__ne__(2) == NotImplemented +assert bytearray(b"foobar").__gt__(2) == NotImplemented +assert bytearray(b"foobar").__ge__(2) == NotImplemented +assert bytearray(b"foobar").__lt__(2) == NotImplemented +assert bytearray(b"foobar").__le__(2) == NotImplemented + +# # hash +with assertRaises(TypeError): + hash(bytearray(b"abcd")) # unashable + +# # iter +[i for i in bytearray(b"abcd")] == ["a", "b", "c", "d"] +assert list(bytearray(3)) == [0, 0, 0] + +# add +assert a + b == bytearray(b"abcdab") + +# contains +assert bytearray(b"ab") in bytearray(b"abcd") +assert bytearray(b"cd") in bytearray(b"abcd") +assert bytearray(b"abcd") in bytearray(b"abcd") +assert bytearray(b"a") in bytearray(b"abcd") +assert bytearray(b"d") in bytearray(b"abcd") +assert bytearray(b"dc") not in bytearray(b"abcd") +assert 97 in bytearray(b"abcd") +assert 150 not in bytearray(b"abcd") +with assertRaises(ValueError): + 350 in bytearray(b"abcd") + + +# getitem +d = bytearray(b"abcdefghij") + +assert d[1] == 98 +assert d[-1] == 106 +assert d[2:6] == bytearray(b"cdef") +assert d[-6:] == bytearray(b"efghij") +assert d[1:8:2] == bytearray(b"bdfh") +assert d[8:1:-2] == bytearray(b"igec") + + +# # is_xx methods + +assert bytearray(b"1a23").isalnum() +assert not bytearray(b"1%a23").isalnum() + +assert bytearray(b"abc").isalpha() +assert not bytearray(b"abc1").isalpha() # travis doesn't like this -#assert bytearray(b'xyz').isascii() -#assert not bytearray([128, 157, 32]).isascii() +# assert bytearray(b'xyz').isascii() +# assert not bytearray([128, 157, 32]).isascii() + +assert bytearray(b"1234567890").isdigit() +assert not bytearray(b"12ab").isdigit() -assert bytearray(b'1234567890').isdigit() -assert not bytearray(b'12ab').isdigit() +l = bytearray(b"lower") +b = bytearray(b"UPPER") -l = bytearray(b'lower') assert l.islower() assert not l.isupper() -assert l.upper().isupper() -assert not bytearray(b'Super Friends').islower() +assert b.isupper() +assert not bytearray(b"Super Friends").islower() -assert bytearray(b' \n\t').isspace() -assert not bytearray(b'\td\n').isspace() +assert bytearray(b" \n\t").isspace() +assert not bytearray(b"\td\n").isspace() -b = bytearray(b'UPPER') assert b.isupper() assert not b.islower() -assert b.lower().islower() -assert not bytearray(b'tuPpEr').isupper() +assert l.islower() +assert not bytearray(b"tuPpEr").isupper() -assert bytearray(b'Is Title Case').istitle() -assert not bytearray(b'is Not title casE').istitle() +assert bytearray(b"Is Title Case").istitle() +assert not bytearray(b"is Not title casE").istitle() -a = bytearray(b'abcd') -a.clear() -assert len(a) == 0 +# upper lower, capitalize, swapcase +l = bytearray(b"lower") +b = bytearray(b"UPPER") +assert l.lower().islower() +assert b.upper().isupper() +assert l.capitalize() == b"Lower" +assert b.capitalize() == b"Upper" +assert bytearray().capitalize() == bytearray() +assert b"AaBbCc123'@/".swapcase().swapcase() == b"AaBbCc123'@/" +assert b"AaBbCc123'@/".swapcase() == b"aAbBcC123'@/" + +# # hex from hex +assert bytearray([0, 1, 9, 23, 90, 234]).hex() == "000109175aea" +bytearray.fromhex("62 6c7a 34350a ") == b"blz45\n" try: - bytearray([400]) -except ValueError: - pass -else: - assert False + bytearray.fromhex("62 a 21") +except ValueError as e: + str(e) == "non-hexadecimal number found in fromhex() arg at position 4" +try: + bytearray.fromhex("6Z2") +except ValueError as e: + str(e) == "non-hexadecimal number found in fromhex() arg at position 1" +with assertRaises(TypeError): + bytearray.fromhex(b"hhjjk") +# center +assert [bytearray(b"koki").center(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"|koki", + b"|koki|", + b"||koki|", + b"||koki||", + b"|||koki||", +] + +assert [bytearray(b"kok").center(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"kok|", + b"|kok|", + b"|kok||", + b"||kok||", + b"||kok|||", + b"|||kok|||", +] +bytearray(b"kok").center(4) == b" kok" # " test no arg" +with assertRaises(TypeError): + bytearray(b"b").center(2, "a") +with assertRaises(TypeError): + bytearray(b"b").center(2, b"ba") +with assertRaises(TypeError): + bytearray(b"b").center(b"ba") +assert bytearray(b"kok").center(5, bytearray(b"x")) == b"xkokx" +bytearray(b"kok").center(-5) == b"kok" -b = bytearray(b'test') + +# ljust +assert [bytearray(b"koki").ljust(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"koki|", + b"koki||", + b"koki|||", + b"koki||||", + b"koki|||||", +] +assert [bytearray(b"kok").ljust(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"kok|", + b"kok||", + b"kok|||", + b"kok||||", + b"kok|||||", + b"kok||||||", +] + +bytearray(b"kok").ljust(4) == b"kok " # " test no arg" +with assertRaises(TypeError): + bytearray(b"b").ljust(2, "a") +with assertRaises(TypeError): + bytearray(b"b").ljust(2, b"ba") +with assertRaises(TypeError): + bytearray(b"b").ljust(b"ba") +assert bytearray(b"kok").ljust(5, bytearray(b"x")) == b"kokxx" +assert bytearray(b"kok").ljust(-5) == b"kok" + +# rjust +assert [bytearray(b"koki").rjust(i, b"|") for i in range(3, 10)] == [ + b"koki", + b"koki", + b"|koki", + b"||koki", + b"|||koki", + b"||||koki", + b"|||||koki", +] +assert [bytearray(b"kok").rjust(i, b"|") for i in range(2, 10)] == [ + b"kok", + b"kok", + b"|kok", + b"||kok", + b"|||kok", + b"||||kok", + b"|||||kok", + b"||||||kok", +] + + +bytearray(b"kok").rjust(4) == b" kok" # " test no arg" +with assertRaises(TypeError): + bytearray(b"b").rjust(2, "a") +with assertRaises(TypeError): + bytearray(b"b").rjust(2, b"ba") +with assertRaises(TypeError): + bytearray(b"b").rjust(b"ba") +assert bytearray(b"kok").rjust(5, bytearray(b"x")) == b"xxkok" +assert bytearray(b"kok").rjust(-5) == b"kok" + + +# count +assert bytearray(b"azeazerazeazopia").count(b"aze") == 3 +assert bytearray(b"azeazerazeazopia").count(b"az") == 4 +assert bytearray(b"azeazerazeazopia").count(b"a") == 5 +assert bytearray(b"123456789").count(b"") == 10 +assert bytearray(b"azeazerazeazopia").count(bytearray(b"aze")) == 3 +assert bytearray(b"azeazerazeazopia").count(memoryview(b"aze")) == 3 +assert bytearray(b"azeazerazeazopia").count(memoryview(b"aze"), 1, 9) == 1 +assert bytearray(b"azeazerazeazopia").count(b"aze", None, None) == 3 +assert bytearray(b"azeazerazeazopia").count(b"aze", 2, None) == 2 +assert bytearray(b"azeazerazeazopia").count(b"aze", 2) == 2 +assert bytearray(b"azeazerazeazopia").count(b"aze", None, 7) == 2 +assert bytearray(b"azeazerazeazopia").count(b"aze", None, 7) == 2 +assert bytearray(b"azeazerazeazopia").count(b"aze", 2, 7) == 1 +assert bytearray(b"azeazerazeazopia").count(b"aze", -13, -10) == 1 +assert bytearray(b"azeazerazeazopia").count(b"aze", 1, 10000) == 2 +with assertRaises(ValueError): + bytearray(b"ilj").count(3550) +assert bytearray(b"azeazerazeazopia").count(97) == 5 + +# join +assert bytearray(b"").join( + (b"jiljl", bytearray(b"kmoomk"), memoryview(b"aaaa")) +) == bytearray(b"jiljlkmoomkaaaa") +with assertRaises(TypeError): + bytearray(b"").join((b"km", "kl")) + + +# endswith startswith +assert bytearray(b"abcde").endswith(b"de") +assert bytearray(b"abcde").endswith(b"") +assert not bytearray(b"abcde").endswith(b"zx") +assert bytearray(b"abcde").endswith(b"bc", 0, 3) +assert not bytearray(b"abcde").endswith(b"bc", 2, 3) +assert bytearray(b"abcde").endswith((b"c", bytearray(b"de"))) + +assert bytearray(b"abcde").startswith(b"ab") +assert bytearray(b"abcde").startswith(b"") +assert not bytearray(b"abcde").startswith(b"zx") +assert bytearray(b"abcde").startswith(b"cd", 2) +assert not bytearray(b"abcde").startswith(b"cd", 1, 4) +assert bytearray(b"abcde").startswith((b"a", bytearray(b"bc"))) + + +# index find +assert bytearray(b"abcd").index(b"cd") == 2 +assert bytearray(b"abcd").index(b"cd", 0) == 2 +assert bytearray(b"abcd").index(b"cd", 1) == 2 +assert bytearray(b"abcd").index(99) == 2 +with assertRaises(ValueError): + bytearray(b"abcde").index(b"c", 3, 1) +with assertRaises(ValueError): + bytearray(b"abcd").index(b"cdaaaaa") +with assertRaises(ValueError): + bytearray(b"abcd").index(b"b", 3, 4) +with assertRaises(ValueError): + bytearray(b"abcd").index(1) + + +assert bytearray(b"abcd").find(b"cd") == 2 +assert bytearray(b"abcd").find(b"cd", 0) == 2 +assert bytearray(b"abcd").find(b"cd", 1) == 2 +assert bytearray(b"abcde").find(b"c", 3, 1) == -1 +assert bytearray(b"abcd").find(b"cdaaaaa") == -1 +assert bytearray(b"abcd").find(b"b", 3, 4) == -1 +assert bytearray(b"abcd").find(1) == -1 +assert bytearray(b"abcd").find(99) == 2 + +assert bytearray(b"abcdabcda").find(b"a") == 0 +assert bytearray(b"abcdabcda").rfind(b"a") == 8 +assert bytearray(b"abcdabcda").rfind(b"a", 2, 6) == 4 +assert bytearray(b"abcdabcda").rfind(b"a", None, 6) == 4 +assert bytearray(b"abcdabcda").rfind(b"a", 2, None) == 8 +assert bytearray(b"abcdabcda").index(b"a") == 0 +assert bytearray(b"abcdabcda").rindex(b"a") == 8 + + +# make trans +# fmt: off +assert ( + bytearray.maketrans(memoryview(b"abc"), bytearray(b"zzz")) + == bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 122, 122, 122, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255]) +) +# fmt: on + +# translate +assert bytearray(b"hjhtuyjyujuyj").translate( + bytearray.maketrans(b"hj", bytearray(b"ab")), bytearray(b"h") +) == bytearray(b"btuybyubuyb") +assert bytearray(b"hjhtuyjyujuyj").translate( + bytearray.maketrans(b"hj", bytearray(b"ab")), bytearray(b"a") +) == bytearray(b"abatuybyubuyb") +assert bytearray(b"hjhtuyjyujuyj").translate( + bytearray.maketrans(b"hj", bytearray(b"ab")) +) == bytearray(b"abatuybyubuyb") +assert bytearray(b"hjhtuyfjtyhuhjuyj").translate(None, bytearray(b"ht")) == bytearray( + b"juyfjyujuyj" +) +assert bytearray(b"hjhtuyfjtyhuhjuyj").translate(None, delete=b"ht") == bytearray( + b"juyfjyujuyj" +) + + +# strip lstrip rstrip +assert bytearray(b" spacious ").strip() == bytearray(b"spacious") +assert bytearray(b"www.example.com").strip(b"cmowz.") == bytearray(b"example") +assert bytearray(b" spacious ").lstrip() == bytearray(b"spacious ") +assert bytearray(b"www.example.com").lstrip(b"cmowz.") == bytearray(b"example.com") +assert bytearray(b" spacious ").rstrip() == bytearray(b" spacious") +assert bytearray(b"mississippi").rstrip(b"ipz") == bytearray(b"mississ") + + +# clear +a = bytearray(b"abcd") +a.clear() +assert len(a) == 0 + +b = bytearray(b"test") assert len(b) == 4 b.pop() assert len(b) == 3 @@ -66,8 +381,11 @@ else: assert False -a = bytearray(b'appen') +a = bytearray(b"appen") assert len(a) == 5 a.append(100) +assert a == bytearray(b"append") assert len(a) == 6 assert a.pop() == 100 + +import bytes as bbytes diff --git a/tests/snippets/bytes.py b/tests/snippets/bytes.py index 2733533d8f..7086a1423e 100644 --- a/tests/snippets/bytes.py +++ b/tests/snippets/bytes.py @@ -175,7 +175,8 @@ with assertRaises(TypeError): b"b".center(b"ba") assert b"kok".center(5, bytearray(b"x")) == b"xkokx" -b"kok".center(-5) +b"kok".center(-5) == b"kok" + # ljust diff --git a/vm/src/obj/objbytearray.rs b/vm/src/obj/objbytearray.rs index 6f78ecf481..4336243399 100644 --- a/vm/src/obj/objbytearray.rs +++ b/vm/src/obj/objbytearray.rs @@ -1,32 +1,62 @@ //! Implementation of the python bytearray object. -use std::cell::{Cell, RefCell}; -use std::fmt::Write; -use std::ops::{Deref, DerefMut}; - -use num_traits::ToPrimitive; - use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::obj::objbyteinner::{ + ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerPosition, + ByteInnerTranslateOptions, ByteOr, PyByteInner, +}; +use crate::obj::objint::PyIntRef; +use crate::obj::objslice::PySliceRef; +use crate::obj::objstr::PyStringRef; +use crate::obj::objtuple::PyTupleRef; +use crate::pyobject::{ + Either, PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, + TryFromObject, +}; use crate::vm::VirtualMachine; +use std::cell::{Cell, RefCell}; -use super::objint; use super::objiter; use super::objtype::PyClassRef; -#[derive(Debug)] +/// "bytearray(iterable_of_ints) -> bytearray\n\ +/// bytearray(string, encoding[, errors]) -> bytearray\n\ +/// bytearray(bytes_or_buffer) -> mutable copy of bytes_or_buffer\n\ +/// bytearray(int) -> bytes array of size given by the parameter initialized with null bytes\n\ +/// bytearray() -> empty bytes array\n\n\ +/// Construct a mutable bytearray object from:\n \ +/// - an iterable yielding integers in range(256)\n \ +/// - a text string encoded using the specified encoding\n \ +/// - a bytes or a buffer object\n \ +/// - any object implementing the buffer API.\n \ +/// - an integer"; +#[pyclass(name = "bytearray")] +#[derive(Clone, Debug)] pub struct PyByteArray { - // TODO: shouldn't be public - pub value: RefCell>, + pub inner: RefCell, } type PyByteArrayRef = PyRef; impl PyByteArray { pub fn new(data: Vec) -> Self { PyByteArray { - value: RefCell::new(data), + inner: RefCell::new(PyByteInner { elements: data }), } } + + pub fn from_inner(inner: PyByteInner) -> Self { + PyByteArray { + inner: RefCell::new(inner), + } + } + + // pub fn get_value(&self) -> Vec { + // self.inner.borrow().clone().elements + // } + + // pub fn get_value_mut(&self) -> Vec { + // self.inner.borrow_mut().clone().elements + // } } impl PyValue for PyByteArray { @@ -35,254 +65,318 @@ impl PyValue for PyByteArray { } } -pub fn get_value<'a>(obj: &'a PyObjectRef) -> impl Deref> + 'a { - obj.payload::().unwrap().value.borrow() -} +// pub fn get_value(obj: &PyObjectRef) -> Vec { +// obj.payload::().unwrap().get_value() +// } -pub fn get_mut_value<'a>(obj: &'a PyObjectRef) -> impl DerefMut> + 'a { - obj.payload::().unwrap().value.borrow_mut() -} - -// Binary data support +// pub fn get_value_mut(obj: &PyObjectRef) -> Vec { +// obj.payload::().unwrap().get_value_mut() +// } /// Fill bytearray class methods dictionary. pub fn init(context: &PyContext) { + PyByteArrayRef::extend_class(context, &context.bytearray_type); let bytearray_type = &context.bytearray_type; - - let bytearray_doc = - "bytearray(iterable_of_ints) -> bytearray\n\ - bytearray(string, encoding[, errors]) -> bytearray\n\ - bytearray(bytes_or_buffer) -> mutable copy of bytes_or_buffer\n\ - bytearray(int) -> bytes array of size given by the parameter initialized with null bytes\n\ - bytearray() -> empty bytes array\n\n\ - Construct a mutable bytearray object from:\n \ - - an iterable yielding integers in range(256)\n \ - - a text string encoded using the specified encoding\n \ - - a bytes or a buffer object\n \ - - any object implementing the buffer API.\n \ - - an integer"; - extend_class!(context, bytearray_type, { - "__doc__" => context.new_str(bytearray_doc.to_string()), - "__new__" => context.new_rustfunc(bytearray_new), - "__eq__" => context.new_rustfunc(PyByteArrayRef::eq), - "__len__" => context.new_rustfunc(PyByteArrayRef::len), - "__repr__" => context.new_rustfunc(PyByteArrayRef::repr), - "__iter__" => context.new_rustfunc(PyByteArrayRef::iter), - "clear" => context.new_rustfunc(PyByteArrayRef::clear), - "isalnum" => context.new_rustfunc(PyByteArrayRef::isalnum), - "isalpha" => context.new_rustfunc(PyByteArrayRef::isalpha), - "isascii" => context.new_rustfunc(PyByteArrayRef::isascii), - "isdigit" => context.new_rustfunc(PyByteArrayRef::isdigit), - "islower" => context.new_rustfunc(PyByteArrayRef::islower), - "isspace" => context.new_rustfunc(PyByteArrayRef::isspace), - "istitle" =>context.new_rustfunc(PyByteArrayRef::istitle), - "isupper" => context.new_rustfunc(PyByteArrayRef::isupper), - "lower" => context.new_rustfunc(PyByteArrayRef::lower), - "append" => context.new_rustfunc(PyByteArrayRef::append), - "pop" => context.new_rustfunc(PyByteArrayRef::pop), - "upper" => context.new_rustfunc(PyByteArrayRef::upper) + "fromhex" => context.new_rustfunc(PyByteArrayRef::fromhex), + "maketrans" => context.new_rustfunc(PyByteInner::maketrans), }); PyByteArrayIterator::extend_class(context, &context.bytearrayiterator_type); } -fn bytearray_new( - cls: PyClassRef, - val_option: OptionalArg, - vm: &VirtualMachine, -) -> PyResult { - // Create bytes data: - let value = if let OptionalArg::Present(ival) = val_option { - let elements = vm.extract_elements(&ival)?; - let mut data_bytes = vec![]; - for elem in elements.iter() { - let v = objint::to_int(vm, elem, 10)?; - if let Some(i) = v.to_u8() { - data_bytes.push(i); - } else { - return Err(vm.new_value_error("byte must be in range(0, 256)".to_string())); - } - } - data_bytes - // return Err(vm.new_type_error("Cannot construct bytes".to_string())); - } else { - vec![] - }; - PyByteArray::new(value).into_ref_with_type(vm, cls.clone()) -} - +#[pyimpl] impl PyByteArrayRef { + #[pymethod(name = "__new__")] + fn bytearray_new( + cls: PyClassRef, + options: ByteInnerNewOptions, + vm: &VirtualMachine, + ) -> PyResult { + PyByteArray::from_inner(options.get_value(vm)?).into_ref_with_type(vm, cls) + } + + #[pymethod(name = "__repr__")] + fn repr(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.new_str(format!("bytearray(b'{}')", self.inner.borrow().repr()?))) + } + + #[pymethod(name = "__len__")] fn len(self, _vm: &VirtualMachine) -> usize { - self.value.borrow().len() + self.inner.borrow().len() + } + + #[pymethod(name = "__eq__")] + fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().eq(other, vm) + } + + #[pymethod(name = "__ge__")] + fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().ge(other, vm) + } + + #[pymethod(name = "__le__")] + fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().le(other, vm) + } + + #[pymethod(name = "__gt__")] + fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().gt(other, vm) } - fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyObjectRef { - if let Ok(other) = other.downcast::() { - vm.ctx - .new_bool(self.value.borrow().as_slice() == other.value.borrow().as_slice()) + #[pymethod(name = "__lt__")] + fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().lt(other, vm) + } + + #[pymethod(name = "__hash__")] + fn hash(self, vm: &VirtualMachine) -> PyResult { + Err(vm.new_type_error("unhashable type: bytearray".to_string())) + } + + #[pymethod(name = "__iter__")] + fn iter(self, _vm: &VirtualMachine) -> PyByteArrayIterator { + PyByteArrayIterator { + position: Cell::new(0), + bytearray: self, + } + } + + #[pymethod(name = "__add__")] + fn add(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.ctx.new_bytearray(self.inner.borrow().add(other))) } else { - vm.ctx.not_implemented() + Ok(vm.ctx.not_implemented()) } } - fn isalnum(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() && bytes.iter().all(|x| char::from(*x).is_alphanumeric()) + #[pymethod(name = "__contains__")] + fn contains(self, needle: Either, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().contains(needle, vm) } - fn isalpha(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() && bytes.iter().all(|x| char::from(*x).is_alphabetic()) + #[pymethod(name = "__getitem__")] + fn getitem(self, needle: Either, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().getitem(needle, vm) } - fn isascii(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() && bytes.iter().all(|x| char::from(*x).is_ascii()) + #[pymethod(name = "isalnum")] + fn isalnum(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isalnum(vm) } - fn isdigit(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() && bytes.iter().all(|x| char::from(*x).is_digit(10)) + #[pymethod(name = "isalpha")] + fn isalpha(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isalpha(vm) } - fn islower(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() - && bytes - .iter() - .filter(|x| !char::from(**x).is_whitespace()) - .all(|x| char::from(*x).is_lowercase()) + #[pymethod(name = "isascii")] + fn isascii(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isascii(vm) } - fn isspace(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() && bytes.iter().all(|x| char::from(*x).is_whitespace()) + #[pymethod(name = "isdigit")] + fn isdigit(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isdigit(vm) } - fn isupper(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - !bytes.is_empty() - && bytes - .iter() - .filter(|x| !char::from(**x).is_whitespace()) - .all(|x| char::from(*x).is_uppercase()) + #[pymethod(name = "islower")] + fn islower(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().islower(vm) } - fn istitle(self, _vm: &VirtualMachine) -> bool { - let bytes = self.value.borrow(); - if bytes.is_empty() { - return false; - } + #[pymethod(name = "isspace")] + fn isspace(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isspace(vm) + } - let mut iter = bytes.iter().peekable(); - let mut prev_cased = false; - - while let Some(c) = iter.next() { - let current = char::from(*c); - let next = if let Some(k) = iter.peek() { - char::from(**k) - } else if current.is_uppercase() { - return !prev_cased; - } else { - return prev_cased; - }; - - if (is_cased(current) && next.is_uppercase() && !prev_cased) - || (!is_cased(current) && next.is_lowercase()) - { - return false; - } - - prev_cased = is_cased(current); - } + #[pymethod(name = "isupper")] + fn isupper(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().isupper(vm) + } - true + #[pymethod(name = "istitle")] + fn istitle(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().istitle(vm) } - fn repr(self, _vm: &VirtualMachine) -> String { - let bytes = self.value.borrow(); - let data = String::from_utf8(bytes.to_vec()).unwrap_or_else(|_| to_hex(&bytes.to_vec())); - format!("bytearray(b'{}')", data) + #[pymethod(name = "lower")] + fn lower(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytearray(self.inner.borrow().lower(vm))) } - fn clear(self, _vm: &VirtualMachine) { - self.value.borrow_mut().clear(); + #[pymethod(name = "upper")] + fn upper(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytearray(self.inner.borrow().upper(vm))) } - fn append(self, x: u8, _vm: &VirtualMachine) { - self.value.borrow_mut().push(x); + #[pymethod(name = "capitalize")] + fn capitalize(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytearray(self.inner.borrow().capitalize(vm))) } - fn pop(self, vm: &VirtualMachine) -> PyResult { - let mut bytes = self.value.borrow_mut(); - bytes - .pop() - .ok_or_else(|| vm.new_index_error("pop from empty bytearray".to_string())) + #[pymethod(name = "swapcase")] + fn swapcase(self, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytearray(self.inner.borrow().swapcase(vm))) } - fn lower(self, _vm: &VirtualMachine) -> PyByteArray { - let bytes = self.value.borrow().clone().to_ascii_lowercase(); - PyByteArray { - value: RefCell::new(bytes), - } + #[pymethod(name = "hex")] + fn hex(self, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().hex(vm) } - fn upper(self, _vm: &VirtualMachine) -> PyByteArray { - let bytes = self.value.borrow().clone().to_ascii_uppercase(); - PyByteArray { - value: RefCell::new(bytes), + fn fromhex(string: PyStringRef, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytearray(PyByteInner::fromhex(string.as_str(), vm)?)) + } + + #[pymethod(name = "center")] + fn center(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytearray(self.inner.borrow().center(options, vm)?)) + } + + #[pymethod(name = "ljust")] + fn ljust(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytearray(self.inner.borrow().ljust(options, vm)?)) + } + + #[pymethod(name = "rjust")] + fn rjust(self, options: ByteInnerPaddingOptions, vm: &VirtualMachine) -> PyResult { + Ok(vm + .ctx + .new_bytearray(self.inner.borrow().rjust(options, vm)?)) + } + + #[pymethod(name = "count")] + fn count(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().count(options, vm) + } + + #[pymethod(name = "join")] + fn join(self, iter: PyIterable, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().join(iter, vm) + } + + #[pymethod(name = "endswith")] + fn endswith( + self, + suffix: Either, + start: OptionalArg, + end: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + self.inner + .borrow() + .startsendswith(suffix, start, end, true, vm) + } + + #[pymethod(name = "startswith")] + fn startswith( + self, + prefix: Either, + start: OptionalArg, + end: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { + self.inner + .borrow() + .startsendswith(prefix, start, end, false, vm) + } + + #[pymethod(name = "find")] + fn find(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().find(options, false, vm) + } + + #[pymethod(name = "index")] + fn index(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let res = self.inner.borrow().find(options, false, vm)?; + if res == -1 { + return Err(vm.new_value_error("substring not found".to_string())); } + Ok(res) } - fn iter(self, _vm: &VirtualMachine) -> PyByteArrayIterator { - PyByteArrayIterator { - position: Cell::new(0), - bytearray: self, + #[pymethod(name = "rfind")] + fn rfind(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().find(options, true, vm) + } + + #[pymethod(name = "rindex")] + fn rindex(self, options: ByteInnerFindOptions, vm: &VirtualMachine) -> PyResult { + let res = self.inner.borrow().find(options, true, vm)?; + if res == -1 { + return Err(vm.new_value_error("substring not found".to_string())); } + Ok(res) } -} -// helper function for istitle -fn is_cased(c: char) -> bool { - c.to_uppercase().next().unwrap() != c || c.to_lowercase().next().unwrap() != c -} + #[pymethod(name = "translate")] + fn translate(self, options: ByteInnerTranslateOptions, vm: &VirtualMachine) -> PyResult { + self.inner.borrow().translate(options, vm) + } -/* -fn getitem(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { - arg_check!( - vm, - args, - required = [(obj, Some(vm.ctx.bytearray_type())), (needle, None)] - ); - let elements = get_elements(obj); - get_item(vm, list, &, needle.clone()) -} -*/ -/* -fn set_value(obj: &PyObjectRef, value: Vec) { - obj.borrow_mut().kind = PyObjectPayload::Bytes { value }; -} -*/ - -/// Return a lowercase hex representation of a bytearray -fn to_hex(bytearray: &[u8]) -> String { - bytearray.iter().fold(String::new(), |mut s, b| { - let _ = write!(s, "\\x{:02x}", b); - s - }) -} + #[pymethod(name = "strip")] + fn strip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes( + self.inner + .borrow() + .strip(chars, ByteInnerPosition::All, vm)?, + )) + } -#[cfg(test)] -mod tests { - use super::*; + #[pymethod(name = "lstrip")] + fn lstrip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes( + self.inner + .borrow() + .strip(chars, ByteInnerPosition::Left, vm)?, + )) + } + + #[pymethod(name = "rstrip")] + fn rstrip(self, chars: OptionalArg, vm: &VirtualMachine) -> PyResult { + Ok(vm.ctx.new_bytes( + self.inner + .borrow() + .strip(chars, ByteInnerPosition::Right, vm)?, + )) + } + + #[pymethod(name = "clear")] + fn clear(self, _vm: &VirtualMachine) { + self.inner.borrow_mut().elements.clear(); + } - #[test] - fn bytearray_to_hex_formatting() { - assert_eq!(&to_hex(&[11u8, 222u8]), "\\x0b\\xde"); + #[pymethod(name = "append")] + fn append(self, x: PyIntRef, vm: &VirtualMachine) -> Result<(), PyObjectRef> { + self.inner + .borrow_mut() + .elements + .push(x.as_bigint().byte_or(vm)?); + Ok(()) + } + #[pymethod(name = "pop")] + fn pop(self, vm: &VirtualMachine) -> PyResult { + let bytes = &mut self.inner.borrow_mut().elements; + bytes + .pop() + .ok_or_else(|| vm.new_index_error("pop from empty bytearray".to_string())) } } +// fn set_value(obj: &PyObjectRef, value: Vec) { +// obj.borrow_mut().kind = PyObjectPayload::Bytes { value }; +// } + #[pyclass] #[derive(Debug)] pub struct PyByteArrayIterator { @@ -300,8 +394,8 @@ impl PyValue for PyByteArrayIterator { impl PyByteArrayIterator { #[pymethod(name = "__next__")] fn next(&self, vm: &VirtualMachine) -> PyResult { - if self.position.get() < self.bytearray.value.borrow().len() { - let ret = self.bytearray.value.borrow()[self.position.get()]; + if self.position.get() < self.bytearray.inner.borrow().len() { + let ret = self.bytearray.inner.borrow().elements[self.position.get()]; self.position.set(self.position.get() + 1); Ok(ret) } else { diff --git a/vm/src/obj/objbyteinner.rs b/vm/src/obj/objbyteinner.rs index b75f8be697..9e3473b939 100644 --- a/vm/src/obj/objbyteinner.rs +++ b/vm/src/obj/objbyteinner.rs @@ -27,7 +27,7 @@ use crate::obj::objint::PyInt; use num_integer::Integer; use num_traits::ToPrimitive; -use super::objbytearray::{get_value as get_value_bytearray, PyByteArray}; +use super::objbytearray::PyByteArray; use super::objbytes::PyBytes; use super::objmemory::PyMemoryView; @@ -43,7 +43,7 @@ impl TryFromObject for PyByteInner { match_class!(obj, i @ PyBytes => Ok(PyByteInner{elements: i.get_value().to_vec()}), - j @ PyByteArray => Ok(PyByteInner{elements: get_value_bytearray(&j.as_object()).to_vec()}), + j @ PyByteArray => Ok(PyByteInner{elements: j.inner.borrow().elements.to_vec()}), k @ PyMemoryView => Ok(PyByteInner{elements: k.get_obj_value().unwrap()}), obj => Err(vm.new_type_error(format!( "a bytes-like object is required, not {}", @@ -261,6 +261,7 @@ impl PyByteInner { 0..=8 => res.push_str(&format!("\\x0{}", i)), 9 => res.push_str("\\t"), 10 => res.push_str("\\n"), + 11 => res.push_str(&format!("\\x0{:x}", i)), 13 => res.push_str("\\r"), 32..=126 => res.push(*(i) as char), _ => res.push_str(&format!("\\x{:x}", i)), @@ -277,43 +278,43 @@ impl PyByteInner { self.elements.len() == 0 } - pub fn eq(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - if self.elements == other.elements { - Ok(vm.new_bool(true)) + pub fn eq(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.new_bool(self.elements == other.elements)) } else { - Ok(vm.new_bool(false)) + Ok(vm.ctx.not_implemented()) } } - pub fn ge(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - if self.elements >= other.elements { - Ok(vm.new_bool(true)) + pub fn ge(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.new_bool(self.elements >= other.elements)) } else { - Ok(vm.new_bool(false)) + Ok(vm.ctx.not_implemented()) } } - pub fn le(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - if self.elements <= other.elements { - Ok(vm.new_bool(true)) + pub fn le(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.new_bool(self.elements <= other.elements)) } else { - Ok(vm.new_bool(false)) + Ok(vm.ctx.not_implemented()) } } - pub fn gt(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - if self.elements > other.elements { - Ok(vm.new_bool(true)) + pub fn gt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.new_bool(self.elements > other.elements)) } else { - Ok(vm.new_bool(false)) + Ok(vm.ctx.not_implemented()) } } - pub fn lt(&self, other: &PyByteInner, vm: &VirtualMachine) -> PyResult { - if self.elements < other.elements { - Ok(vm.new_bool(true)) + pub fn lt(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.new_bool(self.elements < other.elements)) } else { - Ok(vm.new_bool(false)) + Ok(vm.ctx.not_implemented()) } } @@ -323,14 +324,12 @@ impl PyByteInner { hasher.finish() as usize } - pub fn add(&self, other: &PyByteInner, _vm: &VirtualMachine) -> Vec { - let elements: Vec = self - .elements + pub fn add(&self, other: PyByteInner) -> Vec { + self.elements .iter() .chain(other.elements.iter()) .cloned() - .collect(); - elements + .collect::>() } pub fn contains(&self, needle: Either, vm: &VirtualMachine) -> PyResult { @@ -763,7 +762,7 @@ pub fn try_as_byte(obj: &PyObjectRef) -> Option> { match_class!(obj.clone(), i @ PyBytes => Some(i.get_value().to_vec()), - j @ PyByteArray => Some(get_value_bytearray(&j.as_object()).to_vec()), + j @ PyByteArray => Some(j.inner.borrow().elements.to_vec()), _ => None) } diff --git a/vm/src/obj/objbytes.rs b/vm/src/obj/objbytes.rs index 59b0731f4f..30ad6623e0 100644 --- a/vm/src/obj/objbytes.rs +++ b/vm/src/obj/objbytes.rs @@ -9,7 +9,9 @@ use core::cell::Cell; use std::ops::Deref; use crate::function::OptionalArg; -use crate::pyobject::{PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue}; +use crate::pyobject::{ + PyClassImpl, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, +}; use super::objbyteinner::{ ByteInnerFindOptions, ByteInnerNewOptions, ByteInnerPaddingOptions, ByteInnerPosition, @@ -100,35 +102,25 @@ impl PyBytesRef { #[pymethod(name = "__eq__")] fn eq(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => self.inner.eq(&bytes.inner, vm), - _ => Ok(vm.ctx.not_implemented())) + self.inner.eq(other, vm) } - #[pymethod(name = "__ge__")] fn ge(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => self.inner.ge(&bytes.inner, vm), - _ => Ok(vm.ctx.not_implemented())) + self.inner.ge(other, vm) } #[pymethod(name = "__le__")] fn le(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => self.inner.le(&bytes.inner, vm), - _ => Ok(vm.ctx.not_implemented())) + self.inner.le(other, vm) } #[pymethod(name = "__gt__")] fn gt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => self.inner.gt(&bytes.inner, vm), - _ => Ok(vm.ctx.not_implemented())) + self.inner.gt(other, vm) } #[pymethod(name = "__lt__")] fn lt(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => self.inner.lt(&bytes.inner, vm), - _ => Ok(vm.ctx.not_implemented())) + self.inner.lt(other, vm) } + #[pymethod(name = "__hash__")] fn hash(self, _vm: &VirtualMachine) -> usize { self.inner.hash() @@ -144,9 +136,11 @@ impl PyBytesRef { #[pymethod(name = "__add__")] fn add(self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult { - match_class!(other, - bytes @ PyBytes => Ok(vm.ctx.new_bytes(self.inner.add(&bytes.inner, vm))), - _ => Ok(vm.ctx.not_implemented())) + if let Ok(other) = PyByteInner::try_from_object(vm, other) { + Ok(vm.ctx.new_bytearray(self.inner.add(other))) + } else { + Ok(vm.ctx.not_implemented()) + } } #[pymethod(name = "__contains__")] diff --git a/vm/src/stdlib/io.rs b/vm/src/stdlib/io.rs index b8e72cc15d..e9b49ec442 100644 --- a/vm/src/stdlib/io.rs +++ b/vm/src/stdlib/io.rs @@ -124,7 +124,7 @@ fn buffered_reader_read(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { //Copy bytes from the buffer vector into the results vector if let Some(bytes) = buffer.payload::() { - result.extend_from_slice(&bytes.value.borrow()); + result.extend_from_slice(&bytes.inner.borrow().elements); }; let py_len = vm.call_method(&buffer, "__len__", PyFuncArgs::default())?; @@ -207,9 +207,9 @@ fn file_io_readinto(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { if let Some(bytes) = obj.payload::() { //TODO: Implement for MemoryView - let mut value_mut = bytes.value.borrow_mut(); + let value_mut = &mut bytes.inner.borrow_mut().elements; value_mut.clear(); - match f.read_to_end(&mut value_mut) { + match f.read_to_end(value_mut) { Ok(_) => {} Err(_) => return Err(vm.new_value_error("Error reading from Take".to_string())), } @@ -237,7 +237,7 @@ fn file_io_write(vm: &VirtualMachine, args: PyFuncArgs) -> PyResult { match obj.payload::() { Some(bytes) => { - let value_mut = bytes.value.borrow(); + let value_mut = &mut bytes.inner.borrow_mut().elements; match handle.write(&value_mut[..]) { Ok(len) => { //reset raw fd on the FileIO object