-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Marshaling support for ints, floats, strs, lists, dicts, tuples #3506
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import unittest | ||
import marshal | ||
|
||
class MarshalTests(unittest.TestCase): | ||
""" | ||
Testing the (incomplete) marshal module. | ||
""" | ||
|
||
def dump_then_load(self, data): | ||
return marshal.loads(marshal.dumps(data)) | ||
|
||
def test_dump_and_load_int(self): | ||
self.assertEqual(self.dump_then_load(0), 0) | ||
self.assertEqual(self.dump_then_load(-1), -1) | ||
self.assertEqual(self.dump_then_load(1), 1) | ||
self.assertEqual(self.dump_then_load(100000000), 100000000) | ||
|
||
def test_dump_and_load_int(self): | ||
self.assertEqual(self.dump_then_load(0.0), 0.0) | ||
self.assertEqual(self.dump_then_load(-10.0), -10.0) | ||
self.assertEqual(self.dump_then_load(10), 10) | ||
|
||
def test_dump_and_load_str(self): | ||
self.assertEqual(self.dump_then_load(""), "") | ||
self.assertEqual(self.dump_then_load("Hello, World"), "Hello, World") | ||
|
||
def test_dump_and_load_list(self): | ||
self.assertEqual(self.dump_then_load([]), []) | ||
self.assertEqual(self.dump_then_load([1, "hello", 1.0]), [1, "hello", 1.0]) | ||
self.assertEqual(self.dump_then_load([[0], ['a','b']]),[[0], ['a','b']]) | ||
|
||
def test_dump_and_load_tuple(self): | ||
self.assertEqual(self.dump_then_load(()), ()) | ||
self.assertEqual(self.dump_then_load((1, "hello", 1.0)), (1, "hello", 1.0)) | ||
|
||
def test_dump_and_load_dict(self): | ||
self.assertEqual(self.dump_then_load({}), {}) | ||
self.assertEqual(self.dump_then_load({'a':1, 1:'a'}), {'a':1, 1:'a'}) | ||
self.assertEqual(self.dump_then_load({'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}), {'a':{'b':2}, 'c':[0.0, 4.0, 6, 9]}) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -114,6 +114,22 @@ struct DictEntry<T> { | |
value: T, | ||
} | ||
|
||
impl<T: Clone> DictEntry<T> { | ||
pub(crate) fn as_tuple(&self) -> (PyObjectRef, T) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need this method because the next functions is not required. But for information, |
||
(self.key.clone(), self.value.clone()) | ||
} | ||
} | ||
|
||
impl<T: Clone> Dict<T> { | ||
pub(crate) fn as_kvpairs(&self) -> Vec<(PyObjectRef, T)> { | ||
let entries = &self.inner.read().entries; | ||
entries | ||
.iter() | ||
.filter_map(|entry| entry.as_ref().map(|dict_entry| dict_entry.as_tuple())) | ||
.collect() | ||
} | ||
} | ||
|
||
#[derive(Debug, PartialEq)] | ||
pub struct DictSize { | ||
indices_size: usize, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ use crate::{ | |
}, | ||
sliceable::wrap_index, | ||
types::{Constructor, Unconstructible}, | ||
PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, | ||
PyObject, PyObjectPayload, PyObjectRef, PyObjectView, PyObjectWrap, PyRef, PyResult, PyValue, | ||
TryFromBorrowedObject, TypeProtocol, VirtualMachine, | ||
}; | ||
use std::{borrow::Cow, fmt::Debug, ops::Range}; | ||
|
@@ -63,6 +63,15 @@ impl PyBuffer { | |
.then(|| unsafe { self.contiguous_mut_unchecked() }) | ||
} | ||
|
||
pub fn from_byte_vector(bytes: Vec<u8>, vm: &VirtualMachine) -> Self { | ||
let bytes_len = bytes.len(); | ||
PyBuffer::new( | ||
PyValue::into_object(VecBuffer::from(bytes), vm), | ||
BufferDescriptor::simple(bytes_len, true), | ||
&VEC_BUFFER_METHODS, | ||
) | ||
} | ||
Comment on lines
+66
to
+73
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qingshi163 could you check whether this is correct usage of VecBuffer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it could be simplify as:
|
||
|
||
/// # Safety | ||
/// assume the buffer is contiguous | ||
pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<[u8]> { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,53 +2,274 @@ pub(crate) use decl::make_module; | |
|
||
#[pymodule(name = "marshal")] | ||
mod decl { | ||
/// TODO add support for Booleans, Sets, etc | ||
use ascii::AsciiStr; | ||
use num_bigint::{BigInt, Sign}; | ||
use std::slice::Iter; | ||
|
||
use crate::{ | ||
builtins::{PyBytes, PyCode}, | ||
builtins::{ | ||
dict::DictContentType, PyBytes, PyCode, PyDict, PyFloat, PyInt, PyList, PyStr, PyTuple, | ||
}, | ||
bytecode, | ||
function::ArgBytesLike, | ||
function::{ArgBytesLike, IntoPyObject}, | ||
protocol::PyBuffer, | ||
PyObjectRef, PyResult, TryFromObject, VirtualMachine, | ||
}; | ||
|
||
#[pyfunction] | ||
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> { | ||
const STR_BYTE: u8 = b's'; | ||
const INT_BYTE: u8 = b'i'; | ||
const FLOAT_BYTE: u8 = b'f'; | ||
const LIST_BYTE: u8 = b'['; | ||
const TUPLE_BYTE: u8 = b'('; | ||
const DICT_BYTE: u8 = b','; | ||
|
||
/// Safely convert usize to 4 le bytes | ||
fn size_to_bytes(x: usize, vm: &VirtualMachine) -> PyResult<[u8; 4]> { | ||
// For marshalling we want to convert lengths to bytes. To save space | ||
// we limit the size to u32 to keep marshalling smaller. | ||
match u32::try_from(x) { | ||
Ok(n) => Ok(n.to_le_bytes()), | ||
Err(_) => { | ||
Err(vm.new_value_error("Size exceeds 2^32 capacity for marshaling.".to_owned())) | ||
} | ||
} | ||
} | ||
|
||
/// Dumps a iterator of objects into binary vector. | ||
fn dump_list(pyobjs: Iter<PyObjectRef>, vm: &VirtualMachine) -> PyResult<Vec<u8>> { | ||
let mut byte_list = size_to_bytes(pyobjs.len(), vm)?.to_vec(); | ||
// For each element, dump into binary, then add its length and value. | ||
for element in pyobjs { | ||
let element_bytes: Vec<u8> = _dumps(element.clone(), vm)?; | ||
byte_list.extend(size_to_bytes(element_bytes.len(), vm)?); | ||
byte_list.extend(element_bytes) | ||
} | ||
Ok(byte_list) | ||
} | ||
|
||
fn _dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<Vec<u8>> { | ||
let r = match_class!(match value { | ||
pyint @ PyInt => { | ||
let (sign, mut int_bytes) = pyint.as_bigint().to_bytes_le(); | ||
let sign_byte = match sign { | ||
Sign::Minus => b'-', | ||
Sign::NoSign => b'0', | ||
Sign::Plus => b'+', | ||
}; | ||
// Return as [TYPE, SIGN, uint bytes] | ||
int_bytes.insert(0, sign_byte); | ||
int_bytes.push(INT_BYTE); | ||
int_bytes | ||
} | ||
pyfloat @ PyFloat => { | ||
let mut float_bytes = pyfloat.to_f64().to_le_bytes().to_vec(); | ||
float_bytes.push(FLOAT_BYTE); | ||
float_bytes | ||
} | ||
pystr @ PyStr => { | ||
let mut str_bytes = pystr.as_str().as_bytes().to_vec(); | ||
str_bytes.push(STR_BYTE); | ||
str_bytes | ||
} | ||
pylist @ PyList => { | ||
let pylist_items = pylist.borrow_vec(); | ||
let mut list_bytes = dump_list(pylist_items.iter(), vm)?; | ||
list_bytes.push(LIST_BYTE); | ||
list_bytes | ||
} | ||
pytuple @ PyTuple => { | ||
let mut tuple_bytes = dump_list(pytuple.as_slice().iter(), vm)?; | ||
tuple_bytes.push(TUPLE_BYTE); | ||
tuple_bytes | ||
} | ||
pydict @ PyDict => { | ||
let key_value_pairs = pydict._as_dict_inner().clone().as_kvpairs(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as To iterate PyDict as key-value pair, let elements = pydict.into_iter().map(|(k, v)| PyTuple::new_ref(vec![k, v], vm).collect(); Please let me know if it doesn't work. |
||
// Converts list of tuples to PyObjectRefs of tuples | ||
let elements: Vec<PyObjectRef> = key_value_pairs | ||
.into_iter() | ||
.map(|(k, v)| PyTuple::new_ref(vec![k, v], &vm.ctx).into_pyobject(vm)) | ||
.collect(); | ||
// Converts list of tuples to list, dump into binary | ||
let mut dict_bytes = dump_list(elements.iter(), vm)?; | ||
dict_bytes.push(LIST_BYTE); | ||
dict_bytes.push(DICT_BYTE); | ||
dict_bytes | ||
} | ||
co @ PyCode => { | ||
PyBytes::from(co.code.map_clone_bag(&bytecode::BasicBag).to_bytes()) | ||
// Code is default, doesn't have prefix. | ||
co.code.map_clone_bag(&bytecode::BasicBag).to_bytes() | ||
} | ||
_ => | ||
_ => { | ||
return Err(vm.new_not_implemented_error( | ||
"TODO: not implemented yet or marshal unsupported type".to_owned() | ||
)), | ||
"TODO: not implemented yet or marshal unsupported type".to_owned(), | ||
)); | ||
} | ||
}); | ||
Ok(r) | ||
} | ||
|
||
#[pyfunction] | ||
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> { | ||
Ok(PyBytes::from(_dumps(value, vm)?)) | ||
} | ||
|
||
#[pyfunction] | ||
fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { | ||
let dumped = dumps(value, vm)?; | ||
vm.call_method(&f, "write", (dumped,))?; | ||
Ok(()) | ||
} | ||
|
||
/// Read the next 4 bytes of a slice, read as u32, pass as usize. | ||
/// Returns the rest of buffer with the value. | ||
fn eat_length<'a>(bytes: &'a [u8], vm: &VirtualMachine) -> PyResult<(usize, &'a [u8])> { | ||
let (u32_bytes, rest) = bytes.split_at(4); | ||
let length = u32::from_le_bytes(u32_bytes.try_into().map_err(|_| { | ||
vm.new_value_error("Could not read u32 size from byte array".to_owned()) | ||
})?); | ||
Ok((length as usize, rest)) | ||
} | ||
|
||
/// Reads next element from a python list. First by getting element size | ||
/// then by building a pybuffer and "loading" the pyobject. | ||
/// Returns rest of buffer with object. | ||
fn next_element_of_list<'a>( | ||
buf: &'a [u8], | ||
vm: &VirtualMachine, | ||
) -> PyResult<(PyObjectRef, &'a [u8])> { | ||
let (element_length, element_and_rest) = eat_length(buf, vm)?; | ||
let (element_buff, rest) = element_and_rest.split_at(element_length); | ||
let pybuffer = PyBuffer::from_byte_vector(element_buff.to_vec(), vm); | ||
Ok((loads(pybuffer, vm)?, rest)) | ||
} | ||
|
||
/// Reads a list (or tuple) from a buffer. | ||
fn read_list(buf: &[u8], vm: &VirtualMachine) -> PyResult<Vec<PyObjectRef>> { | ||
let (expected_array_len, mut buffer) = eat_length(buf, vm)?; | ||
let mut elements: Vec<PyObjectRef> = Vec::new(); | ||
while !buffer.is_empty() { | ||
let (element, rest_of_buffer) = next_element_of_list(buffer, vm)?; | ||
elements.push(element); | ||
buffer = rest_of_buffer; | ||
} | ||
debug_assert!(expected_array_len == elements.len()); | ||
Ok(elements) | ||
} | ||
|
||
/// Builds a PyDict from iterator of tuple objects | ||
pub fn from_tuples(iterable: Iter<PyObjectRef>, vm: &VirtualMachine) -> PyResult<PyDict> { | ||
let dict = DictContentType::default(); | ||
for elem in iterable { | ||
let items = match_class!(match elem.clone() { | ||
pytuple @ PyTuple => pytuple.as_slice().to_vec(), | ||
_ => | ||
return Err(vm.new_value_error( | ||
"Couldn't unmarshal key:value pair of dictionary".to_owned() | ||
)), | ||
}); | ||
// Marshalled tuples are always in format key:value. | ||
dict.insert( | ||
vm, | ||
items.get(0).unwrap().clone(), | ||
items.get(1).unwrap().clone(), | ||
)?; | ||
} | ||
Ok(PyDict::from_entries(dict)) | ||
} | ||
|
||
#[pyfunction] | ||
fn loads(code_bytes: ArgBytesLike, vm: &VirtualMachine) -> PyResult<PyCode> { | ||
let buf = &*code_bytes.borrow_buf(); | ||
let code = bytecode::CodeObject::from_bytes(buf).map_err(|e| match e { | ||
bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( | ||
fn loads(pybuffer: PyBuffer, vm: &VirtualMachine) -> PyResult<PyObjectRef> { | ||
let full_buff = pybuffer.as_contiguous().ok_or_else(|| { | ||
vm.new_buffer_error("Buffer provided to marshal.loads() is not contiguous".to_owned()) | ||
})?; | ||
let (type_indicator, buf) = full_buff.split_last().ok_or_else(|| { | ||
vm.new_exception_msg( | ||
vm.ctx.exceptions.eof_error.clone(), | ||
"end of file while deserializing bytecode".to_owned(), | ||
), | ||
_ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), | ||
"EOF where object expected.".to_owned(), | ||
) | ||
})?; | ||
Ok(PyCode { | ||
code: vm.map_codeobj(code), | ||
}) | ||
match *type_indicator { | ||
INT_BYTE => { | ||
let (sign_byte, uint_bytes) = buf | ||
.split_first() | ||
.ok_or_else(|| vm.new_value_error("EOF where object expected.".to_owned()))?; | ||
let sign = match sign_byte { | ||
b'-' => Sign::Minus, | ||
b'0' => Sign::NoSign, | ||
b'+' => Sign::Plus, | ||
_ => { | ||
return Err(vm.new_value_error( | ||
"Unknown sign byte when trying to unmarshal integer".to_owned(), | ||
)) | ||
} | ||
}; | ||
let pyint = BigInt::from_bytes_le(sign, uint_bytes); | ||
Ok(pyint.into_pyobject(vm)) | ||
} | ||
FLOAT_BYTE => { | ||
let number = f64::from_le_bytes(match buf[..].try_into() { | ||
Ok(byte_array) => byte_array, | ||
Err(e) => { | ||
return Err(vm.new_value_error(format!( | ||
"Expected float, could not load from bytes. {}", | ||
e | ||
))) | ||
} | ||
}); | ||
let pyfloat = PyFloat::from(number); | ||
Ok(pyfloat.into_pyobject(vm)) | ||
} | ||
STR_BYTE => { | ||
let pystr = PyStr::from(match AsciiStr::from_ascii(buf) { | ||
Ok(ascii_str) => ascii_str, | ||
Err(e) => { | ||
return Err( | ||
vm.new_value_error(format!("Cannot unmarshal bytes to string, {}", e)) | ||
) | ||
} | ||
}); | ||
Ok(pystr.into_pyobject(vm)) | ||
} | ||
LIST_BYTE => { | ||
let elements = read_list(buf, vm)?; | ||
Ok(elements.into_pyobject(vm)) | ||
} | ||
TUPLE_BYTE => { | ||
let elements = read_list(buf, vm)?; | ||
let pytuple = PyTuple::new_ref(elements, &vm.ctx).into_pyobject(vm); | ||
Ok(pytuple) | ||
} | ||
DICT_BYTE => { | ||
let pybuffer = PyBuffer::from_byte_vector(buf[..].to_vec(), vm); | ||
let pydict = match_class!(match loads(pybuffer, vm)? { | ||
pylist @ PyList => from_tuples(pylist.borrow_vec().iter(), vm)?, | ||
_ => | ||
return Err(vm.new_value_error("Couldn't unmarshal dicitionary.".to_owned())), | ||
}); | ||
Ok(pydict.into_pyobject(vm)) | ||
} | ||
_ => { | ||
// If prefix is not identifiable, assume CodeObject, error out if it doesn't match. | ||
let code = bytecode::CodeObject::from_bytes(&full_buff).map_err(|e| match e { | ||
bytecode::CodeDeserializeError::Eof => vm.new_exception_msg( | ||
vm.ctx.exceptions.eof_error.clone(), | ||
"End of file while deserializing bytecode".to_owned(), | ||
), | ||
_ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()), | ||
})?; | ||
Ok(PyCode { | ||
code: vm.map_codeobj(code), | ||
} | ||
.into_pyobject(vm)) | ||
} | ||
} | ||
} | ||
|
||
#[pyfunction] | ||
fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyCode> { | ||
fn load(f: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> { | ||
let read_res = vm.call_method(&f, "read", ())?; | ||
let bytes = ArgBytesLike::try_from_object(vm, read_res)?; | ||
loads(bytes, vm) | ||
loads(PyBuffer::from(bytes), vm) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we are encapsulating
DictContentType
from outside. If we can avoid to expose this type as more as possible, I would like it.Could you check if you can refactor
PyDict::merge_object
a little bit to expose merge from iterator part? it starts from line 108.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still working on this, everything else is updated. Thanks for all the feedback btw!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great! please let me know if you meet any blocker