Skip to content

Switch to libbz2-rs-sys and finish bz2 impl #5709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ flame-it = ["rustpython-vm/flame-it", "flame", "flamescope"]
freeze-stdlib = ["stdlib", "rustpython-vm/freeze-stdlib", "rustpython-pylib?/freeze-stdlib"]
jit = ["rustpython-vm/jit"]
threading = ["rustpython-vm/threading", "rustpython-stdlib/threading"]
bz2 = ["stdlib", "rustpython-stdlib/bz2"]
sqlite = ["rustpython-stdlib/sqlite"]
ssl = ["rustpython-stdlib/ssl"]
ssl-vendor = ["ssl", "rustpython-stdlib/ssl-vendor"]
Expand Down
6 changes: 6 additions & 0 deletions Lib/test/test_bz2.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ def testCompress4G(self, size):
finally:
data = None

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testPickle(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -734,6 +736,8 @@ def testDecompress4G(self, size):
compressed = None
decompressed = None

# TODO: RUSTPYTHON
@unittest.expectedFailure
def testPickle(self):
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -1001,6 +1005,8 @@ def test_encoding_error_handler(self):
as f:
self.assertEqual(f.read(), "foobar")

# TODO: RUSTPYTHON
@unittest.expectedFailure
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coolreader18 @arihant2math what cause this regression? could this be fixed in future?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A flag probably needs to be added to the decompressor state. I believe the issue is the wt and rt formats. Although I'm not to sure, I suppose looking at the cpython source might help.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason this needed to be flagged now is because the whole test file wasn't running at all before - bz2 was not included in the --features=stdlib,threading,... flag in CI, so the module wasn't even getting compiled or tested at all.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, didn't noticed that. Thanks!

def test_newline(self):
# Test with explicit newline (universal newline mode disabled).
text = self.TEXT.decode("ascii")
Expand Down
3 changes: 1 addition & 2 deletions stdlib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ license.workspace = true
default = ["compiler"]
compiler = ["rustpython-vm/compiler"]
threading = ["rustpython-common/threading", "rustpython-vm/threading"]
bz2 = ["bzip2"]
sqlite = ["dep:libsqlite3-sys"]
ssl = ["openssl", "openssl-sys", "foreign-types-shared", "openssl-probe"]
ssl-vendor = ["ssl", "openssl/vendored"]
Expand Down Expand Up @@ -80,7 +79,7 @@ adler32 = "1.2.0"
crc32fast = "1.3.2"
flate2 = { version = "1.1", default-features = false, features = ["zlib-rs"] }
libz-sys = { package = "libz-rs-sys", version = "0.5" }
bzip2 = { version = "0.4", optional = true }
bzip2 = { version = "0.5", features = ["libbz2-rs-sys"] }

# tkinter
tk-sys = { git = "https://github.com/arihant2math/tkinter.git", tag = "v0.1.0", optional = true }
Expand Down
151 changes: 46 additions & 105 deletions stdlib/src/bz2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,48 @@ mod _bz2 {
object::{PyPayload, PyResult},
types::Constructor,
};
use crate::zlib::{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there is at-least 1 other decompression module that uses the same format (_lzma). I think this should be moved someplace common.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - I was planning on doing that in a followup, to reduce the amount of code movement in this pr.

DecompressArgs, DecompressError, DecompressState, DecompressStatus, Decompressor,
};
use bzip2::{Decompress, Status, write::BzEncoder};
use rustpython_vm::convert::ToPyException;
use std::{fmt, io::Write};

// const BUFSIZ: i32 = 8192;

struct DecompressorState {
decoder: Decompress,
eof: bool,
needs_input: bool,
// input_buffer: Vec<u8>,
// output_buffer: Vec<u8>,
}
const BUFSIZ: usize = 8192;

#[pyattr]
#[pyclass(name = "BZ2Decompressor")]
#[derive(PyPayload)]
struct BZ2Decompressor {
state: PyMutex<DecompressorState>,
state: PyMutex<DecompressState<Decompress>>,
}

impl Decompressor for Decompress {
type Flush = ();
type Status = Status;
type Error = bzip2::Error;

fn total_in(&self) -> u64 {
self.total_in()
}
fn decompress_vec(
&mut self,
input: &[u8],
output: &mut Vec<u8>,
(): Self::Flush,
) -> Result<Self::Status, Self::Error> {
self.decompress_vec(input, output)
}
}

impl DecompressStatus for Status {
fn is_stream_end(&self) -> bool {
*self == Status::StreamEnd
}
}

impl fmt::Debug for BZ2Decompressor {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "_bz2.BZ2Decompressor")
}
}
Expand All @@ -43,13 +63,7 @@ mod _bz2 {

fn py_new(cls: PyTypeRef, _: Self::Args, vm: &VirtualMachine) -> PyResult {
Self {
state: PyMutex::new(DecompressorState {
decoder: Decompress::new(false),
eof: false,
needs_input: true,
// input_buffer: Vec::new(),
// output_buffer: Vec::new(),
}),
state: PyMutex::new(DecompressState::new(Decompress::new(false), vm)),
}
.into_ref_with_type(vm, cls)
.map(Into::into)
Expand All @@ -59,107 +73,34 @@ mod _bz2 {
#[pyclass(with(Constructor))]
impl BZ2Decompressor {
#[pymethod]
fn decompress(
&self,
data: ArgBytesLike,
// TODO: PyIntRef
max_length: OptionalArg<i32>,
vm: &VirtualMachine,
) -> PyResult<PyBytesRef> {
let max_length = max_length.unwrap_or(-1);
if max_length >= 0 {
return Err(vm.new_not_implemented_error(
"the max_value argument is not implemented yet".to_owned(),
));
}
// let max_length = if max_length < 0 || max_length >= BUFSIZ {
// BUFSIZ
// } else {
// max_length
// };
fn decompress(&self, args: DecompressArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
let max_length = args.max_length();
let data = &*args.data();

let mut state = self.state.lock();
let DecompressorState {
decoder,
eof,
..
// needs_input,
// input_buffer,
// output_buffer,
} = &mut *state;

if *eof {
return Err(vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
"End of stream already reached".to_owned(),
));
}

// data.with_ref(|data| input_buffer.extend(data));

// If max_length is negative:
// read the input X bytes at a time, compress it and append it to output.
// Once you're out of input, setting needs_input to true and return the
// output as bytes.
//
// TODO:
// If max_length is non-negative:
// Read the input X bytes at a time, compress it and append it to
// the output. If output reaches `max_length` in size, return
// it (up to max_length), and store the rest of the output
// for later.

// TODO: arbitrary choice, not the right way to do it.
let mut buf = Vec::with_capacity(data.len() * 32);

let before = decoder.total_in();
let res = data.with_ref(|data| decoder.decompress_vec(data, &mut buf));
let _written = (decoder.total_in() - before) as usize;

let res = match res {
Ok(x) => x,
// TODO: error message
_ => return Err(vm.new_os_error("Invalid data stream".to_owned())),
};

if res == Status::StreamEnd {
*eof = true;
}
Ok(vm.ctx.new_bytes(buf.to_vec()))
state
.decompress(data, max_length, BUFSIZ, vm)
.map_err(|e| match e {
DecompressError::Decompress(err) => vm.new_os_error(err.to_string()),
DecompressError::Eof(err) => err.to_pyexception(vm),
})
}

#[pygetset]
fn eof(&self) -> bool {
let state = self.state.lock();
state.eof
self.state.lock().eof()
}

#[pygetset]
fn unused_data(&self, vm: &VirtualMachine) -> PyBytesRef {
// Data found after the end of the compressed stream.
// If this attribute is accessed before the end of the stream
// has been reached, its value will be b''.
vm.ctx.new_bytes(b"".to_vec())
// alternatively, be more honest:
// Err(vm.new_not_implemented_error(
// "unused_data isn't implemented yet".to_owned(),
// ))
//
// TODO
// let state = self.state.lock();
// if state.eof {
// vm.ctx.new_bytes(state.input_buffer.to_vec())
// else {
// vm.ctx.new_bytes(b"".to_vec())
// }
fn unused_data(&self) -> PyBytesRef {
self.state.lock().unused_data()
}

#[pygetset]
fn needs_input(&self) -> bool {
// False if the decompress() method can provide more
// decompressed data before requiring new uncompressed input.
let state = self.state.lock();
state.needs_input
self.state.lock().needs_input()
}

// TODO: mro()?
Expand All @@ -178,7 +119,7 @@ mod _bz2 {
}

impl fmt::Debug for BZ2Compressor {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "_bz2.BZ2Compressor")
}
}
Expand Down
6 changes: 1 addition & 5 deletions stdlib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ mod statistics;
mod suggestions;
// TODO: maybe make this an extension module, if we ever get those
// mod re;
#[cfg(feature = "bz2")]
mod bz2;
#[cfg(not(target_arch = "wasm32"))]
pub mod socket;
Expand Down Expand Up @@ -112,6 +111,7 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
"array" => array::make_module,
"binascii" => binascii::make_module,
"_bisect" => bisect::make_module,
"_bz2" => bz2::make_module,
"cmath" => cmath::make_module,
"_contextvars" => contextvars::make_module,
"_csv" => csv::make_module,
Expand Down Expand Up @@ -158,10 +158,6 @@ pub fn get_module_inits() -> impl Iterator<Item = (Cow<'static, str>, StdlibInit
{
"_ssl" => ssl::make_module,
}
#[cfg(feature = "bz2")]
{
"_bz2" => bz2::make_module,
}
#[cfg(windows)]
{
"_overlapped" => overlapped::make_module,
Expand Down
Loading
Loading