diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py index 35412f5bec..6b6c367324 100644 --- a/Lib/test/test_marshal.py +++ b/Lib/test/test_marshal.py @@ -64,8 +64,6 @@ def test_bool(self): self.helper(b) class FloatTestCase(unittest.TestCase, HelperMixin): - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_floats(self): # Test a few floats small = 1e-25 @@ -101,8 +99,6 @@ def test_string(self): for s in ["", "Andr\xe8 Previn", "abc", " "*10000]: self.helper(s) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_bytes(self): for s in [b"", b"Andr\xe8 Previn", b"abc", b" "*10000]: self.helper(s) @@ -202,14 +198,11 @@ def test_patch_873224(self): self.assertRaises(Exception, marshal.loads, b'f') self.assertRaises(Exception, marshal.loads, marshal.dumps(2**65)[:-1]) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_version_argument(self): # Python 2.4.0 crashes for any call to marshal.dumps(x, y) self.assertEqual(marshal.loads(marshal.dumps(5, 0)), 5) self.assertEqual(marshal.loads(marshal.dumps(5, 1)), 5) - @unittest.skip("TODO: RUSTPYTHON; panic") def test_fuzz(self): # simple test that it's at least not *totally* trivial to # crash from bad marshal data @@ -337,8 +330,6 @@ def readinto(self, buf): self.assertRaises(ValueError, marshal.load, BadReader(marshal.dumps(value))) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_eof(self): data = marshal.dumps(("hello", "dolly", None)) for i in range(len(data)): @@ -509,8 +500,7 @@ def testModule(self): self.helper(code) self.helper3(code) - # TODO: RUSTPYTHON - @unittest.expectedFailure + @unittest.skip("TODO: RUSTPYTHON") def testRecursion(self): obj = 1.2345 d = {"hello": obj, "goodbye": obj, obj: "hello"} @@ -529,23 +519,15 @@ def _test(self, version): data = marshal.dumps(code, version) marshal.loads(data) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test0To3(self): self._test(0) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test1To3(self): self._test(1) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test2To3(self): self._test(2) - # TODO: RUSTPYTHON - @unittest.expectedFailure def test3To3(self): self._test(3) @@ -562,8 +544,6 @@ def testIntern(self): s2 = sys.intern(s) self.assertEqual(id(s2), id(s)) - # TODO: RUSTPYTHON - @unittest.expectedFailure def testNoIntern(self): s = marshal.loads(marshal.dumps(self.strobj, 2)) self.assertEqual(s, self.strobj) diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index ae1cfa27b9..21d575e3be 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -9,24 +9,22 @@ mod decl { }, bytecode, convert::ToPyObject, - function::ArgBytesLike, + function::{ArgBytesLike, OptionalArg}, object::AsObject, protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; - /// TODO - /// PyBytes: Currently getting recursion error with match_class! use num_bigint::{BigInt, Sign}; use num_traits::Zero; #[repr(u8)] enum Type { // Null = b'0', - // None = b'N', + None = b'N', False = b'F', True = b'T', // StopIter = b'S', - // Ellipsis = b'.', + Ellipsis = b'.', Int = b'i', Float = b'g', // Complex = b'y', @@ -38,11 +36,11 @@ mod decl { List = b'[', Dict = b'{', Code = b'c', - Str = b'u', // = TYPE_UNICODE + Unicode = b'u', // Unknown = b'?', Set = b'<', FrozenSet = b'>', - // Ascii = b'a', + Ascii = b'a', // AsciiInterned = b'A', // SmallTuple = b')', // ShortAscii = b'z', @@ -56,11 +54,11 @@ mod decl { use Type::*; Ok(match value { // b'0' => Null, - // b'N' => None, + b'N' => None, b'F' => False, b'T' => True, // b'S' => StopIter, - // b'.' => Ellipsis, + b'.' => Ellipsis, b'i' => Int, b'g' => Float, // b'y' => Complex, @@ -72,11 +70,11 @@ mod decl { b'[' => List, b'{' => Dict, b'c' => Code, - b'u' => Str, + b'u' => Unicode, // b'?' => Unknown, b'<' => Set, b'>' => FrozenSet, - // b'a' => Ascii, + b'a' => Ascii, // b'A' => AsciiInterned, // b')' => SmallTuple, // b'z' => ShortAscii, @@ -86,6 +84,9 @@ mod decl { } } + #[pyattr(name = "version")] + const VERSION: u32 = 4; + fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef { vm.new_exception_msg( vm.ctx.exceptions.eof_error.to_owned(), @@ -109,93 +110,118 @@ mod decl { /// Dumping helper function to turn a value into bytes. fn dump_obj(buf: &mut Vec, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - match_class!(match value { - pyint @ PyInt => { - if pyint.class().is(vm.ctx.types.bool_type) { - let typ = if pyint.as_bigint().is_zero() { - Type::False + if vm.is_none(&value) { + buf.push(Type::None as u8); + } else if value.is(&vm.ctx.ellipsis) { + buf.push(Type::Ellipsis as u8); + } else { + match_class!(match value { + pyint @ PyInt => { + if pyint.class().is(vm.ctx.types.bool_type) { + let typ = if pyint.as_bigint().is_zero() { + Type::False + } else { + Type::True + }; + buf.push(typ as u8); } else { - Type::True - }; - buf.push(typ as u8); - } else { - buf.push(Type::Int as u8); - let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); - let mut len = int_bytes.len() as i32; - if sign == Sign::Minus { - len = -len; + buf.push(Type::Int as u8); + let (sign, int_bytes) = pyint.as_bigint().to_bytes_le(); + let mut len = int_bytes.len() as i32; + if sign == Sign::Minus { + len = -len; + } + buf.extend(len.to_le_bytes()); + buf.extend(int_bytes); } - buf.extend(len.to_le_bytes()); - buf.extend(int_bytes); } - } - pyfloat @ PyFloat => { - buf.push(Type::Float as u8); - buf.extend(pyfloat.to_f64().to_le_bytes()); - } - pystr @ PyStr => { - buf.push(Type::Str as u8); - write_size(buf, pystr.as_str().len(), vm)?; - buf.extend(pystr.as_str().as_bytes()); - } - pylist @ PyList => { - buf.push(Type::List as u8); - let pylist_items = pylist.borrow_vec(); - dump_seq(buf, pylist_items.iter(), vm)?; - } - pyset @ PySet => { - buf.push(Type::Set as u8); - let elements = pyset.elements(); - dump_seq(buf, elements.iter(), vm)?; - } - pyfrozen @ PyFrozenSet => { - buf.push(Type::FrozenSet as u8); - let elements = pyfrozen.elements(); - dump_seq(buf, elements.iter(), vm)?; - } - pytuple @ PyTuple => { - buf.push(Type::Tuple as u8); - dump_seq(buf, pytuple.iter(), vm)?; - } - pydict @ PyDict => { - buf.push(Type::Dict as u8); - write_size(buf, pydict.len(), vm)?; - for (key, value) in pydict { - dump_obj(buf, key, vm)?; - dump_obj(buf, value, vm)?; + pyfloat @ PyFloat => { + buf.push(Type::Float as u8); + buf.extend(pyfloat.to_f64().to_le_bytes()); } - } - bytes @ PyByteArray => { - buf.push(Type::Bytes as u8); - let data = bytes.borrow_buf(); - write_size(buf, data.len(), vm)?; - buf.extend(&*data); - } - co @ PyCode => { - buf.push(Type::Code as u8); - let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); - write_size(buf, bytes.len(), vm)?; - buf.extend(bytes); - } - _ => { - return Err(vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type".to_owned(), - )); - } - }); + pystr @ PyStr => { + buf.push(if pystr.is_ascii() { + Type::Ascii + } else { + Type::Unicode + } as u8); + write_size(buf, pystr.as_str().len(), vm)?; + buf.extend(pystr.as_str().as_bytes()); + } + pylist @ PyList => { + buf.push(Type::List as u8); + let pylist_items = pylist.borrow_vec(); + dump_seq(buf, pylist_items.iter(), vm)?; + } + pyset @ PySet => { + buf.push(Type::Set as u8); + let elements = pyset.elements(); + dump_seq(buf, elements.iter(), vm)?; + } + pyfrozen @ PyFrozenSet => { + buf.push(Type::FrozenSet as u8); + let elements = pyfrozen.elements(); + dump_seq(buf, elements.iter(), vm)?; + } + pytuple @ PyTuple => { + buf.push(Type::Tuple as u8); + dump_seq(buf, pytuple.iter(), vm)?; + } + pydict @ PyDict => { + buf.push(Type::Dict as u8); + write_size(buf, pydict.len(), vm)?; + for (key, value) in pydict { + dump_obj(buf, key, vm)?; + dump_obj(buf, value, vm)?; + } + } + bytes @ PyBytes => { + buf.push(Type::Bytes as u8); + let data = bytes.as_bytes(); + write_size(buf, data.len(), vm)?; + buf.extend(&*data); + } + bytes @ PyByteArray => { + buf.push(Type::Bytes as u8); + let data = bytes.borrow_buf(); + write_size(buf, data.len(), vm)?; + buf.extend(&*data); + } + co @ PyCode => { + buf.push(Type::Code as u8); + let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); + write_size(buf, bytes.len(), vm)?; + buf.extend(bytes); + } + _ => { + return Err(vm.new_not_implemented_error( + "TODO: not implemented yet or marshal unsupported type".to_owned(), + )); + } + }) + } Ok(()) } #[pyfunction] - fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn dumps( + value: PyObjectRef, + _version: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult { let mut buf = Vec::new(); dump_obj(&mut buf, value, vm)?; Ok(PyBytes::from(buf)) } #[pyfunction] - fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { - let dumped = dumps(value, vm)?; + fn dump( + value: PyObjectRef, + f: PyObjectRef, + version: OptionalArg, + vm: &VirtualMachine, + ) -> PyResult<()> { + let dumped = dumps(value, version, vm)?; vm.call_method(&f, "write", (dumped,))?; Ok(()) } @@ -248,8 +274,10 @@ mod decl { let typ = Type::try_from(*type_indicator) .map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?; let (obj, buf) = match typ { - Type::True => ((true).to_pyobject(vm), buf), - Type::False => ((false).to_pyobject(vm), buf), + Type::True => (true.to_pyobject(vm), buf), + Type::False => (false.to_pyobject(vm), buf), + Type::None => (vm.ctx.none(), buf), + Type::Ellipsis => (vm.ctx.ellipsis(), buf), Type::Int => { if buf.len() < 4 { return Err(too_short_error(vm)); @@ -276,7 +304,17 @@ mod decl { let number = f64::from_le_bytes(bytes.try_into().unwrap()); (vm.ctx.new_float(number).into(), buf) } - Type::Str => { + Type::Ascii => { + let (len, buf) = read_size(buf, vm)?; + if buf.len() < len { + return Err(too_short_error(vm)); + } + let (bytes, buf) = buf.split_at(len); + let s = String::from_utf8(bytes.to_vec()) + .map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?; + (s.to_pyobject(vm), buf) + } + Type::Unicode => { let (len, buf) = read_size(buf, vm)?; if buf.len() < len { return Err(too_short_error(vm));