From 25d2426ec3a2f8d53e8721eb5eced14ceca3023b Mon Sep 17 00:00:00 2001 From: Jake Armendariz Date: Sun, 19 Dec 2021 22:18:23 -0800 Subject: [PATCH 1/2] Add marshalling support for ints, floats, strs, lists, dict --- Lib/test/test_marshal.py | 91 +++++++++++++ vm/src/builtins/dict.rs | 4 + vm/src/dictdatatype.rs | 22 ++++ vm/src/protocol/buffer.rs | 11 +- vm/src/stdlib/marshal.rs | 262 +++++++++++++++++++++++++++++++++++--- 5 files changed, 370 insertions(+), 20 deletions(-) create mode 100644 Lib/test/test_marshal.py diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py new file mode 100644 index 0000000000..a0642258c4 --- /dev/null +++ b/Lib/test/test_marshal.py @@ -0,0 +1,91 @@ +import unittest +import marshal + +class MarshalTests(unittest.TestCase): + """ + Testing each data type is done with two tests + Test dumps data == expected_bytes + Test load(dumped data) == data + """ + + def dump_then_load(self, data): + return marshal.loads(marshal.dumps(data)) + + def test_dumps_int(self): + self.assertEqual(marshal.dumps(0), b'i0\x00') + self.assertEqual(marshal.dumps(-1), b'i-\x01') + self.assertEqual(marshal.dumps(1), b'i+\x01') + self.assertEqual(marshal.dumps(100000000), b'i+\x00\xe1\xf5\x05') + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0), 0) + self.assertEqual(self.dump_then_load(-1), -1) + self.assertEqual(self.dump_then_load(1), 1) + self.assertEqual(self.dump_then_load(100000000), 100000000) + + def test_dumps_float(self): + self.assertEqual(marshal.dumps(0.0), b'f\x00\x00\x00\x00\x00\x00\x00\x00') + self.assertEqual(marshal.dumps(-10.0), b'f\x00\x00\x00\x00\x00\x00$\xc0') + self.assertEqual(marshal.dumps(10.0), b'f\x00\x00\x00\x00\x00\x00$@') + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0.0), 0.0) + self.assertEqual(self.dump_then_load(-10.0), -10.0) + self.assertEqual(self.dump_then_load(10), 10) + + def test_dumps_str(self): + self.assertEqual(marshal.dumps(""), b's') + self.assertEqual(marshal.dumps("Hello, World"), b'sHello, World') + + def test_dump_and_load_str(self): + self.assertEqual(self.dump_then_load(""), "") + self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") + + def test_dumps_list(self): + # Lists have to print the length of every element + # so when marshelling and unmarshelling we know how many bytes to search + # all usize values are converted to u32 to handle different architecture sizes. + self.assertEqual(marshal.dumps([]), b'[\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps([1, "hello", 1.0]), + b'[\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?', + ) + self.assertEqual( + marshal.dumps([[0], ['a','b']]), + b'[\x02\x00\x00\x00\x0c\x00\x00\x00[\x01\x00\x00\x00\x03\x00\x00\x00i0\x00\x11\x00\x00\x00[\x02\x00\x00\x00\x02\x00\x00\x00sa\x02\x00\x00\x00sb', + ) + + def test_dump_and_load_list(self): + self.assertEqual(self.dump_then_load([]), []) + self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) + self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) + + def test_dumps_tuple(self): + self.assertEqual(marshal.dumps(()), b'(\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps((1, "hello", 1.0)), + b'(\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?' + ) + + def test_dump_and_load_tuple(self): + self.assertEqual(self.dump_then_load(()), ()) + self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) + + def test_dumps_dict(self): + self.assertEqual(marshal.dumps({}), b',[\x00\x00\x00\x00') + self.assertEqual( + marshal.dumps({'a':1, 1:'a'}), + b',[\x02\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x03\x00\x00\x00i+\x01\x12\x00\x00\x00(\x02\x00\x00\x00\x03\x00\x00\x00i+\x01\x02\x00\x00\x00sa' + ) + self.assertEqual( + marshal.dumps({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), + b',[\x02\x00\x00\x00+\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x1c\x00\x00\x00,[\x01\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sb\x03\x00\x00\x00i+\x02<\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sc-\x00\x00\x00[\x04\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x10@\x03\x00\x00\x00i+\x06\x03\x00\x00\x00i+\t' + ) + + def test_dump_and_load_dict(self): + self.assertEqual(self.dump_then_load({}), {}) + self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) + self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index d5d1c6e09d..b589e121a2 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -69,6 +69,10 @@ impl PyDict { &self.entries } + pub(crate) fn from_entries(entries: DictContentType) -> Self { + Self { entries } + } + #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { PyDict::default().into_pyresult_with_type(vm, cls) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index e29b95b1ae..e3784fb4a8 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -114,6 +114,28 @@ struct DictEntry { value: T, } +impl DictEntry { + pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) { + (self.key.clone(), self.value.clone()) + } +} + +impl Dict { + pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { + let entries = &self.inner.read().entries; + entries + .into_iter() + .filter_map(|entry| { + if let Some(dict_entry) = entry { + Some(dict_entry.as_tuple()) + } else { + None + } + }) + .collect() + } +} + #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index c69b4efc2f..6ca3ed3800 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -10,7 +10,7 @@ use crate::{ }, sliceable::wrap_index, types::{Constructor, Unconstructible}, - PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, + PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use std::{borrow::Cow, fmt::Debug, ops::Range}; @@ -63,6 +63,15 @@ impl PyBuffer { .then(|| unsafe { self.contiguous_mut_unchecked() }) } + pub fn from_byte_vector(bytes: Vec, vm: &VirtualMachine) -> Self { + let bytes_len = bytes.len(); + PyBuffer::new( + PyValue::into_object(VecBuffer::from(bytes), vm), + BufferDescriptor::simple(bytes_len, true), + &VEC_BUFFER_METHODS, + ) + } + /// # Safety /// assume the buffer is contiguous pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> { diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index f7ae60eb1c..211d1713af 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -2,23 +2,113 @@ pub(crate) use decl::make_module; #[pymodule(name = "marshal")] mod decl { + /// TODO add support for Booleans, Sets, etc + use ascii::AsciiStr; + use num_bigint::{BigInt, Sign}; + use std::ops::Deref; + use std::slice::Iter; + use crate::{ - builtins::{PyBytes, PyCode}, + builtins::{ + dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, + }, bytecode, - function::ArgBytesLike, + common::borrow::BorrowedValue, + function::{ArgBytesLike, IntoPyObject}, + protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; + const STR_BYTE: u8 = b's'; + const INT_BYTE: u8 = b'i'; + const FLOAT_BYTE: u8 = b'f'; + const LIST_BYTE: u8 = b'['; + const TUPLE_BYTE: u8 = b'('; + const DICT_BYTE: u8 = b','; + + /// Safely convert usize to 4 le bytes + fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { + // For marshalling we want to convert lengths to bytes. To save space + // we limit the size to u32 to keep marshalling smaller. + match u32::try_from(x) { + Ok(n) => Ok(n.to_le_bytes()), + Err(_) => { + Err(vm.new_value_error("Size exceeds 2^32 capacity for marshalling.".to_owned())) + } + } + } + + /// Dumps a iterator of objects into binary vector. + fn dump_list(pyobjs: Iter, vm: &VirtualMachine) -> PyResult> { + let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); + // For each element, dump into binary, then add its length and value. + for element in pyobjs { + let element_bytes: PyBytes = dumps(element.clone(), vm)?; + byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); + byte_list.extend_from_slice(element_bytes.deref()) + } + Ok(byte_list) + } + #[pyfunction] fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { let r = match_class!(match value { + pyint @ PyInt => { + let (sign, uint_bytes) = pyint.as_bigint().to_bytes_le(); + let sign_byte = match sign { + Sign::Minus => b'-', + Sign::NoSign => b'0', + Sign::Plus => b'+', + }; + // Return as [TYPE, SIGN, uint bytes] + PyBytes::from([vec![INT_BYTE, sign_byte], uint_bytes].concat()) + } + pyfloat @ PyFloat => { + let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); + float_bytes.insert(0, FLOAT_BYTE); + PyBytes::from(float_bytes) + } + pystr @ PyStr => { + let mut str_bytes = pystr.as_str().as_bytes().to_vec(); + str_bytes.insert(0, STR_BYTE); + PyBytes::from(str_bytes) + } + pylist @ PyList => { + let pylist_items = pylist.borrow_vec(); + let mut list_bytes = dump_list(pylist_items.iter(), vm)?; + list_bytes.insert(0, LIST_BYTE); + PyBytes::from(list_bytes) + } + pytuple @ PyTuple => { + let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; + tuple_bytes.insert(0, TUPLE_BYTE); + PyBytes::from(tuple_bytes) + } + pydict @ PyDict => { + let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); + // Converts list of tuples to PyObjectRefs of tuples + let elements: Vec = key_value_pairs + .into_iter() + .map(|(k, v)| { + PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm) + }) + .collect(); + // Converts list of tuples to list, dump into binary + let mut dict_bytes = dump_list(elements.iter(), vm)?; + dict_bytes.insert(0, LIST_BYTE); + dict_bytes.insert(0, DICT_BYTE); + PyBytes::from(dict_bytes) + } co @ PyCode => { - PyBytes::from(co.code.map_clone_bag(&bytecode::BasicBag).to_bytes()) + // Code is default, doesn't have prefix. + let code_bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); + PyBytes::from(code_bytes) } - _ => + _ => { return Err(vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type".to_owned() - )), + "TODO: not implemented yet or marshal unsupported type".to_owned(), + )); + } }); Ok(r) } @@ -30,25 +120,159 @@ mod decl { Ok(()) } + /// Read the next 4 bytes of a slice, convert to u32. + /// Side effect: increasing position pointer by 4. + fn eat_u32(bytes: &[u8], position: &mut usize, vm: &VirtualMachine) -> PyResult { + let length_as_u32 = + u32::from_le_bytes(match bytes[*position..(*position + 4)].try_into() { + Ok(length_as_u32) => length_as_u32, + Err(_) => { + return Err( + vm.new_buffer_error("Could not read u32 size from byte array".to_owned()) + ) + } + }); + *position += 4; + Ok(length_as_u32) + } + + /// Reads next element from a python list. First by getting element size + /// then by building a pybuffer and "loading" the pyobject. + /// Moves the position pointer past the element. + fn next_element_of_list( + buf: &BorrowedValue<[u8]>, + position: &mut usize, + vm: &VirtualMachine, + ) -> PyResult { + // Read size of the current element from buffer. + let element_length = eat_u32(buf, position, vm)? as usize; + // Create pybuffer consisting of the data in the next element. + let pybuffer = + PyBuffer::from_byte_vector(buf[*position..(*position + element_length)].to_vec(), vm); + // Move position pointer past element. + *position += element_length; + // Return marshalled element. + loads(pybuffer, vm) + } + + /// Reads a list (or tuple) from a buffer. + fn read_list(buf: &BorrowedValue<[u8]>, vm: &VirtualMachine) -> PyResult> { + let mut position = 1; + let expected_array_len = eat_u32(buf, &mut position, vm)? as usize; + // Read each element in list, incrementing position pointer to reflect position in the buffer. + let mut elements: Vec = Vec::new(); + while position < buf.len() { + elements.push(next_element_of_list(buf, &mut position, vm)?); + } + debug_assert!(expected_array_len == elements.len()); + debug_assert!(buf.len() == position); + Ok(elements) + } + + /// Builds a PyDict from iterator of tuple objects + pub fn from_tuples(iterable: Iter, vm: &VirtualMachine) -> PyResult { + let dict = DictContentType::default(); + for elem in iterable { + let items = match_class!(match elem.clone() { + pytuple @ PyTuple => pytuple.as_slice().to_vec(), + _ => + return Err(vm.new_value_error( + "Couldn't unmarshal key:value pair of dictionary".to_owned() + )), + }); + // Marshalled tuples are always in format key:value. + dict.insert( + vm, + items.get(0).unwrap().clone(), + items.get(1).unwrap().clone(), + )?; + } + Ok(PyDict::from_entries(dict)) + } + #[pyfunction] - fn loads(code_bytes: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - let buf = &*code_bytes.borrow_buf(); - let code = bytecode::CodeObject::from_bytes(buf).map_err(|e| match e { - bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( - vm.ctx.exceptions.eof_error.clone(), - "end of file while deserializing bytecode".to_owned(), - ), - _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), + fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { + let buf = &pybuffer.as_contiguous().ok_or_else(|| { + vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) })?; - Ok(PyCode { - code: vm.map_codeobj(code), - }) + match buf[0] { + INT_BYTE => { + let sign = match buf[1] { + b'-' => Sign::Minus, + b'0' => Sign::NoSign, + b'+' => Sign::Plus, + _ => { + return Err(vm.new_value_error( + "Unknown sign byte when trying to unmarshal integer".to_owned(), + )) + } + }; + let pyint = BigInt::from_bytes_le(sign, &buf[2..buf.len()]); + Ok(pyint.into_pyobject(vm)) + } + FLOAT_BYTE => { + let number = f64::from_le_bytes(match buf[1..buf.len()].try_into() { + Ok(byte_array) => byte_array, + Err(e) => { + return Err(vm.new_value_error(format!( + "Expected float, could not load from bytes. {}", + e + ))) + } + }); + let pyfloat = PyFloat::from(number); + Ok(pyfloat.into_pyobject(vm)) + } + STR_BYTE => { + let pystr = PyStr::from(match AsciiStr::from_ascii(&buf[1..buf.len()]) { + Ok(ascii_str) => ascii_str, + Err(e) => { + return Err( + vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e)) + ) + } + }); + Ok(pystr.into_pyobject(vm)) + } + LIST_BYTE => { + let elements = read_list(buf, vm)?; + Ok(elements.into_pyobject(vm)) + } + TUPLE_BYTE => { + let elements = read_list(buf, vm)?; + let pytuple = PyTuple::new_ref(elements, &vm.ctx).into_pyobject(vm); + Ok(pytuple) + } + DICT_BYTE => { + let pybuffer = PyBuffer::from_byte_vector(buf[1..buf.len()].to_vec(), vm); + let pydict = match_class!(match loads(pybuffer, vm)? { + pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, + _ => + return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())), + }); + Ok(pydict.into_pyobject(vm)) + } + _ => { + // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. + let code = bytecode::CodeObject::from_bytes(&buf).map_err(|e| match e { + bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( + vm.ctx.exceptions.eof_error.clone(), + "End of file while deserializing bytecode".to_owned(), + ), + _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), + })?; + Ok(PyCode { + code: vm.map_codeobj(code), + } + .into_pyobject(vm)) + } + } } #[pyfunction] - fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { let read_res = vm.call_method(&f, "read", ())?; let bytes = ArgBytesLike::try_from_object(vm, read_res)?; - loads(bytes, vm) + loads(PyBuffer::from(bytes), vm) } } From 502d5c217d562ba63fc06d502fd77b5455145b43 Mon Sep 17 00:00:00 2001 From: Jake Armendariz Date: Wed, 19 Jan 2022 00:14:51 -0800 Subject: [PATCH 2/2] Changes to code/style of marshaling module --- Lib/test/test_marshal.py | 91 ------------------ extra_tests/snippets/test_marshal.py | 42 +++++++++ vm/src/dictdatatype.rs | 10 +- vm/src/stdlib/marshal.rs | 133 +++++++++++++-------------- 4 files changed, 109 insertions(+), 167 deletions(-) delete mode 100644 Lib/test/test_marshal.py create mode 100644 extra_tests/snippets/test_marshal.py diff --git a/Lib/test/test_marshal.py b/Lib/test/test_marshal.py deleted file mode 100644 index a0642258c4..0000000000 --- a/Lib/test/test_marshal.py +++ /dev/null @@ -1,91 +0,0 @@ -import unittest -import marshal - -class MarshalTests(unittest.TestCase): - """ - Testing each data type is done with two tests - Test dumps data == expected_bytes - Test load(dumped data) == data - """ - - def dump_then_load(self, data): - return marshal.loads(marshal.dumps(data)) - - def test_dumps_int(self): - self.assertEqual(marshal.dumps(0), b'i0\x00') - self.assertEqual(marshal.dumps(-1), b'i-\x01') - self.assertEqual(marshal.dumps(1), b'i+\x01') - self.assertEqual(marshal.dumps(100000000), b'i+\x00\xe1\xf5\x05') - - def test_dump_and_load_int(self): - self.assertEqual(self.dump_then_load(0), 0) - self.assertEqual(self.dump_then_load(-1), -1) - self.assertEqual(self.dump_then_load(1), 1) - self.assertEqual(self.dump_then_load(100000000), 100000000) - - def test_dumps_float(self): - self.assertEqual(marshal.dumps(0.0), b'f\x00\x00\x00\x00\x00\x00\x00\x00') - self.assertEqual(marshal.dumps(-10.0), b'f\x00\x00\x00\x00\x00\x00$\xc0') - self.assertEqual(marshal.dumps(10.0), b'f\x00\x00\x00\x00\x00\x00$@') - - def test_dump_and_load_int(self): - self.assertEqual(self.dump_then_load(0.0), 0.0) - self.assertEqual(self.dump_then_load(-10.0), -10.0) - self.assertEqual(self.dump_then_load(10), 10) - - def test_dumps_str(self): - self.assertEqual(marshal.dumps(""), b's') - self.assertEqual(marshal.dumps("Hello, World"), b'sHello, World') - - def test_dump_and_load_str(self): - self.assertEqual(self.dump_then_load(""), "") - self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") - - def test_dumps_list(self): - # Lists have to print the length of every element - # so when marshelling and unmarshelling we know how many bytes to search - # all usize values are converted to u32 to handle different architecture sizes. - self.assertEqual(marshal.dumps([]), b'[\x00\x00\x00\x00') - self.assertEqual( - marshal.dumps([1, "hello", 1.0]), - b'[\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?', - ) - self.assertEqual( - marshal.dumps([[0], ['a','b']]), - b'[\x02\x00\x00\x00\x0c\x00\x00\x00[\x01\x00\x00\x00\x03\x00\x00\x00i0\x00\x11\x00\x00\x00[\x02\x00\x00\x00\x02\x00\x00\x00sa\x02\x00\x00\x00sb', - ) - - def test_dump_and_load_list(self): - self.assertEqual(self.dump_then_load([]), []) - self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) - self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) - - def test_dumps_tuple(self): - self.assertEqual(marshal.dumps(()), b'(\x00\x00\x00\x00') - self.assertEqual( - marshal.dumps((1, "hello", 1.0)), - b'(\x03\x00\x00\x00\x03\x00\x00\x00i+\x01\x06\x00\x00\x00shello\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\xf0?' - ) - - def test_dump_and_load_tuple(self): - self.assertEqual(self.dump_then_load(()), ()) - self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) - - def test_dumps_dict(self): - self.assertEqual(marshal.dumps({}), b',[\x00\x00\x00\x00') - self.assertEqual( - marshal.dumps({'a':1, 1:'a'}), - b',[\x02\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x03\x00\x00\x00i+\x01\x12\x00\x00\x00(\x02\x00\x00\x00\x03\x00\x00\x00i+\x01\x02\x00\x00\x00sa' - ) - self.assertEqual( - marshal.dumps({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), - b',[\x02\x00\x00\x00+\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sa\x1c\x00\x00\x00,[\x01\x00\x00\x00\x12\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sb\x03\x00\x00\x00i+\x02<\x00\x00\x00(\x02\x00\x00\x00\x02\x00\x00\x00sc-\x00\x00\x00[\x04\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00f\x00\x00\x00\x00\x00\x00\x10@\x03\x00\x00\x00i+\x06\x03\x00\x00\x00i+\t' - ) - - def test_dump_and_load_dict(self): - self.assertEqual(self.dump_then_load({}), {}) - self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) - self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) - -if __name__ == "__main__": - unittest.main() \ No newline at end of file diff --git a/extra_tests/snippets/test_marshal.py b/extra_tests/snippets/test_marshal.py new file mode 100644 index 0000000000..eee468333b --- /dev/null +++ b/extra_tests/snippets/test_marshal.py @@ -0,0 +1,42 @@ +import unittest +import marshal + +class MarshalTests(unittest.TestCase): + """ + Testing the (incomplete) marshal module. + """ + + def dump_then_load(self, data): + return marshal.loads(marshal.dumps(data)) + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0), 0) + self.assertEqual(self.dump_then_load(-1), -1) + self.assertEqual(self.dump_then_load(1), 1) + self.assertEqual(self.dump_then_load(100000000), 100000000) + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0.0), 0.0) + self.assertEqual(self.dump_then_load(-10.0), -10.0) + self.assertEqual(self.dump_then_load(10), 10) + + def test_dump_and_load_str(self): + self.assertEqual(self.dump_then_load(""), "") + self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") + + def test_dump_and_load_list(self): + self.assertEqual(self.dump_then_load([]), []) + self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) + self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) + + def test_dump_and_load_tuple(self): + self.assertEqual(self.dump_then_load(()), ()) + self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) + + def test_dump_and_load_dict(self): + self.assertEqual(self.dump_then_load({}), {}) + self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) + self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index e3784fb4a8..6278df5591 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -124,14 +124,8 @@ impl Dict { pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { let entries = &self.inner.read().entries; entries - .into_iter() - .filter_map(|entry| { - if let Some(dict_entry) = entry { - Some(dict_entry.as_tuple()) - } else { - None - } - }) + .iter() + .filter_map(|entry| entry.as_ref().map(|dict_entry| dict_entry.as_tuple())) .collect() } } diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index 211d1713af..d5a84f0c12 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -5,7 +5,6 @@ mod decl { /// TODO add support for Booleans, Sets, etc use ascii::AsciiStr; use num_bigint::{BigInt, Sign}; - use std::ops::Deref; use std::slice::Iter; use crate::{ @@ -13,7 +12,6 @@ mod decl { dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, }, bytecode, - common::borrow::BorrowedValue, function::{ArgBytesLike, IntoPyObject}, protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, @@ -33,7 +31,7 @@ mod decl { match u32::try_from(x) { Ok(n) => Ok(n.to_le_bytes()), Err(_) => { - Err(vm.new_value_error("Size exceeds 2^32 capacity for marshalling.".to_owned())) + Err(vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned())) } } } @@ -43,66 +41,64 @@ mod decl { let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); // For each element, dump into binary, then add its length and value. for element in pyobjs { - let element_bytes: PyBytes = dumps(element.clone(), vm)?; + let element_bytes: Vec = _dumps(element.clone(), vm)?; byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); - byte_list.extend_from_slice(element_bytes.deref()) + byte_list.extend(element_bytes) } Ok(byte_list) } - #[pyfunction] - fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { let r = match_class!(match value { pyint @ PyInt => { - let (sign, uint_bytes) = pyint.as_bigint().to_bytes_le(); + let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le(); let sign_byte = match sign { Sign::Minus => b'-', Sign::NoSign => b'0', Sign::Plus => b'+', }; // Return as [TYPE, SIGN, uint bytes] - PyBytes::from([vec![INT_BYTE, sign_byte], uint_bytes].concat()) + int_bytes.insert(0, sign_byte); + int_bytes.push(INT_BYTE); + int_bytes } pyfloat @ PyFloat => { let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); - float_bytes.insert(0, FLOAT_BYTE); - PyBytes::from(float_bytes) + float_bytes.push(FLOAT_BYTE); + float_bytes } pystr @ PyStr => { let mut str_bytes = pystr.as_str().as_bytes().to_vec(); - str_bytes.insert(0, STR_BYTE); - PyBytes::from(str_bytes) + str_bytes.push(STR_BYTE); + str_bytes } pylist @ PyList => { let pylist_items = pylist.borrow_vec(); let mut list_bytes = dump_list(pylist_items.iter(), vm)?; - list_bytes.insert(0, LIST_BYTE); - PyBytes::from(list_bytes) + list_bytes.push(LIST_BYTE); + list_bytes } pytuple @ PyTuple => { let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; - tuple_bytes.insert(0, TUPLE_BYTE); - PyBytes::from(tuple_bytes) + tuple_bytes.push(TUPLE_BYTE); + tuple_bytes } pydict @ PyDict => { let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); // Converts list of tuples to PyObjectRefs of tuples let elements: Vec = key_value_pairs .into_iter() - .map(|(k, v)| { - PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm) - }) + .map(|(k, v)| PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm)) .collect(); // Converts list of tuples to list, dump into binary let mut dict_bytes = dump_list(elements.iter(), vm)?; - dict_bytes.insert(0, LIST_BYTE); - dict_bytes.insert(0, DICT_BYTE); - PyBytes::from(dict_bytes) + dict_bytes.push(LIST_BYTE); + dict_bytes.push(DICT_BYTE); + dict_bytes } co @ PyCode => { // Code is default, doesn't have prefix. - let code_bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes(); - PyBytes::from(code_bytes) + co.code.map_clone_bag(&bytecode::BasicBag).to_bytes() } _ => { return Err(vm.new_not_implemented_error( @@ -113,6 +109,11 @@ mod decl { Ok(r) } + #[pyfunction] + fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyBytes::from(_dumps(value, vm)?)) + } + #[pyfunction] fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let dumped = dumps(value, vm)?; @@ -120,52 +121,39 @@ mod decl { Ok(()) } - /// Read the next 4 bytes of a slice, convert to u32. - /// Side effect: increasing position pointer by 4. - fn eat_u32(bytes: &[u8], position: &mut usize, vm: &VirtualMachine) -> PyResult { - let length_as_u32 = - u32::from_le_bytes(match bytes[*position..(*position + 4)].try_into() { - Ok(length_as_u32) => length_as_u32, - Err(_) => { - return Err( - vm.new_buffer_error("Could not read u32 size from byte array".to_owned()) - ) - } - }); - *position += 4; - Ok(length_as_u32) + /// Read the next 4 bytes of a slice, read as u32, pass as usize. + /// Returns the rest of buffer with the value. + fn eat_length<'a>(bytes: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> { + let (u32_bytes, rest) = bytes.split_at(4); + let length = u32::from_le_bytes(u32_bytes.try_into().map_err(|_| { + vm.new_value_error("Could not read u32 size from byte array".to_owned()) + })?); + Ok((length as usize, rest)) } /// Reads next element from a python list. First by getting element size /// then by building a pybuffer and "loading" the pyobject. - /// Moves the position pointer past the element. - fn next_element_of_list( - buf: &BorrowedValue<[u8]>, - position: &mut usize, + /// Returns rest of buffer with object. + fn next_element_of_list<'a>( + buf: &'a [u8], vm: &VirtualMachine, - ) -> PyResult { - // Read size of the current element from buffer. - let element_length = eat_u32(buf, position, vm)? as usize; - // Create pybuffer consisting of the data in the next element. - let pybuffer = - PyBuffer::from_byte_vector(buf[*position..(*position + element_length)].to_vec(), vm); - // Move position pointer past element. - *position += element_length; - // Return marshalled element. - loads(pybuffer, vm) + ) -> PyResult<(PyObjectRef, &'a [u8])> { + let (element_length, element_and_rest) = eat_length(buf, vm)?; + let (element_buff, rest) = element_and_rest.split_at(element_length); + let pybuffer = PyBuffer::from_byte_vector(element_buff.to_vec(), vm); + Ok((loads(pybuffer, vm)?, rest)) } /// Reads a list (or tuple) from a buffer. - fn read_list(buf: &BorrowedValue<[u8]>, vm: &VirtualMachine) -> PyResult> { - let mut position = 1; - let expected_array_len = eat_u32(buf, &mut position, vm)? as usize; - // Read each element in list, incrementing position pointer to reflect position in the buffer. + fn read_list(buf: &[u8], vm: &VirtualMachine) -> PyResult> { + let (expected_array_len, mut buffer) = eat_length(buf, vm)?; let mut elements: Vec = Vec::new(); - while position < buf.len() { - elements.push(next_element_of_list(buf, &mut position, vm)?); + while !buffer.is_empty() { + let (element, rest_of_buffer) = next_element_of_list(buffer, vm)?; + elements.push(element); + buffer = rest_of_buffer; } debug_assert!(expected_array_len == elements.len()); - debug_assert!(buf.len() == position); Ok(elements) } @@ -192,12 +180,21 @@ mod decl { #[pyfunction] fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { - let buf = &pybuffer.as_contiguous().ok_or_else(|| { + let full_buff = pybuffer.as_contiguous().ok_or_else(|| { vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) })?; - match buf[0] { + let (type_indicator, buf) = full_buff.split_last().ok_or_else(|| { + vm.new_exception_msg( + vm.ctx.exceptions.eof_error.clone(), + "EOF where object expected.".to_owned(), + ) + })?; + match *type_indicator { INT_BYTE => { - let sign = match buf[1] { + let (sign_byte, uint_bytes) = buf + .split_first() + .ok_or_else(|| vm.new_value_error("EOF where object expected.".to_owned()))?; + let sign = match sign_byte { b'-' => Sign::Minus, b'0' => Sign::NoSign, b'+' => Sign::Plus, @@ -207,11 +204,11 @@ mod decl { )) } }; - let pyint = BigInt::from_bytes_le(sign, &buf[2..buf.len()]); + let pyint = BigInt::from_bytes_le(sign, uint_bytes); Ok(pyint.into_pyobject(vm)) } FLOAT_BYTE => { - let number = f64::from_le_bytes(match buf[1..buf.len()].try_into() { + let number = f64::from_le_bytes(match buf[..].try_into() { Ok(byte_array) => byte_array, Err(e) => { return Err(vm.new_value_error(format!( @@ -224,7 +221,7 @@ mod decl { Ok(pyfloat.into_pyobject(vm)) } STR_BYTE => { - let pystr = PyStr::from(match AsciiStr::from_ascii(&buf[1..buf.len()]) { + let pystr = PyStr::from(match AsciiStr::from_ascii(buf) { Ok(ascii_str) => ascii_str, Err(e) => { return Err( @@ -244,7 +241,7 @@ mod decl { Ok(pytuple) } DICT_BYTE => { - let pybuffer = PyBuffer::from_byte_vector(buf[1..buf.len()].to_vec(), vm); + let pybuffer = PyBuffer::from_byte_vector(buf[..].to_vec(), vm); let pydict = match_class!(match loads(pybuffer, vm)? { pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, _ => @@ -254,7 +251,7 @@ mod decl { } _ => { // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. - let code = bytecode::CodeObject::from_bytes(&buf).map_err(|e| match e { + let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e { bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( vm.ctx.exceptions.eof_error.clone(), "End of file while deserializing bytecode".to_owned(),