diff --git a/extra_tests/snippets/test_marshal.py b/extra_tests/snippets/test_marshal.py new file mode 100644 index 0000000000..eee468333b --- /dev/null +++ b/extra_tests/snippets/test_marshal.py @@ -0,0 +1,42 @@ +import unittest +import marshal + +class MarshalTests(unittest.TestCase): + """ + Testing the (incomplete) marshal module. + """ + + def dump_then_load(self, data): + return marshal.loads(marshal.dumps(data)) + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0), 0) + self.assertEqual(self.dump_then_load(-1), -1) + self.assertEqual(self.dump_then_load(1), 1) + self.assertEqual(self.dump_then_load(100000000), 100000000) + + def test_dump_and_load_int(self): + self.assertEqual(self.dump_then_load(0.0), 0.0) + self.assertEqual(self.dump_then_load(-10.0), -10.0) + self.assertEqual(self.dump_then_load(10), 10) + + def test_dump_and_load_str(self): + self.assertEqual(self.dump_then_load(""), "") + self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") + + def test_dump_and_load_list(self): + self.assertEqual(self.dump_then_load([]), []) + self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) + self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) + + def test_dump_and_load_tuple(self): + self.assertEqual(self.dump_then_load(()), ()) + self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) + + def test_dump_and_load_dict(self): + self.assertEqual(self.dump_then_load({}), {}) + self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) + self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/vm/src/builtins/dict.rs b/vm/src/builtins/dict.rs index d5d1c6e09d..b589e121a2 100644 --- a/vm/src/builtins/dict.rs +++ b/vm/src/builtins/dict.rs @@ -69,6 +69,10 @@ impl PyDict { &self.entries } + pub(crate) fn from_entries(entries: DictContentType) -> Self { + Self { entries } + } + #[pyslot] fn slot_new(cls: PyTypeRef, _args: FuncArgs, vm: &VirtualMachine) -> PyResult { PyDict::default().into_pyresult_with_type(vm, cls) diff --git a/vm/src/dictdatatype.rs b/vm/src/dictdatatype.rs index e29b95b1ae..6278df5591 100644 --- a/vm/src/dictdatatype.rs +++ b/vm/src/dictdatatype.rs @@ -114,6 +114,22 @@ struct DictEntry { value: T, } +impl DictEntry { + pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) { + (self.key.clone(), self.value.clone()) + } +} + +impl Dict { + pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { + let entries = &self.inner.read().entries; + entries + .iter() + .filter_map(|entry| entry.as_ref().map(|dict_entry| dict_entry.as_tuple())) + .collect() + } +} + #[derive(Debug, PartialEq)] pub struct DictSize { indices_size: usize, diff --git a/vm/src/protocol/buffer.rs b/vm/src/protocol/buffer.rs index c69b4efc2f..6ca3ed3800 100644 --- a/vm/src/protocol/buffer.rs +++ b/vm/src/protocol/buffer.rs @@ -10,7 +10,7 @@ use crate::{ }, sliceable::wrap_index, types::{Constructor, Unconstructible}, - PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, + PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, PyValue, TryFromBorrowedObject, TypeProtocol, VirtualMachine, }; use std::{borrow::Cow, fmt::Debug, ops::Range}; @@ -63,6 +63,15 @@ impl PyBuffer { .then(|| unsafe { self.contiguous_mut_unchecked() }) } + pub fn from_byte_vector(bytes: Vec, vm: &VirtualMachine) -> Self { + let bytes_len = bytes.len(); + PyBuffer::new( + PyValue::into_object(VecBuffer::from(bytes), vm), + BufferDescriptor::simple(bytes_len, true), + &VEC_BUFFER_METHODS, + ) + } + /// # Safety /// assume the buffer is contiguous pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> { diff --git a/vm/src/stdlib/marshal.rs b/vm/src/stdlib/marshal.rs index f7ae60eb1c..d5a84f0c12 100644 --- a/vm/src/stdlib/marshal.rs +++ b/vm/src/stdlib/marshal.rs @@ -2,27 +2,118 @@ pub(crate) use decl::make_module; #[pymodule(name = "marshal")] mod decl { + /// TODO add support for Booleans, Sets, etc + use ascii::AsciiStr; + use num_bigint::{BigInt, Sign}; + use std::slice::Iter; + use crate::{ - builtins::{PyBytes, PyCode}, + builtins::{ + dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, + }, bytecode, - function::ArgBytesLike, + function::{ArgBytesLike, IntoPyObject}, + protocol::PyBuffer, PyObjectRef, PyResult, TryFromObject, VirtualMachine, }; - #[pyfunction] - fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + const STR_BYTE: u8 = b's'; + const INT_BYTE: u8 = b'i'; + const FLOAT_BYTE: u8 = b'f'; + const LIST_BYTE: u8 = b'['; + const TUPLE_BYTE: u8 = b'('; + const DICT_BYTE: u8 = b','; + + /// Safely convert usize to 4 le bytes + fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { + // For marshalling we want to convert lengths to bytes. To save space + // we limit the size to u32 to keep marshalling smaller. + match u32::try_from(x) { + Ok(n) => Ok(n.to_le_bytes()), + Err(_) => { + Err(vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned())) + } + } + } + + /// Dumps a iterator of objects into binary vector. + fn dump_list(pyobjs: Iter, vm: &VirtualMachine) -> PyResult> { + let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); + // For each element, dump into binary, then add its length and value. + for element in pyobjs { + let element_bytes: Vec = _dumps(element.clone(), vm)?; + byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); + byte_list.extend(element_bytes) + } + Ok(byte_list) + } + + fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult> { let r = match_class!(match value { + pyint @ PyInt => { + let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le(); + let sign_byte = match sign { + Sign::Minus => b'-', + Sign::NoSign => b'0', + Sign::Plus => b'+', + }; + // Return as [TYPE, SIGN, uint bytes] + int_bytes.insert(0, sign_byte); + int_bytes.push(INT_BYTE); + int_bytes + } + pyfloat @ PyFloat => { + let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); + float_bytes.push(FLOAT_BYTE); + float_bytes + } + pystr @ PyStr => { + let mut str_bytes = pystr.as_str().as_bytes().to_vec(); + str_bytes.push(STR_BYTE); + str_bytes + } + pylist @ PyList => { + let pylist_items = pylist.borrow_vec(); + let mut list_bytes = dump_list(pylist_items.iter(), vm)?; + list_bytes.push(LIST_BYTE); + list_bytes + } + pytuple @ PyTuple => { + let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; + tuple_bytes.push(TUPLE_BYTE); + tuple_bytes + } + pydict @ PyDict => { + let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); + // Converts list of tuples to PyObjectRefs of tuples + let elements: Vec = key_value_pairs + .into_iter() + .map(|(k, v)| PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm)) + .collect(); + // Converts list of tuples to list, dump into binary + let mut dict_bytes = dump_list(elements.iter(), vm)?; + dict_bytes.push(LIST_BYTE); + dict_bytes.push(DICT_BYTE); + dict_bytes + } co @ PyCode => { - PyBytes::from(co.code.map_clone_bag(&bytecode::BasicBag).to_bytes()) + // Code is default, doesn't have prefix. + co.code.map_clone_bag(&bytecode::BasicBag).to_bytes() } - _ => + _ => { return Err(vm.new_not_implemented_error( - "TODO: not implemented yet or marshal unsupported type".to_owned() - )), + "TODO: not implemented yet or marshal unsupported type".to_owned(), + )); + } }); Ok(r) } + #[pyfunction] + fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult { + Ok(PyBytes::from(_dumps(value, vm)?)) + } + #[pyfunction] fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let dumped = dumps(value, vm)?; @@ -30,25 +121,155 @@ mod decl { Ok(()) } + /// Read the next 4 bytes of a slice, read as u32, pass as usize. + /// Returns the rest of buffer with the value. + fn eat_length<'a>(bytes: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> { + let (u32_bytes, rest) = bytes.split_at(4); + let length = u32::from_le_bytes(u32_bytes.try_into().map_err(|_| { + vm.new_value_error("Could not read u32 size from byte array".to_owned()) + })?); + Ok((length as usize, rest)) + } + + /// Reads next element from a python list. First by getting element size + /// then by building a pybuffer and "loading" the pyobject. + /// Returns rest of buffer with object. + fn next_element_of_list<'a>( + buf: &'a [u8], + vm: &VirtualMachine, + ) -> PyResult<(PyObjectRef, &'a [u8])> { + let (element_length, element_and_rest) = eat_length(buf, vm)?; + let (element_buff, rest) = element_and_rest.split_at(element_length); + let pybuffer = PyBuffer::from_byte_vector(element_buff.to_vec(), vm); + Ok((loads(pybuffer, vm)?, rest)) + } + + /// Reads a list (or tuple) from a buffer. + fn read_list(buf: &[u8], vm: &VirtualMachine) -> PyResult> { + let (expected_array_len, mut buffer) = eat_length(buf, vm)?; + let mut elements: Vec = Vec::new(); + while !buffer.is_empty() { + let (element, rest_of_buffer) = next_element_of_list(buffer, vm)?; + elements.push(element); + buffer = rest_of_buffer; + } + debug_assert!(expected_array_len == elements.len()); + Ok(elements) + } + + /// Builds a PyDict from iterator of tuple objects + pub fn from_tuples(iterable: Iter, vm: &VirtualMachine) -> PyResult { + let dict = DictContentType::default(); + for elem in iterable { + let items = match_class!(match elem.clone() { + pytuple @ PyTuple => pytuple.as_slice().to_vec(), + _ => + return Err(vm.new_value_error( + "Couldn't unmarshal key:value pair of dictionary".to_owned() + )), + }); + // Marshalled tuples are always in format key:value. + dict.insert( + vm, + items.get(0).unwrap().clone(), + items.get(1).unwrap().clone(), + )?; + } + Ok(PyDict::from_entries(dict)) + } + #[pyfunction] - fn loads(code_bytes: ArgBytesLike, vm: &VirtualMachine) -> PyResult { - let buf = &*code_bytes.borrow_buf(); - let code = bytecode::CodeObject::from_bytes(buf).map_err(|e| match e { - bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( + fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult { + let full_buff = pybuffer.as_contiguous().ok_or_else(|| { + vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) + })?; + let (type_indicator, buf) = full_buff.split_last().ok_or_else(|| { + vm.new_exception_msg( vm.ctx.exceptions.eof_error.clone(), - "end of file while deserializing bytecode".to_owned(), - ), - _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), + "EOF where object expected.".to_owned(), + ) })?; - Ok(PyCode { - code: vm.map_codeobj(code), - }) + match *type_indicator { + INT_BYTE => { + let (sign_byte, uint_bytes) = buf + .split_first() + .ok_or_else(|| vm.new_value_error("EOF where object expected.".to_owned()))?; + let sign = match sign_byte { + b'-' => Sign::Minus, + b'0' => Sign::NoSign, + b'+' => Sign::Plus, + _ => { + return Err(vm.new_value_error( + "Unknown sign byte when trying to unmarshal integer".to_owned(), + )) + } + }; + let pyint = BigInt::from_bytes_le(sign, uint_bytes); + Ok(pyint.into_pyobject(vm)) + } + FLOAT_BYTE => { + let number = f64::from_le_bytes(match buf[..].try_into() { + Ok(byte_array) => byte_array, + Err(e) => { + return Err(vm.new_value_error(format!( + "Expected float, could not load from bytes. {}", + e + ))) + } + }); + let pyfloat = PyFloat::from(number); + Ok(pyfloat.into_pyobject(vm)) + } + STR_BYTE => { + let pystr = PyStr::from(match AsciiStr::from_ascii(buf) { + Ok(ascii_str) => ascii_str, + Err(e) => { + return Err( + vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e)) + ) + } + }); + Ok(pystr.into_pyobject(vm)) + } + LIST_BYTE => { + let elements = read_list(buf, vm)?; + Ok(elements.into_pyobject(vm)) + } + TUPLE_BYTE => { + let elements = read_list(buf, vm)?; + let pytuple = PyTuple::new_ref(elements, &vm.ctx).into_pyobject(vm); + Ok(pytuple) + } + DICT_BYTE => { + let pybuffer = PyBuffer::from_byte_vector(buf[..].to_vec(), vm); + let pydict = match_class!(match loads(pybuffer, vm)? { + pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, + _ => + return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())), + }); + Ok(pydict.into_pyobject(vm)) + } + _ => { + // If prefix is not identifiable, assume CodeObject, error out if it doesn't match. + let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e { + bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( + vm.ctx.exceptions.eof_error.clone(), + "End of file while deserializing bytecode".to_owned(), + ), + _ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), + })?; + Ok(PyCode { + code: vm.map_codeobj(code), + } + .into_pyobject(vm)) + } + } } #[pyfunction] - fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { + fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult { let read_res = vm.call_method(&f, "read", ())?; let bytes = ArgBytesLike::try_from_object(vm, read_res)?; - loads(bytes, vm) + loads(PyBuffer::from(bytes), vm) } }