Skip to content

Commit eb9a5fa

Browse files
authored
Merge pull request #4006 from youknowone/marshal
a few more marshal
2 parents ba766e1 + ec9fa50 commit eb9a5fa

File tree

2 files changed

+126
-108
lines changed

2 files changed

+126
-108
lines changed

Lib/test/test_marshal.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ def test_bool(self):
6464
self.helper(b)
6565

6666
class FloatTestCase(unittest.TestCase, HelperMixin):
67-
# TODO: RUSTPYTHON
68-
@unittest.expectedFailure
6967
def test_floats(self):
7068
# Test a few floats
7169
small = 1e-25
@@ -101,8 +99,6 @@ def test_string(self):
10199
for s in ["", "Andr\xe8 Previn", "abc", " "*10000]:
102100
self.helper(s)
103101

104-
# TODO: RUSTPYTHON
105-
@unittest.expectedFailure
106102
def test_bytes(self):
107103
for s in [b"", b"Andr\xe8 Previn", b"abc", b" "*10000]:
108104
self.helper(s)
@@ -202,14 +198,11 @@ def test_patch_873224(self):
202198
self.assertRaises(Exception, marshal.loads, b'f')
203199
self.assertRaises(Exception, marshal.loads, marshal.dumps(2**65)[:-1])
204200

205-
# TODO: RUSTPYTHON
206-
@unittest.expectedFailure
207201
def test_version_argument(self):
208202
# Python 2.4.0 crashes for any call to marshal.dumps(x, y)
209203
self.assertEqual(marshal.loads(marshal.dumps(5, 0)), 5)
210204
self.assertEqual(marshal.loads(marshal.dumps(5, 1)), 5)
211205

212-
@unittest.skip("TODO: RUSTPYTHON; panic")
213206
def test_fuzz(self):
214207
# simple test that it's at least not *totally* trivial to
215208
# crash from bad marshal data
@@ -337,8 +330,6 @@ def readinto(self, buf):
337330
self.assertRaises(ValueError, marshal.load,
338331
BadReader(marshal.dumps(value)))
339332

340-
# TODO: RUSTPYTHON
341-
@unittest.expectedFailure
342333
def test_eof(self):
343334
data = marshal.dumps(("hello", "dolly", None))
344335
for i in range(len(data)):
@@ -509,8 +500,7 @@ def testModule(self):
509500
self.helper(code)
510501
self.helper3(code)
511502

512-
# TODO: RUSTPYTHON
513-
@unittest.expectedFailure
503+
@unittest.skip("TODO: RUSTPYTHON")
514504
def testRecursion(self):
515505
obj = 1.2345
516506
d = {"hello": obj, "goodbye": obj, obj: "hello"}
@@ -529,23 +519,15 @@ def _test(self, version):
529519
data = marshal.dumps(code, version)
530520
marshal.loads(data)
531521

532-
# TODO: RUSTPYTHON
533-
@unittest.expectedFailure
534522
def test0To3(self):
535523
self._test(0)
536524

537-
# TODO: RUSTPYTHON
538-
@unittest.expectedFailure
539525
def test1To3(self):
540526
self._test(1)
541527

542-
# TODO: RUSTPYTHON
543-
@unittest.expectedFailure
544528
def test2To3(self):
545529
self._test(2)
546530

547-
# TODO: RUSTPYTHON
548-
@unittest.expectedFailure
549531
def test3To3(self):
550532
self._test(3)
551533

@@ -562,8 +544,6 @@ def testIntern(self):
562544
s2 = sys.intern(s)
563545
self.assertEqual(id(s2), id(s))
564546

565-
# TODO: RUSTPYTHON
566-
@unittest.expectedFailure
567547
def testNoIntern(self):
568548
s = marshal.loads(marshal.dumps(self.strobj, 2))
569549
self.assertEqual(s, self.strobj)

vm/src/stdlib/marshal.rs

