Skip to content

a few more marshal #4006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions Lib/test/test_marshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def test_bool(self):
self.helper(b)

class FloatTestCase(unittest.TestCase, HelperMixin):
# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_floats(self):
# Test a few floats
small = 1e-25
Expand Down Expand Up @@ -101,8 +99,6 @@ def test_string(self):
for s in ["", "Andr\xe8 Previn", "abc", " "*10000]:
self.helper(s)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_bytes(self):
for s in [b"", b"Andr\xe8 Previn", b"abc", b" "*10000]:
self.helper(s)
Expand Down Expand Up @@ -202,14 +198,11 @@ def test_patch_873224(self):
self.assertRaises(Exception, marshal.loads, b'f')
self.assertRaises(Exception, marshal.loads, marshal.dumps(2**65)[:-1])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_version_argument(self):
# Python 2.4.0 crashes for any call to marshal.dumps(x, y)
self.assertEqual(marshal.loads(marshal.dumps(5, 0)), 5)
self.assertEqual(marshal.loads(marshal.dumps(5, 1)), 5)

@unittest.skip("TODO: RUSTPYTHON; panic")
def test_fuzz(self):
# simple test that it's at least not *totally* trivial to
# crash from bad marshal data
Expand Down Expand Up @@ -337,8 +330,6 @@ def readinto(self, buf):
self.assertRaises(ValueError, marshal.load,
BadReader(marshal.dumps(value)))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_eof(self):
data = marshal.dumps(("hello", "dolly", None))
for i in range(len(data)):
Expand Down Expand Up @@ -509,8 +500,7 @@ def testModule(self):
self.helper(code)
self.helper3(code)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.skip("TODO: RUSTPYTHON")
def testRecursion(self):
obj = 1.2345
d = {"hello": obj, "goodbye": obj, obj: "hello"}
Expand All @@ -529,23 +519,15 @@ def _test(self, version):
data = marshal.dumps(code, version)
marshal.loads(data)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test0To3(self):
self._test(0)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test1To3(self):
self._test(1)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test2To3(self):
self._test(2)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test3To3(self):
self._test(3)

Expand All @@ -562,8 +544,6 @@ def testIntern(self):
s2 = sys.intern(s)
self.assertEqual(id(s2), id(s))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testNoIntern(self):
s = marshal.loads(marshal.dumps(self.strobj, 2))
self.assertEqual(s, self.strobj)
Expand Down
212 changes: 125 additions & 87 deletions vm/src/stdlib/marshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,22 @@ mod decl {
},
bytecode,
convert::ToPyObject,
function::ArgBytesLike,
function::{ArgBytesLike, OptionalArg},
object::AsObject,
protocol::PyBuffer,
PyObjectRef, PyResult, TryFromObject, VirtualMachine,
};
/// TODO
/// PyBytes: Currently getting recursion error with match_class!
use num_bigint::{BigInt, Sign};
use num_traits::Zero;

