Skip to content

[3.14] gh-132983: Minor fixes and clean up for the _zstd module (GH-134930) #134998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 77 additions & 15 deletions Lib/test/test_zstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,27 +1139,41 @@ def test_invalid_dict(self):
ZstdDecompressor(zd)

# wrong type
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
ZstdCompressor(zstd_dict=(zd, b'123'))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=[zd, 1])
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=(zd, 1.0))
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=(zd,))
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=(zd, 1, 2))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=(zd, -1))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdCompressor(zstd_dict=(zd, 3))

with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
ZstdDecompressor(zstd_dict=(zd, b'123'))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaises(OverflowError):
ZstdCompressor(zstd_dict=(zd, 2**1000))
with self.assertRaises(OverflowError):
ZstdCompressor(zstd_dict=(zd, -2**1000))

with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor(zstd_dict=[zd, 1])
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor(zstd_dict=(zd, 1.0))
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor((zd,))
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor((zd, 1, 2))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor((zd, -1))
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
ZstdDecompressor((zd, 3))
with self.assertRaises(OverflowError):
ZstdDecompressor((zd, 2**1000))
with self.assertRaises(OverflowError):
ZstdDecompressor((zd, -2**1000))

def test_train_dict(self):


TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
ZstdDict(TRAINED_DICT.dict_content, is_raw=False)

Expand Down Expand Up @@ -1240,18 +1254,37 @@ def test_train_dict_c(self):
# argument wrong type
with self.assertRaises(TypeError):
_zstd.train_dict({}, (), 100)
with self.assertRaises(TypeError):
_zstd.train_dict(bytearray(), (), 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', 99, 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', [], 100)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', (), 100.1)
with self.assertRaises(TypeError):
_zstd.train_dict(b'', (99.1,), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'abc', (4, -1), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'abc', (2,), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (99,), 100)

# size > size_t
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (2**64+1,), 100)
_zstd.train_dict(b'', (2**1000,), 100)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (-2**1000,), 100)

# dict_size <= 0
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (), 0)
with self.assertRaises(ValueError):
_zstd.train_dict(b'', (), -1)

with self.assertRaises(ZstdError):
_zstd.train_dict(b'', (), 1)

def test_finalize_dict_c(self):
with self.assertRaises(TypeError):
Expand All @@ -1260,22 +1293,51 @@ def test_finalize_dict_c(self):
# argument wrong type
with self.assertRaises(TypeError):
_zstd.finalize_dict({}, b'', (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
with self.assertRaises(TypeError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)

with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)

# size > size_t
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)

# dict_size <= 0
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
with self.assertRaises(ValueError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)

with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
with self.assertRaises(OverflowError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)

with self.assertRaises(ZstdError):
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)

def test_train_buffer_protocol_samples(self):
def _nbytes(dat):
Expand Down
87 changes: 60 additions & 27 deletions Modules/_zstd/_zstdmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include "Python.h"

#include "_zstdmodule.h"
#include "zstddict.h"

#include <zstd.h> // ZSTD_*()
#include <zdict.h> // ZDICT_*()
Expand All @@ -20,14 +19,52 @@ module _zstd
#include "clinic/_zstdmodule.c.h"


ZstdDict *
_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
{
if (state == NULL) {
return NULL;
}

/* Check ZstdDict */
if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
return (ZstdDict*)dict;
}

/* Check (ZstdDict, type) */
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
&& PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
{
int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
if (type == -1 && PyErr_Occurred()) {
return NULL;
}
if (type == DICT_TYPE_DIGESTED
|| type == DICT_TYPE_UNDIGESTED
|| type == DICT_TYPE_PREFIX)
{
*ptype = type;
return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
}
}

/* Wrong type */
PyErr_SetString(PyExc_TypeError,
"zstd_dict argument should be a ZstdDict object.");
return NULL;
}

/* Format error message and set ZstdError. */
void
set_zstd_error(const _zstd_state* const state,
error_type type, size_t zstd_ret)
set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
{
char *msg;
const char *msg;
assert(ZSTD_isError(zstd_ret));

if (state == NULL) {
return;
}
switch (type) {
case ERR_DECOMPRESS:
msg = "Unable to decompress Zstandard data: %s";
Expand Down Expand Up @@ -174,7 +211,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
Py_ssize_t sizes_sum;
Py_ssize_t i;

chunks_number = Py_SIZE(samples_sizes);
chunks_number = PyTuple_GET_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
Expand All @@ -188,20 +225,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
return -1;
}

sizes_sum = 0;
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
(*chunk_sizes)[i] = size;
if (size == (size_t)-1 && PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
goto sum_error;
}
return -1;
}
sizes_sum += (*chunk_sizes)[i];
if ((size_t)sizes_sum < size) {
goto sum_error;
}
sizes_sum -= size;
}

if (sizes_sum != Py_SIZE(samples_bytes)) {
if (sizes_sum != 0) {
sum_error:
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the "
"concatenation's size.");
Expand Down Expand Up @@ -257,7 +298,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,

/* Train the dictionary */
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
Py_BEGIN_ALLOW_THREADS
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
samples_buffer,
Expand Down Expand Up @@ -507,17 +548,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
{
_zstd_state* mod_state = get_zstd_state(module);

if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
PyErr_SetString(PyExc_ValueError,
"The two arguments should be CompressionParameter and "
"DecompressionParameter types.");
return NULL;
}

Py_XSETREF(
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
Py_XSETREF(
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
Py_INCREF(c_parameter_type);
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
Py_INCREF(d_parameter_type);
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);

Py_RETURN_NONE;
}
Expand Down Expand Up @@ -580,7 +614,6 @@ do { \
return -1;
}
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
Py_DECREF(mod_state->ZstdError);
return -1;
}

Expand Down
10 changes: 8 additions & 2 deletions Modules/_zstd/_zstdmodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef ZSTD_MODULE_H
#define ZSTD_MODULE_H

#include "zstddict.h"

/* Type specs */
extern PyType_Spec zstd_dict_type_spec;
extern PyType_Spec zstd_compressor_type_spec;
Expand Down Expand Up @@ -43,10 +45,14 @@ typedef enum {
DICT_TYPE_PREFIX = 2
} dictionary_type;

extern ZstdDict *
_Py_parse_zstd_dict(const _zstd_state *state,
PyObject *dict, int *type);

/* Format error message and set ZstdError. */
extern void
set_zstd_error(const _zstd_state* const state,
const error_type type, size_t zstd_ret);
set_zstd_error(const _zstd_state *state,
error_type type, size_t zstd_ret);

extern void
set_parameter_error(int is_compress, int key_v, int value_v);
Expand Down
Loading
Loading