diff --git a/Lib/test/test_zstd.py b/Lib/test/test_zstd.py index 1562c29d2e53a3..e54d86f02100ea 100644 --- a/Lib/test/test_zstd.py +++ b/Lib/test/test_zstd.py @@ -1139,27 +1139,41 @@ def test_invalid_dict(self): ZstdDecompressor(zd) # wrong type - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): - ZstdCompressor(zstd_dict=(zd, b'123')) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdCompressor(zstd_dict=(zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, 1, 2)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, -1)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdCompressor(zstd_dict=(zd, 3)) - - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): - ZstdDecompressor(zstd_dict=(zd, b'123')) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdCompressor(zstd_dict=(zd, -2**1000)) + + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=[zd, 1]) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor(zstd_dict=(zd, 1.0)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): + ZstdDecompressor((zd,)) + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, 1, 2)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, -1)) - with self.assertRaisesRegex(TypeError, r'should be ZstdDict object'): + with self.assertRaisesRegex(TypeError, r'should be a ZstdDict object'): ZstdDecompressor((zd, 3)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, 2**1000)) + with self.assertRaises(OverflowError): + ZstdDecompressor((zd, -2**1000)) def test_train_dict(self): - - TRAINED_DICT = train_dict(SAMPLES, DICT_SIZE1) ZstdDict(TRAINED_DICT.dict_content, is_raw=False) @@ -1240,18 +1254,37 @@ def test_train_dict_c(self): # argument wrong type with self.assertRaises(TypeError): _zstd.train_dict({}, (), 100) + with self.assertRaises(TypeError): + _zstd.train_dict(bytearray(), (), 100) with self.assertRaises(TypeError): _zstd.train_dict(b'', 99, 100) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', [], 100) with self.assertRaises(TypeError): _zstd.train_dict(b'', (), 100.1) + with self.assertRaises(TypeError): + _zstd.train_dict(b'', (99.1,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (4, -1), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'abc', (2,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (99,), 100) # size > size_t with self.assertRaises(ValueError): - _zstd.train_dict(b'', (2**64+1,), 100) + _zstd.train_dict(b'', (2**1000,), 100) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (-2**1000,), 100) # dict_size <= 0 with self.assertRaises(ValueError): _zstd.train_dict(b'', (), 0) + with self.assertRaises(ValueError): + _zstd.train_dict(b'', (), -1) + + with self.assertRaises(ZstdError): + _zstd.train_dict(b'', (), 1) def test_finalize_dict_c(self): with self.assertRaises(TypeError): @@ -1260,22 +1293,51 @@ def test_finalize_dict_c(self): # argument wrong type with self.assertRaises(TypeError): _zstd.finalize_dict({}, b'', (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(bytearray(TRAINED_DICT.dict_content), b'', (), 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, {}, (), 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, bytearray(), (), 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', 99, 100, 5) + with self.assertRaises(TypeError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', [], 100, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100.1, 5) with self.assertRaises(TypeError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5.1) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (4, -1), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'abc', (2,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (99,), 100, 5) + # size > size_t with self.assertRaises(ValueError): - _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**64+1,), 100, 5) + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (2**1000,), 100, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (-2**1000,), 100, 5) # dict_size <= 0 with self.assertRaises(ValueError): _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 0, 5) + with self.assertRaises(ValueError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -1, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 2**1000, 5) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), -2**1000, 5) + + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 2**1000) + with self.assertRaises(OverflowError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, -2**1000) + + with self.assertRaises(ZstdError): + _zstd.finalize_dict(TRAINED_DICT.dict_content, b'', (), 100, 5) def test_train_buffer_protocol_samples(self): def _nbytes(dat): diff --git a/Modules/_zstd/_zstdmodule.c b/Modules/_zstd/_zstdmodule.c index 986b3579479f0f..b0e50f873f4ca6 100644 --- a/Modules/_zstd/_zstdmodule.c +++ b/Modules/_zstd/_zstdmodule.c @@ -7,7 +7,6 @@ #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include // ZSTD_*() #include // ZDICT_*() @@ -20,14 +19,52 @@ module _zstd #include "clinic/_zstdmodule.c.h" +ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, PyObject *dict, int *ptype) +{ + if (state == NULL) { + return NULL; + } + + /* Check ZstdDict */ + if (PyObject_TypeCheck(dict, state->ZstdDict_type)) { + return (ZstdDict*)dict; + } + + /* Check (ZstdDict, type) */ + if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2 + && PyObject_TypeCheck(PyTuple_GET_ITEM(dict, 0), state->ZstdDict_type) + && PyLong_Check(PyTuple_GET_ITEM(dict, 1))) + { + int type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); + if (type == -1 && PyErr_Occurred()) { + return NULL; + } + if (type == DICT_TYPE_DIGESTED + || type == DICT_TYPE_UNDIGESTED + || type == DICT_TYPE_PREFIX) + { + *ptype = type; + return (ZstdDict*)PyTuple_GET_ITEM(dict, 0); + } + } + + /* Wrong type */ + PyErr_SetString(PyExc_TypeError, + "zstd_dict argument should be a ZstdDict object."); + return NULL; +} + /* Format error message and set ZstdError. */ void -set_zstd_error(const _zstd_state* const state, - error_type type, size_t zstd_ret) +set_zstd_error(const _zstd_state *state, error_type type, size_t zstd_ret) { - char *msg; + const char *msg; assert(ZSTD_isError(zstd_ret)); + if (state == NULL) { + return; + } switch (type) { case ERR_DECOMPRESS: msg = "Unable to decompress Zstandard data: %s"; @@ -174,7 +211,7 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, Py_ssize_t sizes_sum; Py_ssize_t i; - chunks_number = Py_SIZE(samples_sizes); + chunks_number = PyTuple_GET_SIZE(samples_sizes); if ((size_t) chunks_number > UINT32_MAX) { PyErr_Format(PyExc_ValueError, "The number of samples should be <= %u.", UINT32_MAX); @@ -188,20 +225,24 @@ calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes, return -1; } - sizes_sum = 0; + sizes_sum = PyBytes_GET_SIZE(samples_bytes); for (i = 0; i < chunks_number; i++) { - PyObject *size = PyTuple_GetItem(samples_sizes, i); - (*chunk_sizes)[i] = PyLong_AsSize_t(size); - if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) { - PyErr_Format(PyExc_ValueError, - "Items in samples_sizes should be an int " - "object, with a value between 0 and %u.", SIZE_MAX); + size_t size = PyLong_AsSize_t(PyTuple_GET_ITEM(samples_sizes, i)); + (*chunk_sizes)[i] = size; + if (size == (size_t)-1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_OverflowError)) { + goto sum_error; + } return -1; } - sizes_sum += (*chunk_sizes)[i]; + if ((size_t)sizes_sum < size) { + goto sum_error; + } + sizes_sum -= size; } - if (sizes_sum != Py_SIZE(samples_bytes)) { + if (sizes_sum != 0) { +sum_error: PyErr_SetString(PyExc_ValueError, "The samples size tuple doesn't match the " "concatenation's size."); @@ -257,7 +298,7 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes, /* Train the dictionary */ char *dst_dict_buffer = PyBytes_AS_STRING(dst_dict_bytes); - char *samples_buffer = PyBytes_AS_STRING(samples_bytes); + const char *samples_buffer = PyBytes_AS_STRING(samples_bytes); Py_BEGIN_ALLOW_THREADS zstd_ret = ZDICT_trainFromBuffer(dst_dict_buffer, dict_size, samples_buffer, @@ -507,17 +548,10 @@ _zstd_set_parameter_types_impl(PyObject *module, PyObject *c_parameter_type, { _zstd_state* mod_state = get_zstd_state(module); - if (!PyType_Check(c_parameter_type) || !PyType_Check(d_parameter_type)) { - PyErr_SetString(PyExc_ValueError, - "The two arguments should be CompressionParameter and " - "DecompressionParameter types."); - return NULL; - } - - Py_XSETREF( - mod_state->CParameter_type, (PyTypeObject*)Py_NewRef(c_parameter_type)); - Py_XSETREF( - mod_state->DParameter_type, (PyTypeObject*)Py_NewRef(d_parameter_type)); + Py_INCREF(c_parameter_type); + Py_XSETREF(mod_state->CParameter_type, (PyTypeObject*)c_parameter_type); + Py_INCREF(d_parameter_type); + Py_XSETREF(mod_state->DParameter_type, (PyTypeObject*)d_parameter_type); Py_RETURN_NONE; } @@ -580,7 +614,6 @@ do { \ return -1; } if (PyModule_AddType(m, (PyTypeObject *)mod_state->ZstdError) < 0) { - Py_DECREF(mod_state->ZstdError); return -1; } diff --git a/Modules/_zstd/_zstdmodule.h b/Modules/_zstd/_zstdmodule.h index 1f4160f474f0b0..c73f15b3c5299b 100644 --- a/Modules/_zstd/_zstdmodule.h +++ b/Modules/_zstd/_zstdmodule.h @@ -5,6 +5,8 @@ #ifndef ZSTD_MODULE_H #define ZSTD_MODULE_H +#include "zstddict.h" + /* Type specs */ extern PyType_Spec zstd_dict_type_spec; extern PyType_Spec zstd_compressor_type_spec; @@ -43,10 +45,14 @@ typedef enum { DICT_TYPE_PREFIX = 2 } dictionary_type; +extern ZstdDict * +_Py_parse_zstd_dict(const _zstd_state *state, + PyObject *dict, int *type); + /* Format error message and set ZstdError. */ extern void -set_zstd_error(const _zstd_state* const state, - const error_type type, size_t zstd_ret); +set_zstd_error(const _zstd_state *state, + error_type type, size_t zstd_ret); extern void set_parameter_error(int is_compress, int key_v, int value_v); diff --git a/Modules/_zstd/compressor.c b/Modules/_zstd/compressor.c index 8ff2a3aadc1cd6..e1217635f60cb0 100644 --- a/Modules/_zstd/compressor.c +++ b/Modules/_zstd/compressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdCompressor "ZstdCompressor *" "&zstd_compressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include // offsetof() @@ -71,9 +70,6 @@ _zstd_set_c_level(ZstdCompressor *self, int level) /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { - return -1; - } set_zstd_error(mod_state, ERR_SET_C_LEVEL, zstd_ret); return -1; } @@ -265,56 +261,17 @@ static int _zstd_load_c_dict(ZstdCompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { - return -1; - } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { + /* When compressing, use undigested dictionary by default. */ + int type = DICT_TYPE_UNDIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - else if (ret > 0) { - /* When compressing, use undigested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_UNDIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /*[clinic input] @@ -481,9 +438,7 @@ compress_lock_held(ZstdCompressor *self, Py_buffer *data, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } @@ -553,9 +508,7 @@ compress_mt_continue_lock_held(ZstdCompressor *self, Py_buffer *data) /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_COMPRESS, zstd_ret); goto error; } diff --git a/Modules/_zstd/decompressor.c b/Modules/_zstd/decompressor.c index 26e568cf433308..c53d6e4cb05cf0 100644 --- a/Modules/_zstd/decompressor.c +++ b/Modules/_zstd/decompressor.c @@ -16,7 +16,6 @@ class _zstd.ZstdDecompressor "ZstdDecompressor *" "&zstd_decompressor_type_spec" #include "_zstdmodule.h" #include "buffer.h" -#include "zstddict.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked #include // bool @@ -61,11 +60,6 @@ _get_DDict(ZstdDict *self) assert(PyMutex_IsLocked(&self->lock)); ZSTD_DDict *ret; - /* Already created */ - if (self->d_dict != NULL) { - return self->d_dict; - } - if (self->d_dict == NULL) { /* Create ZSTD_DDict instance from dictionary content */ Py_BEGIN_ALLOW_THREADS @@ -182,56 +176,17 @@ static int _zstd_load_d_dict(ZstdDecompressor *self, PyObject *dict) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state == NULL) { - return -1; - } - ZstdDict *zd; - int type, ret; - - /* Check ZstdDict */ - ret = PyObject_IsInstance(dict, (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { + /* When decompressing, use digested dictionary by default. */ + int type = DICT_TYPE_DIGESTED; + ZstdDict *zd = _Py_parse_zstd_dict(mod_state, dict, &type); + if (zd == NULL) { return -1; } - else if (ret > 0) { - /* When decompressing, use digested dictionary by default. */ - zd = (ZstdDict*)dict; - type = DICT_TYPE_DIGESTED; - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - - /* Check (ZstdDict, type) */ - if (PyTuple_CheckExact(dict) && PyTuple_GET_SIZE(dict) == 2) { - /* Check ZstdDict */ - ret = PyObject_IsInstance(PyTuple_GET_ITEM(dict, 0), - (PyObject*)mod_state->ZstdDict_type); - if (ret < 0) { - return -1; - } - else if (ret > 0) { - /* type == -1 may indicate an error. */ - type = PyLong_AsInt(PyTuple_GET_ITEM(dict, 1)); - if (type == DICT_TYPE_DIGESTED - || type == DICT_TYPE_UNDIGESTED - || type == DICT_TYPE_PREFIX) - { - assert(type >= 0); - zd = (ZstdDict*)PyTuple_GET_ITEM(dict, 0); - PyMutex_Lock(&zd->lock); - ret = _zstd_load_impl(self, zd, mod_state, type); - PyMutex_Unlock(&zd->lock); - return ret; - } - } - } - - /* Wrong type */ - PyErr_SetString(PyExc_TypeError, - "zstd_dict argument should be ZstdDict object."); - return -1; + int ret; + PyMutex_Lock(&zd->lock); + ret = _zstd_load_impl(self, zd, mod_state, type); + PyMutex_Unlock(&zd->lock); + return ret; } /* @@ -282,9 +237,7 @@ decompress_lock_held(ZstdDecompressor *self, ZSTD_inBuffer *in, /* Check error */ if (ZSTD_isError(zstd_ret)) { _zstd_state* mod_state = PyType_GetModuleState(Py_TYPE(self)); - if (mod_state != NULL) { - set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); - } + set_zstd_error(mod_state, ERR_DECOMPRESS, zstd_ret); goto error; } diff --git a/Modules/_zstd/zstddict.c b/Modules/_zstd/zstddict.c index afc58b42e893d3..14f74aaed46ec5 100644 --- a/Modules/_zstd/zstddict.c +++ b/Modules/_zstd/zstddict.c @@ -15,7 +15,6 @@ class _zstd.ZstdDict "ZstdDict *" "&zstd_dict_type_spec" #include "Python.h" #include "_zstdmodule.h" -#include "zstddict.h" #include "clinic/zstddict.c.h" #include "internal/pycore_lock.h" // PyMutex_IsLocked