Skip to content

Commit b595237

Browse files
gh-132983: Minor fixes and clean up for the _zstd module (GH-134930)
1 parent fe6f8a3 commit b595237

File tree

6 files changed

+166
-160
lines changed

6 files changed

+166
-160
lines changed

Lib/test/test_zstd.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,27 +1138,41 @@ def test_invalid_dict(self):
11381138
ZstdDecompressor(zd)
11391139

11401140
# wrong type
1141-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1142-
ZstdCompressor(zstd_dict=(zd, b'123'))
1143-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1141+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1142+
ZstdCompressor(zstd_dict=[zd, 1])
1143+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1144+
ZstdCompressor(zstd_dict=(zd, 1.0))
1145+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1146+
ZstdCompressor(zstd_dict=(zd,))
1147+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11441148
ZstdCompressor(zstd_dict=(zd, 1, 2))
1145-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1149+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11461150
ZstdCompressor(zstd_dict=(zd, -1))
1147-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1151+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11481152
ZstdCompressor(zstd_dict=(zd, 3))
1149-
1150-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1151-
ZstdDecompressor(zstd_dict=(zd, b'123'))
1152-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1153+
with self.assertRaises(OverflowError):
1154+
ZstdCompressor(zstd_dict=(zd, 2**1000))
1155+
with self.assertRaises(OverflowError):
1156+
ZstdCompressor(zstd_dict=(zd, -2**1000))
1157+
1158+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1159+
ZstdDecompressor(zstd_dict=[zd, 1])
1160+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1161+
ZstdDecompressor(zstd_dict=(zd, 1.0))
1162+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
1163+
ZstdDecompressor((zd,))
1164+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11531165
ZstdDecompressor((zd, 1, 2))
1154-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1166+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11551167
ZstdDecompressor((zd, -1))
1156-
with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'):
1168+
with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'):
11571169
ZstdDecompressor((zd, 3))
1170+
with self.assertRaises(OverflowError):
1171+
ZstdDecompressor((zd, 2**1000))
1172+
with self.assertRaises(OverflowError):
1173+
ZstdDecompressor((zd, -2**1000))
11581174

11591175
def test_train_dict(self):
1160-
1161-
11621176
TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1)
11631177
ZstdDict(TRAINED_DICT.dict_content, is_raw=False)
11641178

@@ -1239,18 +1253,37 @@ def test_train_dict_c(self):
12391253
# argument wrong type
12401254
with self.assertRaises(TypeError):
12411255
_zstd.train_dict({}, (), 100)
1256+
with self.assertRaises(TypeError):
1257+
_zstd.train_dict(bytearray(), (), 100)
12421258
with self.assertRaises(TypeError):
12431259
_zstd.train_dict(b'', 99, 100)
1260+
with self.assertRaises(TypeError):
1261+
_zstd.train_dict(b'', [], 100)
12441262
with self.assertRaises(TypeError):
12451263
_zstd.train_dict(b'', (), 100.1)
1264+
with self.assertRaises(TypeError):
1265+
_zstd.train_dict(b'', (99.1,), 100)
1266+
with self.assertRaises(ValueError):
1267+
_zstd.train_dict(b'abc', (4, -1), 100)
1268+
with self.assertRaises(ValueError):
1269+
_zstd.train_dict(b'abc', (2,), 100)
1270+
with self.assertRaises(ValueError):
1271+
_zstd.train_dict(b'', (99,), 100)
12461272

12471273
# size > size_t
12481274
with self.assertRaises(ValueError):
1249-
_zstd.train_dict(b'', (2**64+1,), 100)
1275+
_zstd.train_dict(b'', (2**1000,), 100)
1276+
with self.assertRaises(ValueError):
1277+
_zstd.train_dict(b'', (-2**1000,), 100)
12501278

12511279
# dict_size <= 0
12521280
with self.assertRaises(ValueError):
12531281
_zstd.train_dict(b'', (), 0)
1282+
with self.assertRaises(ValueError):
1283+
_zstd.train_dict(b'', (), -1)
1284+
1285+
with self.assertRaises(ZstdError):
1286+
_zstd.train_dict(b'', (), 1)
12541287

