diff --git a/Cargo.lock b/Cargo.lock index 87cfb55109..0255535c1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -239,12 +239,12 @@ checksum = "b6b1fc10dbac614ebc03540c9dbd60e83887fda27794998c6528f1782047d540" [[package]] name = "bzip2" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" dependencies = [ "bzip2-sys", - "libc", + "libbz2-rs-sys", ] [[package]] @@ -1322,6 +1322,12 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa0e2a1fcbe2f6be6c42e342259976206b383122fc152e872795338b5a3f3a7" +[[package]] +name = "libbz2-rs-sys" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0864a00c8d019e36216b69c2c4ce50b83b7bd966add3cf5ba554ec44f8bebcf5" + [[package]] name = "libc" version = "0.2.171" diff --git a/Cargo.toml b/Cargo.toml index 0a3e3b89ff..fb066e5dfd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,6 @@ flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"] freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/freeze-stdlib"] jit = ["rustpython-vm/jit"] threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"] -bz2 = ["stdlib", "rustpython-stdlib/bz2"] sqlite = ["rustpython-stdlib/sqlite"] ssl = ["rustpython-stdlib/ssl"] ssl-vendor = ["ssl", "rustpython-stdlib/ssl-vendor"] diff --git a/Lib/test/test_bz2.py b/Lib/test/test_bz2.py index 1f0b9adc36..b716d6016b 100644 --- a/Lib/test/test_bz2.py +++ b/Lib/test/test_bz2.py @@ -676,6 +676,8 @@ def testCompress4G(self, size): finally: data = None + # TODO: RUSTPYTHON + @unittest.expectedFailure def testPickle(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.assertRaises(TypeError): @@ -734,6 +736,8 @@ def testDecompress4G(self, size): compressed = None decompressed = None + # TODO: RUSTPYTHON + @unittest.expectedFailure def testPickle(self): for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.assertRaises(TypeError): @@ -1001,6 +1005,8 @@ def test_encoding_error_handler(self): as f: self.assertEqual(f.read(), "foobar") + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_newline(self): # Test with explicit newline (universal newline mode disabled). text = self.TEXT.decode("ascii") diff --git a/stdlib/Cargo.toml b/stdlib/Cargo.toml index eb6b9fe4dd..d29bd3b21e 100644 --- a/stdlib/Cargo.toml +++ b/stdlib/Cargo.toml @@ -14,7 +14,6 @@ license.workspace = true default = ["compiler"] compiler = ["rustpython-vm/compiler"] threading = ["rustpython-common/threading", "rustpython-vm/threading"] -bz2 = ["bzip2"] sqlite = ["dep:libsqlite3-sys"] ssl = ["openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"] ssl-vendor = ["ssl", "openssl/vendored"] @@ -80,7 +79,7 @@ adler32 = "1.2.0" crc32fast = "1.3.2" flate2 = { version = "1.1", default-features = false, features = ["zlib-rs"] } libz-sys = { package = "libz-rs-sys", version = "0.5" } -bzip2 = { version = "0.4", optional = true } +bzip2 = { version = "0.5", features = ["libbz2-rs-sys"] } # tkinter tk-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.1.0", optional = true } diff --git a/stdlib/src/bz2.rs b/stdlib/src/bz2.rs index ba74a38db1..6339a44a24 100644 --- a/stdlib/src/bz2.rs +++ b/stdlib/src/bz2.rs @@ -12,28 +12,48 @@ mod _bz2 { object::{PyPayload, PyResult}, types::Constructor, }; + use crate::zlib::{ + DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor, + }; use bzip2::{Decompress, Status, write::BzEncoder}; + use rustpython_vm::convert::ToPyException; use std::{fmt, io::Write}; - // const BUFSIZ: i32 = 8192; - - struct DecompressorState { - decoder: Decompress, - eof: bool, - needs_input: bool, - // input_buffer: Vec, - // output_buffer: Vec, - } + const BUFSIZ: usize = 8192; #[pyattr] #[pyclass(name = "BZ2Decompressor")] #[derive(PyPayload)] struct BZ2Decompressor { - state: PyMutex, + state: PyMutex>, + } + + impl Decompressor for Decompress { + type Flush = (); + type Status = Status; + type Error = bzip2::Error; + + fn total_in(&self) -> u64 { + self.total_in() + } + fn decompress_vec( + &mut self, + input: &[u8], + output: &mut Vec, + (): Self::Flush, + ) -> Result { + self.decompress_vec(input, output) + } + } + + impl DecompressStatus for Status { + fn is_stream_end(&self) -> bool { + *self == Status::StreamEnd + } } impl fmt::Debug for BZ2Decompressor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "_bz2.BZ2Decompressor") } } @@ -43,13 +63,7 @@ mod _bz2 { fn py_new(cls: PyTypeRef, _: Self::Args, vm: &VirtualMachine) -> PyResult { Self { - state: PyMutex::new(DecompressorState { - decoder: Decompress::new(false), - eof: false, - needs_input: true, - // input_buffer: Vec::new(), - // output_buffer: Vec::new(), - }), + state: PyMutex::new(DecompressState::new(Decompress::new(false), vm)), } .into_ref_with_type(vm, cls) .map(Into::into) @@ -59,107 +73,34 @@ mod _bz2 { #[pyclass(with(Constructor))] impl BZ2Decompressor { #[pymethod] - fn decompress( - &self, - data: ArgBytesLike, - // TODO: PyIntRef - max_length: OptionalArg, - vm: &VirtualMachine, - ) -> PyResult { - let max_length = max_length.unwrap_or(-1); - if max_length >= 0 { - return Err(vm.new_not_implemented_error( - "the max_value argument is not implemented yet".to_owned(), - )); - } - // let max_length = if max_length < 0 || max_length >= BUFSIZ { - // BUFSIZ - // } else { - // max_length - // }; + fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { + let max_length = args.max_length(); + let data = &*args.data(); let mut state = self.state.lock(); - let DecompressorState { - decoder, - eof, - .. - // needs_input, - // input_buffer, - // output_buffer, - } = &mut *state; - - if *eof { - return Err(vm.new_exception_msg( - vm.ctx.exceptions.eof_error.to_owned(), - "End of stream already reached".to_owned(), - )); - } - - // data.with_ref(|data| input_buffer.extend(data)); - - // If max_length is negative: - // read the input X bytes at a time, compress it and append it to output. - // Once you're out of input, setting needs_input to true and return the - // output as bytes. - // - // TODO: - // If max_length is non-negative: - // Read the input X bytes at a time, compress it and append it to - // the output. If output reaches `max_length` in size, return - // it (up to max_length), and store the rest of the output - // for later. - - // TODO: arbitrary choice, not the right way to do it. - let mut buf = Vec::with_capacity(data.len() * 32); - - let before = decoder.total_in(); - let res = data.with_ref(|data| decoder.decompress_vec(data, &mut buf)); - let _written = (decoder.total_in() - before) as usize; - - let res = match res { - Ok(x) => x, - // TODO: error message - _ => return Err(vm.new_os_error("Invalid data stream".to_owned())), - }; - - if res == Status::StreamEnd { - *eof = true; - } - Ok(vm.ctx.new_bytes(buf.to_vec())) + state + .decompress(data, max_length, BUFSIZ, vm) + .map_err(|e| match e { + DecompressError::Decompress(err) => vm.new_os_error(err.to_string()), + DecompressError::Eof(err) => err.to_pyexception(vm), + }) } #[pygetset] fn eof(&self) -> bool { - let state = self.state.lock(); - state.eof + self.state.lock().eof() } #[pygetset] - fn unused_data(&self, vm: &VirtualMachine) -> PyBytesRef { - // Data found after the end of the compressed stream. - // If this attribute is accessed before the end of the stream - // has been reached, its value will be b''. - vm.ctx.new_bytes(b"".to_vec()) - // alternatively, be more honest: - // Err(vm.new_not_implemented_error( - // "unused_data isn't implemented yet".to_owned(), - // )) - // - // TODO - // let state = self.state.lock(); - // if state.eof { - // vm.ctx.new_bytes(state.input_buffer.to_vec()) - // else { - // vm.ctx.new_bytes(b"".to_vec()) - // } + fn unused_data(&self) -> PyBytesRef { + self.state.lock().unused_data() } #[pygetset] fn needs_input(&self) -> bool { // False if the decompress() method can provide more // decompressed data before requiring new uncompressed input. - let state = self.state.lock(); - state.needs_input + self.state.lock().needs_input() } // TODO: mro()? @@ -178,7 +119,7 @@ mod _bz2 { } impl fmt::Debug for BZ2Compressor { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "_bz2.BZ2Compressor") } } diff --git a/stdlib/src/lib.rs b/stdlib/src/lib.rs index 1e81606f3c..254aa39322 100644 --- a/stdlib/src/lib.rs +++ b/stdlib/src/lib.rs @@ -36,7 +36,6 @@ mod statistics; mod suggestions; // TODO: maybe make this an extension module, if we ever get those // mod re; -#[cfg(feature = "bz2")] mod bz2; #[cfg(not(target_arch = "wasm32"))] pub mod socket; @@ -112,6 +111,7 @@ pub fn get_module_inits() -> impl Iterator, StdlibInit "array" => array::make_module, "binascii" => binascii::make_module, "_bisect" => bisect::make_module, + "_bz2" => bz2::make_module, "cmath" => cmath::make_module, "_contextvars" => contextvars::make_module, "_csv" => csv::make_module, @@ -158,10 +158,6 @@ pub fn get_module_inits() -> impl Iterator, StdlibInit { "_ssl" => ssl::make_module, } - #[cfg(feature = "bz2")] - { - "_bz2" => bz2::make_module, - } #[cfg(windows)] { "_overlapped" => overlapped::make_module, diff --git a/stdlib/src/zlib.rs b/stdlib/src/zlib.rs index 9c19b74066..0578f20c86 100644 --- a/stdlib/src/zlib.rs +++ b/stdlib/src/zlib.rs @@ -1,14 +1,17 @@ // spell-checker:ignore compressobj decompressobj zdict chunksize zlibmodule miniz chunker -pub(crate) use zlib::make_module; +pub(crate) use zlib::{DecompressArgs, make_module}; #[pymodule] mod zlib { + use super::generic::{ + DecompressError, DecompressState, DecompressStatus, Decompressor, FlushKind, flush_sync, + }; use crate::vm::{ PyObject, PyPayload, PyResult, VirtualMachine, builtins::{PyBaseExceptionRef, PyBytesRef, PyIntRef, PyTypeRef}, common::lock::PyMutex, - convert::TryFromBorrowedObject, + convert::{ToPyException, TryFromBorrowedObject}, function::{ArgBytesLike, ArgPrimitiveIndex, ArgSize, OptionalArg}, types::Constructor, }; @@ -142,18 +145,18 @@ mod zlib { } #[derive(Clone)] - struct Chunker<'a> { + pub(crate) struct Chunker<'a> { data1: &'a [u8], data2: &'a [u8], } impl<'a> Chunker<'a> { - fn new(data: &'a [u8]) -> Self { + pub(crate) fn new(data: &'a [u8]) -> Self { Self { data1: data, data2: &[], } } - fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self { + pub(crate) fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self { if data1.is_empty() { Self { data1: data2, @@ -163,19 +166,19 @@ mod zlib { Self { data1, data2 } } } - fn len(&self) -> usize { + pub(crate) fn len(&self) -> usize { self.data1.len() + self.data2.len() } - fn is_empty(&self) -> bool { + pub(crate) fn is_empty(&self) -> bool { self.data1.is_empty() } - fn to_vec(&self) -> Vec { + pub(crate) fn to_vec(&self) -> Vec { [self.data1, self.data2].concat() } - fn chunk(&self) -> &'a [u8] { + pub(crate) fn chunk(&self) -> &'a [u8] { self.data1.get(..CHUNKSIZE).unwrap_or(self.data1) } - fn advance(&mut self, consumed: usize) { + pub(crate) fn advance(&mut self, consumed: usize) { self.data1 = &self.data1[consumed..]; if self.data1.is_empty() { self.data1 = std::mem::take(&mut self.data2); @@ -183,28 +186,24 @@ mod zlib { } } - fn _decompress( + fn _decompress( data: &[u8], - d: &mut Decompress, + d: &mut D, bufsize: usize, max_length: Option, - is_flush: bool, - zdict: Option<&ArgBytesLike>, - vm: &VirtualMachine, - ) -> PyResult<(Vec, bool)> { + calc_flush: impl Fn(bool) -> D::Flush, + ) -> Result<(Vec, bool), D::Error> { let mut data = Chunker::new(data); - _decompress_chunks(&mut data, d, bufsize, max_length, is_flush, zdict, vm) + _decompress_chunks(&mut data, d, bufsize, max_length, calc_flush) } - fn _decompress_chunks( + pub(super) fn _decompress_chunks( data: &mut Chunker<'_>, - d: &mut Decompress, + d: &mut D, bufsize: usize, max_length: Option, - is_flush: bool, - zdict: Option<&ArgBytesLike>, - vm: &VirtualMachine, - ) -> PyResult<(Vec, bool)> { + calc_flush: impl Fn(bool) -> D::Flush, + ) -> Result<(Vec, bool), D::Error> { if data.is_empty() { return Ok((Vec::new(), true)); } @@ -213,16 +212,7 @@ mod zlib { 'outer: loop { let chunk = data.chunk(); - let flush = if is_flush { - // if this is the final chunk, finish it - if chunk.len() == data.len() { - FlushDecompress::Finish - } else { - FlushDecompress::None - } - } else { - FlushDecompress::Sync - }; + let flush = calc_flush(chunk.len() == data.len()); loop { let additional = std::cmp::min(bufsize, max_length - buf.capacity()); if additional == 0 { @@ -238,7 +228,7 @@ mod zlib { match res { Ok(status) => { - let stream_end = status == Status::StreamEnd; + let stream_end = status.is_stream_end(); if stream_end || data.is_empty() { // we've reached the end of the stream, we're done buf.shrink_to_fit(); @@ -252,11 +242,7 @@ mod zlib { } } Err(e) => { - let Some(zdict) = e.needs_dictionary().and(zdict) else { - return Err(new_zlib_error(&e.to_string(), vm)); - }; - d.set_dictionary(&zdict.borrow_buf()) - .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; + d.maybe_set_dict(e)?; // now try the next chunk continue 'outer; } @@ -285,8 +271,8 @@ mod zlib { } = args; data.with_ref(|data| { let mut d = InitOptions::new(wbits.value, vm)?.decompress(); - let (buf, stream_end) = - _decompress(data, &mut d, bufsize.value, None, false, None, vm)?; + let (buf, stream_end) = _decompress(data, &mut d, bufsize.value, None, flush_sync) + .map_err(|e| new_zlib_error(e.to_string(), vm))?; if !stream_end { return Err(new_zlib_error( "Error -5 while decompressing data: incomplete or truncated stream", @@ -316,9 +302,8 @@ mod zlib { } } let inner = PyDecompressInner { - decompress: Some(decompress), + decompress: Some(DecompressWithDict { decompress, zdict }), eof: false, - zdict, unused_data: vm.ctx.empty_bytes.clone(), unconsumed_tail: vm.ctx.empty_bytes.clone(), }; @@ -329,8 +314,7 @@ mod zlib { #[derive(Debug)] struct PyDecompressInner { - decompress: Option, - zdict: Option, + decompress: Option, eof: bool, unused_data: PyBytesRef, unconsumed_tail: PyBytesRef, @@ -370,14 +354,25 @@ mod zlib { return Err(new_zlib_error(USE_AFTER_FINISH_ERR, vm)); }; - let zdict = if is_flush { None } else { inner.zdict.as_ref() }; - let prev_in = d.total_in(); - let (ret, stream_end) = - match _decompress(data, d, bufsize, max_length, is_flush, zdict, vm) { - Ok((buf, stream_end)) => (Ok(buf), stream_end), - Err(err) => (Err(err), false), + let res = if is_flush { + // if is_flush: ignore zdict, finish if final chunk + let calc_flush = |final_chunk| { + if final_chunk { + FlushDecompress::Finish + } else { + FlushDecompress::None + } }; + _decompress(data, &mut d.decompress, bufsize, max_length, calc_flush) + } else { + _decompress(data, d, bufsize, max_length, flush_sync) + } + .map_err(|e| new_zlib_error(e.to_string(), vm)); + let (ret, stream_end) = match res { + Ok((buf, stream_end)) => (Ok(buf), stream_end), + Err(err) => (Err(err), false), + }; let consumed = (d.total_in() - prev_in) as usize; // save unused input @@ -404,7 +399,7 @@ mod zlib { .try_into() .map_err(|_| vm.new_value_error("must be non-negative".to_owned()))?; let max_length = (max_length != 0).then_some(max_length); - let data = &*args.data.borrow_buf(); + let data = &*args.data(); let inner = &mut *self.inner.lock(); @@ -440,13 +435,24 @@ mod zlib { } #[derive(FromArgs)] - struct DecompressArgs { + pub(crate) struct DecompressArgs { #[pyarg(positional)] data: ArgBytesLike, #[pyarg(any, optional)] max_length: OptionalArg, } + impl DecompressArgs { + pub(crate) fn data(&self) -> crate::common::borrow::BorrowedValue<'_, [u8]> { + self.data.borrow_buf() + } + pub(crate) fn max_length(&self) -> Option { + self.max_length + .into_option() + .and_then(|ArgSize { value }| usize::try_from(value).ok()) + } + } + #[derive(FromArgs)] #[allow(dead_code)] // FIXME: use args struct CompressobjArgs { @@ -588,8 +594,8 @@ mod zlib { } } - fn new_zlib_error(message: &str, vm: &VirtualMachine) -> PyBaseExceptionRef { - vm.new_exception_msg(vm.class("zlib", "error"), message.to_owned()) + fn new_zlib_error(message: impl Into, vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_exception_msg(vm.class("zlib", "error"), message.into()) } const USE_AFTER_FINISH_ERR: &str = "Error -2: inconsistent stream state"; @@ -626,19 +632,68 @@ mod zlib { #[pyclass(name = "_ZlibDecompressor")] #[derive(Debug, PyPayload)] struct ZlibDecompressor { - inner: PyMutex, + inner: PyMutex>, } #[derive(Debug)] - struct ZlibDecompressorInner { + struct DecompressWithDict { decompress: Decompress, - unused_data: PyBytesRef, - input_buffer: Vec, zdict: Option, - eof: bool, - needs_input: bool, } + impl DecompressStatus for Status { + fn is_stream_end(&self) -> bool { + *self == Status::StreamEnd + } + } + + impl FlushKind for FlushDecompress { + const SYNC: Self = FlushDecompress::Sync; + } + + impl Decompressor for Decompress { + type Flush = FlushDecompress; + type Status = Status; + type Error = flate2::DecompressError; + + fn total_in(&self) -> u64 { + self.total_in() + } + fn decompress_vec( + &mut self, + input: &[u8], + output: &mut Vec, + flush: Self::Flush, + ) -> Result { + self.decompress_vec(input, output, flush) + } + } + + impl Decompressor for DecompressWithDict { + type Flush = FlushDecompress; + type Status = Status; + type Error = flate2::DecompressError; + + fn total_in(&self) -> u64 { + self.decompress.total_in() + } + fn decompress_vec( + &mut self, + input: &[u8], + output: &mut Vec, + flush: Self::Flush, + ) -> Result { + self.decompress.decompress_vec(input, output, flush) + } + fn maybe_set_dict(&mut self, err: Self::Error) -> Result<(), Self::Error> { + let zdict = err.needs_dictionary().and(self.zdict.as_ref()).ok_or(err)?; + self.decompress.set_dictionary(&zdict.borrow_buf())?; + Ok(()) + } + } + + // impl Deconstruct + impl Constructor for ZlibDecompressor { type Args = DecompressobjArgs; @@ -651,14 +706,7 @@ mod zlib { .map_err(|_| new_zlib_error("failed to set dictionary", vm))?; } } - let inner = ZlibDecompressorInner { - decompress, - unused_data: vm.ctx.empty_bytes.clone(), - input_buffer: Vec::new(), - zdict, - eof: false, - needs_input: true, - }; + let inner = DecompressState::new(DecompressWithDict { decompress, zdict }, vm); Self { inner: PyMutex::new(inner), } @@ -671,61 +719,151 @@ mod zlib { impl ZlibDecompressor { #[pygetset] fn eof(&self) -> bool { - self.inner.lock().eof + self.inner.lock().eof() } #[pygetset] fn unused_data(&self) -> PyBytesRef { - self.inner.lock().unused_data.clone() + self.inner.lock().unused_data() } #[pygetset] fn needs_input(&self) -> bool { - self.inner.lock().needs_input + self.inner.lock().needs_input() } #[pymethod] fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult> { - let max_length = args - .max_length - .into_option() - .and_then(|ArgSize { value }| usize::try_from(value).ok()); - let data = &*args.data.borrow_buf(); + let max_length = args.max_length(); + let data = &*args.data(); let inner = &mut *self.inner.lock(); - if inner.eof { - return Err(vm.new_eof_error("End of stream already reached".to_owned())); + inner + .decompress(data, max_length, DEF_BUF_SIZE, vm) + .map_err(|e| match e { + DecompressError::Decompress(err) => new_zlib_error(err.to_string(), vm), + DecompressError::Eof(err) => err.to_pyexception(vm), + }) + } + + // TODO: Wait for getstate pyslot to be fixed + // #[pyslot] + // fn getstate(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { + // Err(vm.new_type_error("cannot serialize '_ZlibDecompressor' object".to_owned())) + // } + } +} + +mod generic { + use super::zlib::{_decompress_chunks, Chunker}; + use crate::vm::{ + VirtualMachine, + builtins::{PyBaseExceptionRef, PyBytesRef}, + convert::ToPyException, + }; + + pub(crate) trait Decompressor { + type Flush: FlushKind; + type Status: DecompressStatus; + type Error; + + fn total_in(&self) -> u64; + fn decompress_vec( + &mut self, + input: &[u8], + output: &mut Vec, + flush: Self::Flush, + ) -> Result; + fn maybe_set_dict(&mut self, err: Self::Error) -> Result<(), Self::Error> { + Err(err) + } + } + + pub(crate) trait DecompressStatus { + fn is_stream_end(&self) -> bool; + } + + pub(crate) trait FlushKind: Copy { + const SYNC: Self; + } + + impl FlushKind for () { + const SYNC: Self = (); + } + + pub(super) fn flush_sync(_final_chunk: bool) -> T { + T::SYNC + } + + #[derive(Debug)] + pub(crate) struct DecompressState { + decompress: D, + unused_data: PyBytesRef, + input_buffer: Vec, + eof: bool, + needs_input: bool, + } + + impl DecompressState { + pub(crate) fn new(decompress: D, vm: &VirtualMachine) -> Self { + Self { + decompress, + unused_data: vm.ctx.empty_bytes.clone(), + input_buffer: Vec::new(), + eof: false, + needs_input: true, + } + } + + pub(crate) fn eof(&self) -> bool { + self.eof + } + + pub(crate) fn unused_data(&self) -> PyBytesRef { + self.unused_data.clone() + } + + pub(crate) fn needs_input(&self) -> bool { + self.needs_input + } + + pub(crate) fn decompress( + &mut self, + data: &[u8], + max_length: Option, + bufsize: usize, + vm: &VirtualMachine, + ) -> Result, DecompressError> { + if self.eof { + return Err(DecompressError::Eof(EofError)); } - let input_buffer = &mut inner.input_buffer; - let d = &mut inner.decompress; + let input_buffer = &mut self.input_buffer; + let d = &mut self.decompress; let mut chunks = Chunker::chain(input_buffer, data); - let zdict = inner.zdict.as_ref(); - let bufsize = DEF_BUF_SIZE; - let prev_len = chunks.len(); let (ret, stream_end) = - match _decompress_chunks(&mut chunks, d, bufsize, max_length, false, zdict, vm) { + match _decompress_chunks(&mut chunks, d, bufsize, max_length, flush_sync) { Ok((buf, stream_end)) => (Ok(buf), stream_end), Err(err) => (Err(err), false), }; let consumed = prev_len - chunks.len(); - inner.eof |= stream_end; + self.eof |= stream_end; - if inner.eof { - inner.needs_input = false; + if self.eof { + self.needs_input = false; if !chunks.is_empty() { - inner.unused_data = vm.ctx.new_bytes(chunks.to_vec()); + self.unused_data = vm.ctx.new_bytes(chunks.to_vec()); } } else if chunks.is_empty() { input_buffer.clear(); - inner.needs_input = true; + self.needs_input = true; } else { - inner.needs_input = false; + self.needs_input = false; if let Some(n_consumed_from_data) = consumed.checked_sub(input_buffer.len()) { input_buffer.clear(); input_buffer.extend_from_slice(&data[n_consumed_from_data..]); @@ -735,13 +873,28 @@ mod zlib { } } - ret + ret.map_err(DecompressError::Decompress) } + } - // TODO: Wait for getstate pyslot to be fixed - // #[pyslot] - // fn getstate(zelf: &PyObject, vm: &VirtualMachine) -> PyResult { - // Err(vm.new_type_error("cannot serialize '_ZlibDecompressor' object".to_owned())) - // } + pub(crate) enum DecompressError { + Decompress(E), + Eof(EofError), + } + + impl From for DecompressError { + fn from(err: E) -> Self { + Self::Decompress(err) + } + } + + pub(crate) struct EofError; + + impl ToPyException for EofError { + fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef { + vm.new_eof_error("End of stream already reached".to_owned()) + } } } + +pub(crate) use generic::{DecompressError, DecompressState, DecompressStatus, Decompressor}; diff --git a/vm/src/function/number.rs b/vm/src/function/number.rs index 0e36f57ad1..bead82123e 100644 --- a/vm/src/function/number.rs +++ b/vm/src/function/number.rs @@ -158,7 +158,7 @@ impl TryFromObject for ArgIndex { } } -#[derive(Debug)] +#[derive(Debug, Copy, Clone)] #[repr(transparent)] pub struct ArgPrimitiveIndex { pub value: T,