+125-87
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,22 @@ mod decl {
99
},
1010
bytecode,
1111
convert::ToPyObject,
12-
function::ArgBytesLike,
12+
function::{ArgBytesLike, OptionalArg},
1313
object::AsObject,
1414
protocol::PyBuffer,
1515
PyObjectRef, PyResult, TryFromObject, VirtualMachine,
1616
};
17-
/// TODO
18-
/// PyBytes: Currently getting recursion error with match_class!
1917
use num_bigint::{BigInt, Sign};
2018
use num_traits::Zero;
2119

2220
#[repr(u8)]
2321
enum Type {
2422
// Null = b'0',
25-
// None = b'N',
23+
None = b'N',
2624
False = b'F',
2725
True = b'T',
2826
// StopIter = b'S',
29-
// Ellipsis = b'.',
27+
Ellipsis = b'.',
3028
Int = b'i',
3129
Float = b'g',
3230
// Complex = b'y',
@@ -38,11 +36,11 @@ mod decl {
3836
List = b'[',
3937
Dict = b'{',
4038
Code = b'c',
41-
Str = b'u', // = TYPE_UNICODE
39+
Unicode = b'u',
4240
// Unknown = b'?',
4341
Set = b'<',
4442
FrozenSet = b'>',
45-
// Ascii = b'a',
43+
Ascii = b'a',
4644
// AsciiInterned = b'A',
4745
// SmallTuple = b')',
4846
// ShortAscii = b'z',
@@ -56,11 +54,11 @@ mod decl {
5654
use Type::*;
5755
Ok(match value {
5856
// b'0' => Null,
59-
// b'N' => None,
57+
b'N' => None,
6058
b'F' => False,
6159
b'T' => True,
6260
// b'S' => StopIter,
63-
// b'.' => Ellipsis,
61+
b'.' => Ellipsis,
6462
b'i' => Int,
6563
b'g' => Float,
6664
// b'y' => Complex,
@@ -72,11 +70,11 @@ mod decl {
7270
b'[' => List,
7371
b'{' => Dict,
7472
b'c' => Code,
75-
b'u' => Str,
73+
b'u' => Unicode,
7674
// b'?' => Unknown,
7775
b'<' => Set,
7876
b'>' => FrozenSet,
79-
// b'a' => Ascii,
77+
b'a' => Ascii,
8078
// b'A' => AsciiInterned,
8179
// b')' => SmallTuple,
8280
// b'z' => ShortAscii,
@@ -86,6 +84,9 @@ mod decl {
8684
}
8785
}
8886

87+
#[pyattr(name = "version")]
88+
const VERSION: u32 = 4;
89+
8990
fn too_short_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
9091
vm.new_exception_msg(
9192
vm.ctx.exceptions.eof_error.to_owned(),
@@ -109,93 +110,118 @@ mod decl {
109110

110111
/// Dumping helper function to turn a value into bytes.
111112
fn dump_obj(buf: &mut Vec<u8>, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
112-
match_class!(match value {
113-
pyint @ PyInt => {
114-
if pyint.class().is(vm.ctx.types.bool_type) {
115-
let typ = if pyint.as_bigint().is_zero() {
116-
Type::False
113+
if vm.is_none(&value) {
114+
buf.push(Type::None as u8);
115+
} else if value.is(&vm.ctx.ellipsis) {
116+
buf.push(Type::Ellipsis as u8);
117+
} else {
118+
match_class!(match value {
119+
pyint @ PyInt => {
120+
if pyint.class().is(vm.ctx.types.bool_type) {
121+
let typ = if pyint.as_bigint().is_zero() {
122+
Type::False
123+
} else {
124+
Type::True
125+
};
126+
buf.push(typ as u8);
117127
} else {
118-
Type::True
119-
};
120-
buf.push(typ as u8);
121-
} else {
122-
buf.push(Type::Int as u8);
123-
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
124-
let mut len = int_bytes.len() as i32;
125-
if sign == Sign::Minus {
126-
len = -len;
128+
buf.push(Type::Int as u8);
129+
let (sign, int_bytes) = pyint.as_bigint().to_bytes_le();
130+
let mut len = int_bytes.len() as i32;
131+
if sign == Sign::Minus {
132+
len = -len;
133+
}
134+
buf.extend(len.to_le_bytes());
135+
buf.extend(int_bytes);
127136
}
128-
buf.extend(len.to_le_bytes());
129-
buf.extend(int_bytes);
130137
}
131-
}
132-
pyfloat @ PyFloat => {
133-
buf.push(Type::Float as u8);
134-
buf.extend(pyfloat.to_f64().to_le_bytes());
135-
}
136-
pystr @ PyStr => {
137-
buf.push(Type::Str as u8);
138-
write_size(buf, pystr.as_str().len(), vm)?;
139-
buf.extend(pystr.as_str().as_bytes());
140-
}
141-
pylist @ PyList => {
142-
buf.push(Type::List as u8);
143-
let pylist_items = pylist.borrow_vec();
144-
dump_seq(buf, pylist_items.iter(), vm)?;
145-
}
146-
pyset @ PySet => {
147-
buf.push(Type::Set as u8);
148-
let elements = pyset.elements();
149-
dump_seq(buf, elements.iter(), vm)?;
150-
}
151-
pyfrozen @ PyFrozenSet => {
152-
buf.push(Type::FrozenSet as u8);
153-
let elements = pyfrozen.elements();
154-
dump_seq(buf, elements.iter(), vm)?;
155-
}
156-
pytuple @ PyTuple => {
157-
buf.push(Type::Tuple as u8);
158-
dump_seq(buf, pytuple.iter(), vm)?;
159-
}
160-
pydict @ PyDict => {
161-
buf.push(Type::Dict as u8);
162-
write_size(buf, pydict.len(), vm)?;
163-
for (key, value) in pydict {
164-
dump_obj(buf, key, vm)?;
165-
dump_obj(buf, value, vm)?;
138+
pyfloat @ PyFloat => {
139+
buf.push(Type::Float as u8);
140+
buf.extend(pyfloat.to_f64().to_le_bytes());
166141
}
167-
}
168-
bytes @ PyByteArray => {
169-
buf.push(Type::Bytes as u8);
170-
let data = bytes.borrow_buf();
171-
write_size(buf, data.len(), vm)?;
172-
buf.extend(&*data);
173-
}
174-
co @ PyCode => {
175-
buf.push(Type::Code as u8);
176-
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
177-
write_size(buf, bytes.len(), vm)?;
178-
buf.extend(bytes);
179-
}
180-
_ => {
181-
return Err(vm.new_not_implemented_error(
182-
"TODO: not implemented yet or marshal unsupported type".to_owned(),
183-
));
184-
}
185-
});
142+
pystr @ PyStr => {
143+
buf.push(if pystr.is_ascii() {
144+
Type::Ascii
145+
} else {
146+
Type::Unicode
147+
} as u8);
148+
write_size(buf, pystr.as_str().len(), vm)?;
149+
buf.extend(pystr.as_str().as_bytes());
150+
}
151+
pylist @ PyList => {
152+
buf.push(Type::List as u8);
153+
let pylist_items = pylist.borrow_vec();
154+
dump_seq(buf, pylist_items.iter(), vm)?;
155+
}
156+
pyset @ PySet => {
157+
buf.push(Type::Set as u8);
158+
let elements = pyset.elements();
159+
dump_seq(buf, elements.iter(), vm)?;
160+
}
161+
pyfrozen @ PyFrozenSet => {
162+
buf.push(Type::FrozenSet as u8);
163+
let elements = pyfrozen.elements();
164+
dump_seq(buf, elements.iter(), vm)?;
165+
}
166+
pytuple @ PyTuple => {
167+
buf.push(Type::Tuple as u8);
168+
dump_seq(buf, pytuple.iter(), vm)?;
169+
}
170+
pydict @ PyDict => {
171+
buf.push(Type::Dict as u8);
172+
write_size(buf, pydict.len(), vm)?;
173+
for (key, value) in pydict {
174+
dump_obj(buf, key, vm)?;
175+
dump_obj(buf, value, vm)?;
176+
}
177+
}
178+
bytes @ PyBytes => {
179+
buf.push(Type::Bytes as u8);
180+
let data = bytes.as_bytes();
181+
write_size(buf, data.len(), vm)?;
182+
buf.extend(&*data);
183+
}
184+
bytes @ PyByteArray => {
185+
buf.push(Type::Bytes as u8);
186+
let data = bytes.borrow_buf();
187+
write_size(buf, data.len(), vm)?;
188+
buf.extend(&*data);
189+
}
190+
co @ PyCode => {
191+
buf.push(Type::Code as u8);
192+
let bytes = co.code.map_clone_bag(&bytecode::BasicBag).to_bytes();
193+
write_size(buf, bytes.len(), vm)?;
194+
buf.extend(bytes);
195+
}
196+
_ => {
197+
return Err(vm.new_not_implemented_error(
198+
"TODO: not implemented yet or marshal unsupported type".to_owned(),
199+
));
200+
}
201+
})
202+
}
186203
Ok(())
187204
}
188205

189206
#[pyfunction]
190-
fn dumps(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyBytes> {
207+
fn dumps(
208+
value: PyObjectRef,
209+
_version: OptionalArg<i32>,
210+
vm: &VirtualMachine,
211+
) -> PyResult<PyBytes> {
191212
let mut buf = Vec::new();
192213
dump_obj(&mut buf, value, vm)?;
193214
Ok(PyBytes::from(buf))
194215
}
195216

196217
#[pyfunction]
197-
fn dump(value: PyObjectRef, f: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
198-
let dumped = dumps(value, vm)?;
218+
fn dump(
219+
value: PyObjectRef,
220+
f: PyObjectRef,
221+
version: OptionalArg<i32>,
222+
vm: &VirtualMachine,
223+
) -> PyResult<()> {
224+
let dumped = dumps(value, version, vm)?;
199225
vm.call_method(&f, "write", (dumped,))?;
200226
Ok(())
201227
}
@@ -248,8 +274,10 @@ mod decl {
248274
let typ = Type::try_from(*type_indicator)
249275
.map_err(|_| vm.new_value_error("bad marshal data (unknown type code)".to_owned()))?;
250276
let (obj, buf) = match typ {
251-
Type::True => ((true).to_pyobject(vm), buf),
252-
Type::False => ((false).to_pyobject(vm), buf),
277+
Type::True => (true.to_pyobject(vm), buf),
278+
Type::False => (false.to_pyobject(vm), buf),
279+
Type::None => (vm.ctx.none(), buf),
280+
Type::Ellipsis => (vm.ctx.ellipsis(), buf),
253281
Type::Int => {
254282
if buf.len() < 4 {
255283
return Err(too_short_error(vm));
@@ -276,7 +304,17 @@ mod decl {
276304
let number = f64::from_le_bytes(bytes.try_into().unwrap());
277305
(vm.ctx.new_float(number).into(), buf)
278306
}
279-
Type::Str => {
307+
Type::Ascii => {
308+
let (len, buf) = read_size(buf, vm)?;
309+
if buf.len() < len {
310+
return Err(too_short_error(vm));
311+
}
312+
let (bytes, buf) = buf.split_at(len);
313+
let s = String::from_utf8(bytes.to_vec())
314+
.map_err(|_| vm.new_value_error("invalid utf8 data".to_owned()))?;
315+
(s.to_pyobject(vm), buf)
316+
}
317+
Type::Unicode => {
280318
let (len, buf) = read_size(buf, vm)?;
281319
if buf.len() < len {
282320
return Err(too_short_error(vm));

0 commit comments

Comments
 (0)