12551288
def test_finalize_dict_c(self):
12561289
with self.assertRaises(TypeError):
@@ -1259,22 +1292,51 @@ def test_finalize_dict_c(self):
12591292
# argument wrong type
12601293
with self.assertRaises(TypeError):
12611294
_zstd.finalize_dict({}, b'', (), 100, 5)
1295+
with self.assertRaises(TypeError):
1296+
_zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5)
12621297
with self.assertRaises(TypeError):
12631298
_zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5)
1299+
with self.assertRaises(TypeError):
1300+
_zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5)
12641301
with self.assertRaises(TypeError):
12651302
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5)
1303+
with self.assertRaises(TypeError):
1304+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5)
12661305
with self.assertRaises(TypeError):
12671306
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5)
12681307
with self.assertRaises(TypeError):
12691308
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1)
12701309

1310+
with self.assertRaises(ValueError):
1311+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5)
1312+
with self.assertRaises(ValueError):
1313+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5)
1314+
with self.assertRaises(ValueError):
1315+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5)
1316+
12711317
# size > size_t
12721318
with self.assertRaises(ValueError):
1273-
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5)
1319+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5)
1320+
with self.assertRaises(ValueError):
1321+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5)
12741322

12751323
# dict_size <= 0
12761324
with self.assertRaises(ValueError):
12771325
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5)
1326+
with self.assertRaises(ValueError):
1327+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5)
1328+
with self.assertRaises(OverflowError):
1329+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5)
1330+
with self.assertRaises(OverflowError):
1331+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5)
1332+
1333+
with self.assertRaises(OverflowError):
1334+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000)
1335+
with self.assertRaises(OverflowError):
1336+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000)
1337+
1338+
with self.assertRaises(ZstdError):
1339+
_zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5)
12781340

12791341
def test_train_buffer_protocol_samples(self):
12801342
def _nbytes(dat):

Modules/_zstd/_zstdmodule.c

Lines changed: 60 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "Python.h"
88

99
#include "_zstdmodule.h"
10-
#include "zstddict.h"
1110

1211
#include <zstd.h> // ZSTD_*()
1312
#include <zdict.h> // ZDICT_*()
@@ -20,14 +19,52 @@ module _zstd
2019
#include "clinic/_zstdmodule.c.h"
2120

2221

22+
ZstdDict *
23+
_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype)
24+
{
25+
if (state == NULL) {
26+
return NULL;
27+
}
28+
29+
/* Check ZstdDict */
30+
if (PyObject_TypeCheck(dict, state->ZstdDict_type)) {
31+
return (ZstdDict*)dict;
32+
}
33+
34+
/* Check (ZstdDict, type) */
35+
if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2
36+
&& PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type)
37+
&& PyLong_Check(PyTuple_GET_ITEM(dict, 1)))
38+
{
39+
int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1));
40+
if (type == -1 && PyErr_Occurred()) {
41+
return NULL;
42+
}
43+
if (type == DICT_TYPE_DIGESTED
44+
|| type == DICT_TYPE_UNDIGESTED
45+
|| type == DICT_TYPE_PREFIX)
46+
{
47+
*ptype = type;
48+
return (ZstdDict*)PyTuple_GET_ITEM(dict, 0);
49+
}
50+
}
51+
52+
/* Wrong type */
53+
PyErr_SetString(PyExc_TypeError,
54+
"zstd_dict argument should be a ZstdDict object.");
55+
return NULL;
56+
}
57+
2358
/* Format error message and set ZstdError. */
2459
void
25-
set_zstd_error(const _zstd_state* const state,
26-
error_type type, size_t zstd_ret)
60+
set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret)
2761
{
28-
char *msg;
62+
const char *msg;
2963
assert(ZSTD_isError(zstd_ret));
3064

65+
if (state == NULL) {
66+
return;
67+
}
3168
switch (type) {
3269
case ERR_DECOMPRESS:
3370
msg = "Unable to decompress Zstandard data: %s";
@@ -174,7 +211,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
174211
Py_ssize_t sizes_sum;
175212
Py_ssize_t i;
176213

177-
chunks_number = Py_SIZE(samples_sizes);
214+
chunks_number = PyTuple_GET_SIZE(samples_sizes);
178215
if ((size_t) chunks_number > UINT32_MAX) {
179216
PyErr_Format(PyExc_ValueError,
180217
"The number of samples should be <= %u.", UINT32_MAX);
@@ -188,20 +225,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
188225
return -1;
189226
}
190227

191-
sizes_sum = 0;
228+
sizes_sum = PyBytes_GET_SIZE(samples_bytes);
192229
for (i = 0; i < chunks_number; i++) {
193-
PyObject *size = PyTuple_GetItem(samples_sizes, i);
194-
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
195-
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
196-
PyErr_Format(PyExc_ValueError,
197-
"Items in samples_sizes should be an int "
198-
"object, with a value between 0 and %u.", SIZE_MAX);
230+
size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i));
231+
(*chunk_sizes)[i] = size;
232+
if (size == (size_t)-1 && PyErr_Occurred()) {
233+
if (PyErr_ExceptionMatches(PyExc_OverflowError)) {
234+
goto sum_error;
235+
}
199236
return -1;
200237
}
201-
sizes_sum += (*chunk_sizes)[i];
238+
if ((size_t)sizes_sum < size) {
239+
goto sum_error;
240+
}
241+
sizes_sum -= size;
202242
}
203243

