Skip to content

Commit 7a8ef8c

Browse files
committed
Split out common compression routines into separate file
1 parent ecbc6f7 commit 7a8ef8c

File tree

5 files changed

+284
-284
lines changed

5 files changed

+284
-284
lines changed

stdlib/src/bz2.rs renamed to stdlib/src/compression/bz2.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ pub(crate) use _bz2::make_module;
44

55
#[pymodule]
66
mod _bz2 {
7+
use super::super::{
8+
DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
9+
};
710
use crate::common::lock::PyMutex;
811
use crate::vm::{
912
VirtualMachine,
@@ -12,9 +15,6 @@ mod _bz2 {
1215
object::{PyPayload, PyResult},
1316
types::Constructor,
1417
};
15-
use crate::zlib::{
16-
DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
17-
};
1818
use bzip2::{Decompress, Status, write::BzEncoder};
1919
use rustpython_vm::convert::ToPyException;
2020
use std::{fmt, io::Write};
@@ -74,7 +74,7 @@ mod _bz2 {
7474
impl BZ2Decompressor {
7575
#[pymethod]
7676
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
77-
let max_length = args.max_length();
77+
let max_length = args.max_length_negative_is_none();
7878
let data = &*args.data();
7979

8080
let mut state = self.state.lock();

stdlib/src/compression/generic.rs

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
use crate::vm::{
2+
VirtualMachine,
3+
builtins::{PyBaseExceptionRef, PyBytesRef},
4+
convert::ToPyException,
5+
function::{ArgBytesLike, ArgSize, OptionalArg},
6+
};
7+
8+
#[derive(FromArgs)]
9+
pub(super) struct DecompressArgs {
10+
#[pyarg(positional)]
11+
data: ArgBytesLike,
12+
#[pyarg(any, optional)]
13+
pub max_length: OptionalArg<ArgSize>,
14+
}
15+
16+
impl DecompressArgs {
17+
pub fn data(&self) -> crate::common::borrow::BorrowedValue<'_, [u8]> {
18+
self.data.borrow_buf()
19+
}
20+
pub fn max_length_negative_is_none(&self) -> Option<usize> {
21+
self.max_length
22+
.into_option()
23+
.and_then(|ArgSize { value }| usize::try_from(value).ok())
24+
}
25+
}
26+
27+
pub(super) trait Decompressor {
28+
type Flush: FlushKind;
29+
type Status: DecompressStatus;
30+
type Error;
31+
32+
fn total_in(&self) -> u64;
33+
fn decompress_vec(
34+
&mut self,
35+
input: &[u8],
36+
output: &mut Vec<u8>,
37+
flush: Self::Flush,
38+
) -> Result<Self::Status, Self::Error>;
39+
fn maybe_set_dict(&mut self, err: Self::Error) -> Result<(), Self::Error> {
40+
Err(err)
41+
}
42+
}
43+
44+
pub(super) trait DecompressStatus {
45+
fn is_stream_end(&self) -> bool;
46+
}
47+
48+
pub(super) trait FlushKind: Copy {
49+
const SYNC: Self;
50+
}
51+
52+
impl FlushKind for () {
53+
const SYNC: Self = ();
54+
}
55+
56+
pub(super) fn flush_sync<T: FlushKind>(_final_chunk: bool) -> T {
57+
T::SYNC
58+
}
59+
60+
pub(super) const CHUNKSIZE: usize = u32::MAX as usize;
61+
62+
#[derive(Clone)]
63+
pub(super) struct Chunker<'a> {
64+
data1: &'a [u8],
65+
data2: &'a [u8],
66+
}
67+
impl<'a> Chunker<'a> {
68+
pub fn new(data: &'a [u8]) -> Self {
69+
Self {
70+
data1: data,
71+
data2: &[],
72+
}
73+
}
74+
pub fn chain(data1: &'a [u8], data2: &'a [u8]) -> Self {
75+
if data1.is_empty() {
76+
Self {
77+
data1: data2,
78+
data2: &[],
79+
}
80+
} else {
81+
Self { data1, data2 }
82+
}
83+
}
84+
pub fn len(&self) -> usize {
85+
self.data1.len() + self.data2.len()
86+
}
87+
pub fn is_empty(&self) -> bool {
88+
self.data1.is_empty()
89+
}
90+
pub fn to_vec(&self) -> Vec<u8> {
91+
[self.data1, self.data2].concat()
92+
}
93+
pub fn chunk(&self) -> &'a [u8] {
94+
self.data1.get(..CHUNKSIZE).unwrap_or(self.data1)
95+
}
96+
pub fn advance(&mut self, consumed: usize) {
97+
self.data1 = &self.data1[consumed..];
98+
if self.data1.is_empty() {
99+
self.data1 = std::mem::take(&mut self.data2);
100+
}
101+
}
102+
}
103+
104+
pub(super) fn _decompress<D: Decompressor>(
105+
data: &[u8],
106+
d: &mut D,
107+
bufsize: usize,
108+
max_length: Option<usize>,
109+
calc_flush: impl Fn(bool) -> D::Flush,
110+
) -> Result<(Vec<u8>, bool), D::Error> {
111+
let mut data = Chunker::new(data);
112+
_decompress_chunks(&mut data, d, bufsize, max_length, calc_flush)
113+
}
114+
115+
pub(super) fn _decompress_chunks<D: Decompressor>(
116+
data: &mut Chunker<'_>,
117+
d: &mut D,
118+
bufsize: usize,
119+
max_length: Option<usize>,
120+
calc_flush: impl Fn(bool) -> D::Flush,
121+
) -> Result<(Vec<u8>, bool), D::Error> {
122+
if data.is_empty() {
123+
return Ok((Vec::new(), true));
124+
}
125+
let max_length = max_length.unwrap_or(usize::MAX);
126+
let mut buf = Vec::new();
127+
128+
'outer: loop {
129+
let chunk = data.chunk();
130+
let flush = calc_flush(chunk.len() == data.len());
131+
loop {
132+
let additional = std::cmp::min(bufsize, max_length - buf.capacity());
133+
if additional == 0 {
134+
return Ok((buf, false));
135+
}
136+
buf.reserve_exact(additional);
137+
138+
let prev_in = d.total_in();
139+
let res = d.decompress_vec(chunk, &mut buf, flush);
140+
let consumed = d.total_in() - prev_in;
141+
142+
data.advance(consumed as usize);
143+
144+
match res {
145+
Ok(status) => {
146+
let stream_end = status.is_stream_end();
147+
if stream_end || data.is_empty() {
148+
// we've reached the end of the stream, we're done
149+
buf.shrink_to_fit();
150+
return Ok((buf, stream_end));
151+
} else if !chunk.is_empty() && consumed == 0 {
152+
// we're gonna need a bigger buffer
153+
continue;
154+
} else {
155+
// next chunk
156+
continue 'outer;
157+
}
158+
}
159+
Err(e) => {
160+
d.maybe_set_dict(e)?;
161+
// now try the next chunk
162+
continue 'outer;
163+
}
164+
};
165+
}
166+
}
167+
}
168+
169+
#[derive(Debug)]
170+
pub(super) struct DecompressState<D> {
171+
decompress: D,
172+
unused_data: PyBytesRef,
173+
input_buffer: Vec<u8>,
174+
eof: bool,
175+
needs_input: bool,
176+
}
177+
178+
impl<D: Decompressor> DecompressState<D> {
179+
pub fn new(decompress: D, vm: &VirtualMachine) -> Self {
180+
Self {
181+
decompress,
182+
unused_data: vm.ctx.empty_bytes.clone(),
183+
input_buffer: Vec::new(),
184+
eof: false,
185+
needs_input: true,
186+
}
187+
}
188+
189+
pub fn eof(&self) -> bool {
190+
self.eof
191+
}
192+
193+
pub fn unused_data(&self) -> PyBytesRef {
194+
self.unused_data.clone()
195+
}
196+
197+
pub fn needs_input(&self) -> bool {
198+
self.needs_input
199+
}
200+
201+
pub fn decompress(
202+
&mut self,
203+
data: &[u8],
204+
max_length: Option<usize>,
205+
bufsize: usize,
206+
vm: &VirtualMachine,
207+
) -> Result<Vec<u8>, DecompressError<D::Error>> {
208+
if self.eof {
209+
return Err(DecompressError::Eof(EofError));
210+
}
211+
212+
let input_buffer = &mut self.input_buffer;
213+
let d = &mut self.decompress;
214+
215+
let mut chunks = Chunker::chain(input_buffer, data);
216+
217+
let prev_len = chunks.len();
218+
let (ret, stream_end) =
219+
match _decompress_chunks(&mut chunks, d, bufsize, max_length, flush_sync) {
220+
Ok((buf, stream_end)) => (Ok(buf), stream_end),
221+
Err(err) => (Err(err), false),
222+
};
223+
let consumed = prev_len - chunks.len();
224+
225+
self.eof |= stream_end;
226+
227+
if self.eof {
228+
self.needs_input = false;
229+
if !chunks.is_empty() {
230+
self.unused_data = vm.ctx.new_bytes(chunks.to_vec());
231+
}
232+
} else if chunks.is_empty() {
233+
input_buffer.clear();
234+
self.needs_input = true;
235+
} else {
236+
self.needs_input = false;
237+
if let Some(n_consumed_from_data) = consumed.checked_sub(input_buffer.len()) {
238+
input_buffer.clear();
239+
input_buffer.extend_from_slice(&data[n_consumed_from_data..]);
240+
} else {
241+
input_buffer.drain(..consumed);
242+
input_buffer.extend_from_slice(data);
243+
}
244+
}
245+
246+
ret.map_err(DecompressError::Decompress)
247+
}
248+
}
249+
250+
pub(super) enum DecompressError<E> {
251+
Decompress(E),
252+
Eof(EofError),
253+
}
254+
255+
impl<E> From<E> for DecompressError<E> {
256+
fn from(err: E) -> Self {
257+
Self::Decompress(err)
258+
}
259+
}
260+
261+
pub(super) struct EofError;
262+
263+
impl ToPyException for EofError {
264+
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
265+
vm.new_eof_error("End of stream already reached".to_owned())
266+
}
267+
}

stdlib/src/compression/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
mod generic;
2+
use generic::*;
3+
4+
pub mod bz2;
5+
pub mod zlib;

0 commit comments

Comments
 (0)