#[repr(u8)]
enum Type {
// Null = b'0',
// None = b'N',
None = b'N',
False = b'F',
True = b'T',
// StopIter = b'S',
// Ellipsis = b'.',
Ellipsis = b'.',
Int = b'i',
Float = b'g',
// Complex = b'y',
Expand All @@ -38,11 +36,11 @@ mod decl {
List = b'[',
Dict = b'{',
Code = b'c',
Str = b'u', // = TYPE_UNICODE
Unicode = b'u',
// Unknown = b'?',
Set = b'<',
FrozenSet = b'>',
// Ascii = b'a',
Ascii = b'a',
// AsciiInterned = b'A',
// SmallTuple = b')',
// ShortAscii = b'z',
Expand All @@ -56,11 +54,11 @@ mod decl {
use Type::*;
Ok(match value {
// b'0' => Null,
// b'N' => None,
b'N' => None,
b'F' => False,
b'T' => True,
// b'S' => StopIter,
// b'.' => Ellipsis,
b'.' => Ellipsis,
b'i' => Int,
b'g' => Float,
// b'y' => Complex,
Expand All @@ -72,11 +70,11 @@ mod decl {
b'[' => List,
b'{' => Dict,
b'c' => Code,
b'u' => Str,
b'u' => Unicode,
// b'?' => Unknown,
b'<' => Set,
b'>' => FrozenSet,
// b'a' => Ascii,
b'a' => Ascii,
// b'A' => AsciiInterned,
// b')' => SmallTuple,
// b'z' => ShortAscii,
Expand All @@ -86,6 +84,9 @@ mod decl {
}
}

#[pyattr(name = "version")]
const VERSION: u32 = 4;

fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
Expand All @@ -109,93 +110,118 @@ mod decl {

/// Dumping helper function to turn a value into bytes.
fn dump_obj(buf: &mut Vec<u8>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
match_class!(match value {
pyint @ PyInt => {
if pyint.class().is(vm.ctx.types.bool_type) {
let typ = if pyint.as_bigint().is_zero() {
Type::False
if vm.is_none(&value) {
buf.push(Type::None as u8);
} else if value.is(&vm.ctx.ellipsis) {
buf.push(Type::Ellipsis as u8);
} else {
match_class!(match value {
pyint @ PyInt => {
if pyint.class().is(vm.ctx.types.bool_type) {
let typ = if pyint.as_bigint().is_zero() {
Type::False
} else {
Type::True
};
buf.push(typ as u8);
} else {
Type::True
};
buf.push(typ as u8);
} else {
buf.push(Type::Int as u8);
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
let mut len = int_bytes.len() as i32;
if sign == Sign::Minus {
len = -len;
buf.push(Type::Int as u8);
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
let mut len = int_bytes.len() as i32;
if sign == Sign::Minus {
len = -len;
}
buf.extend(len.to_le_bytes());
buf.extend(int_bytes);
}
buf.extend(len.to_le_bytes());
buf.extend(int_bytes);
}
}
pyfloat @ PyFloat => {
buf.push(Type::Float as u8);
buf.extend(pyfloat.to_f64().to_le_bytes());
}
pystr @ PyStr => {
buf.push(Type::Str as u8);
write_size(buf, pystr.as_str().len(), vm)?;
buf.extend(pystr.as_str().as_bytes());
}
pylist @ PyList => {
buf.push(Type::List as u8);
let pylist_items = pylist.borrow_vec();
dump_seq(buf, pylist_items.iter(), vm)?;
}
pyset @ PySet => {
buf.push(Type::Set as u8);
let elements = pyset.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pyfrozen @ PyFrozenSet => {
buf.push(Type::FrozenSet as u8);
let elements = pyfrozen.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pytuple @ PyTuple => {
buf.push(Type::Tuple as u8);
dump_seq(buf, pytuple.iter(), vm)?;
}
pydict @ PyDict => {
buf.push(Type::Dict as u8);
write_size(buf, pydict.len(), vm)?;
for (key, value) in pydict {
dump_obj(buf, key, vm)?;
dump_obj(buf, value, vm)?;
pyfloat @ PyFloat => {
buf.push(Type::Float as u8);
buf.extend(pyfloat.to_f64().to_le_bytes());
}
}
bytes @ PyByteArray => {
buf.push(Type::Bytes as u8);
let data = bytes.borrow_buf();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
co @ PyCode => {
buf.push(Type::Code as u8);
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
write_size(buf, bytes.len(), vm)?;
buf.extend(bytes);
}
_ => {
return Err(vm.new_not_implemented_error(
"TODO: not implemented yet or marshal unsupported type".to_owned(),
));
}
});
pystr @ PyStr => {
buf.push(if pystr.is_ascii() {
Type::Ascii
} else {
Type::Unicode
} as u8);
write_size(buf, pystr.as_str().len(), vm)?;
buf.extend(pystr.as_str().as_bytes());
}
pylist @ PyList => {
buf.push(Type::List as u8);
let pylist_items = pylist.borrow_vec();
dump_seq(buf, pylist_items.iter(), vm)?;
}
pyset @ PySet => {
buf.push(Type::Set as u8);
let elements = pyset.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pyfrozen @ PyFrozenSet => {
buf.push(Type::FrozenSet as u8);
let elements = pyfrozen.elements();
dump_seq(buf, elements.iter(), vm)?;
}
pytuple @ PyTuple => {
buf.push(Type::Tuple as u8);
dump_seq(buf, pytuple.iter(), vm)?;
}
pydict @ PyDict => {
buf.push(Type::Dict as u8);
write_size(buf, pydict.len(), vm)?;
for (key, value) in pydict {
dump_obj(buf, key, vm)?;
dump_obj(buf, value, vm)?;
}
}
bytes @ PyBytes => {
buf.push(Type::Bytes as u8);
let data = bytes.as_bytes();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
bytes @ PyByteArray => {
buf.push(Type::Bytes as u8);
let data = bytes.borrow_buf();
write_size(buf, data.len(), vm)?;
buf.extend(&*data);
}
co @ PyCode => {
buf.push(Type::Code as u8);
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
write_size(buf, bytes.len(), vm)?;
buf.extend(bytes);
}
_ => {
return Err(vm.new_not_implemented_error(
"TODO: not implemented yet or marshal unsupported type".to_owned(),
));
}
})
}
Ok(())
}

#[pyfunction]
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> {
fn dumps(
value: PyObjectRef,
_version: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<PyBytes> {
let mut buf = Vec::new();
dump_obj(&mut buf, value, vm)?;
Ok(PyBytes::from(buf))
}

#[pyfunction]
fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
let dumped = dumps(value, vm)?;
fn dump(
value: PyObjectRef,
f: PyObjectRef,
version: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<()> {
let dumped = dumps(value, version, vm)?;
vm.call_method(&f, "write", (dumped,))?;
Ok(())
}
Expand Down Expand Up @@ -248,8 +274,10 @@ mod decl {
let typ = Type::try_from(*type_indicator)
.map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?;
let (obj, buf) = match typ {
Type::True => ((true).to_pyobject(vm), buf),
Type::False => ((false).to_pyobject(vm), buf),
Type::True => (true.to_pyobject(vm), buf),
Type::False => (false.to_pyobject(vm), buf),
Type::None => (vm.ctx.none(), buf),
Type::Ellipsis => (vm.ctx.ellipsis(), buf),
Type::Int => {
if buf.len() < 4 {
return Err(too_short_error(vm));
Expand All @@ -276,7 +304,17 @@ mod decl {
let number = f64::from_le_bytes(bytes.try_into().unwrap());
(vm.ctx.new_float(number).into(), buf)
}
Type::Str => {
Type::Ascii => {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let s = String::from_utf8(bytes.to_vec())
.map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?;
(s.to_pyobject(vm), buf)
}
Type::Unicode => {
let (len, buf) = read_size(buf, vm)?;
if buf.len() < len {
return Err(too_short_error(vm));
Expand Down