204-
if (sizes_sum != Py_SIZE(samples_bytes)) {
244+
if (sizes_sum != 0) {
245+
sum_error:
205246
PyErr_SetString(PyExc_ValueError,
206247
"The samples size tuple doesn't match the "
207248
"concatenation's size.");
@@ -257,7 +298,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
257298

258299
/* Train the dictionary */
259300
char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes);
260-
char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
301+
const char *samples_buffer = PyBytes_AS_STRING(samples_bytes);
261302
Py_BEGIN_ALLOW_THREADS
262303
zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size,
263304
samples_buffer,
@@ -507,17 +548,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type,
507548
{
508549
_zstd_state* mod_state = get_zstd_state(module);
509550

510-
if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) {
511-
PyErr_SetString(PyExc_ValueError,
512-
"The two arguments should be CompressionParameter and "
513-
"DecompressionParameter types.");
514-
return NULL;
515-
}
516-
517-
Py_XSETREF(
518-
mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type));
519-
Py_XSETREF(
520-
mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type));
551+
Py_INCREF(c_parameter_type);
552+
Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type);
553+
Py_INCREF(d_parameter_type);
554+
Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type);
521555

522556
Py_RETURN_NONE;
523557
}
@@ -580,7 +614,6 @@ do { \
580614
return -1;
581615
}
582616
if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) {
583-
Py_DECREF(mod_state->ZstdError);
584617
return -1;
585618
}
586619

Modules/_zstd/_zstdmodule.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#ifndef ZSTD_MODULE_H
66
#define ZSTD_MODULE_H
77

8+
#include "zstddict.h"
9+
810
/* Type specs */
911
extern PyType_Spec zstd_dict_type_spec;
1012
extern PyType_Spec zstd_compressor_type_spec;
@@ -43,10 +45,14 @@ typedef enum {
4345
DICT_TYPE_PREFIX = 2
4446
} dictionary_type;
4547

48+
extern ZstdDict *
49+
_Py_parse_zstd_dict(const _zstd_state *state,
50+
PyObject *dict, int *type);
51+
4652
/* Format error message and set ZstdError. */
4753
extern void
48-
set_zstd_error(const _zstd_state* const state,
49-
const error_type type, size_t zstd_ret);
54+
set_zstd_error(const _zstd_state *state,
55+
error_type type, size_t zstd_ret);
5056

5157
extern void
5258
set_parameter_error(int is_compress, int key_v, int value_v);

0 commit comments

Comments
 (0)