From 1fdd4b67596a77677a81ccad528d4f610a2dd668 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 4 Feb 2023 15:07:56 +0100 Subject: [PATCH] Group shape/dtype validation logic in image_resample. Move it all to a single place rather than having some of it interspersed with the dtype dispatch later. Also reorder the dtype dispatch to be consistent in the 2D and 3D cases, and remove _array from many local variable names. --- lib/matplotlib/tests/test_image.py | 15 +++ src/_image_wrapper.cpp | 162 ++++++++++++++--------------- 2 files changed, 91 insertions(+), 86 deletions(-) diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index 64a97ed58502..3ab99104c7ee 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -1,5 +1,6 @@ from contextlib import ExitStack from copy import copy +import functools import io import os from pathlib import Path @@ -1453,3 +1454,17 @@ def test_str_norms(fig_test, fig_ref): assert type(axts[0].images[0].norm) == colors.LogNorm # Exactly that class with pytest.raises(ValueError): axts[0].imshow(t, norm="foobar") + + +def test__resample_valid_output(): + resample = functools.partial(mpl._image.resample, transform=Affine2D()) + with pytest.raises(ValueError, match="must be a NumPy array"): + resample(np.zeros((9, 9)), None) + with pytest.raises(ValueError, match="different dimensionalities"): + resample(np.zeros((9, 9)), np.zeros((9, 9, 4))) + with pytest.raises(ValueError, match="must be RGBA"): + resample(np.zeros((9, 9, 4)), np.zeros((9, 9, 3))) + with pytest.raises(ValueError, match="Mismatched types"): + resample(np.zeros((9, 9), np.uint8), np.zeros((9, 9))) + with pytest.raises(ValueError, match="must be C-contiguous"): + resample(np.zeros((9, 9)), np.zeros((9, 9)).T) diff --git a/src/_image_wrapper.cpp b/src/_image_wrapper.cpp index 9eba0249d3e9..ba2f6cbf8651 100644 --- a/src/_image_wrapper.cpp +++ b/src/_image_wrapper.cpp @@ -9,7 +9,7 @@ * */ const char* image_resample__doc__ = -"resample(input_array, output_array, matrix, interpolation=NEAREST, alpha=1.0, norm=False, radius=1)\n" +"resample(input_array, output_array, transform, interpolation=NEAREST, alpha=1.0, norm=False, radius=1)\n" "--\n\n" "Resample input_array, blending it in-place into output_array, using an\n" @@ -121,14 +121,15 @@ resample(PyArrayObject* input, PyArrayObject* output, resample_params_t params) static PyObject * image_resample(PyObject *self, PyObject* args, PyObject *kwargs) { - PyObject *py_input_array = NULL; - PyObject *py_output_array = NULL; + PyObject *py_input = NULL; + PyObject *py_output = NULL; PyObject *py_transform = NULL; resample_params_t params; - PyArrayObject *input_array = NULL; - PyArrayObject *output_array = NULL; - PyArrayObject *transform_mesh_array = NULL; + PyArrayObject *input = NULL; + PyArrayObject *output = NULL; + PyArrayObject *transform_mesh = NULL; + int ndim; params.interpolation = NEAREST; params.transform_mesh = NULL; @@ -143,36 +144,52 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs) if (!PyArg_ParseTupleAndKeywords( args, kwargs, "OOO|iO&dO&d:resample", (char **)kwlist, - &py_input_array, &py_output_array, &py_transform, + &py_input, &py_output, &py_transform, ¶ms.interpolation, &convert_bool, ¶ms.resample, ¶ms.alpha, &convert_bool, ¶ms.norm, ¶ms.radius)) { return NULL; } if (params.interpolation < 0 || params.interpolation >= _n_interpolation) { - PyErr_Format(PyExc_ValueError, "invalid interpolation value %d", + PyErr_Format(PyExc_ValueError, "Invalid interpolation value %d", params.interpolation); goto error; } - input_array = (PyArrayObject *)PyArray_FromAny( - py_input_array, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL); - if (input_array == NULL) { + input = (PyArrayObject *)PyArray_FromAny( + py_input, NULL, 2, 3, NPY_ARRAY_C_CONTIGUOUS, NULL); + if (!input) { goto error; } + ndim = PyArray_NDIM(input); - if (!PyArray_Check(py_output_array)) { - PyErr_SetString(PyExc_ValueError, "output array must be a NumPy array"); + if (!PyArray_Check(py_output)) { + PyErr_SetString(PyExc_ValueError, "Output array must be a NumPy array"); goto error; } - output_array = (PyArrayObject *)py_output_array; - if (!PyArray_IS_C_CONTIGUOUS(output_array)) { - PyErr_SetString(PyExc_ValueError, "output array must be C-contiguous"); + output = (PyArrayObject *)py_output; + if (PyArray_NDIM(output) != ndim) { + PyErr_Format( + PyExc_ValueError, + "Input (%dD) and output (%dD) have different dimensionalities.", + ndim, PyArray_NDIM(output)); + goto error; + } + // PyArray_FromAny above checks that input is 2D or 3D. + if (ndim == 3 && (PyArray_DIM(input, 2) != 4 || PyArray_DIM(output, 2) != 4)) { + PyErr_Format( + PyExc_ValueError, + "If 3D, input and output arrays must be RGBA with shape (M, N, 4); " + "got trailing dimensions of %" NPY_INTP_FMT " and %" NPY_INTP_FMT + " respectively", PyArray_DIM(input, 2), PyArray_DIM(output, 2)); goto error; } - if (PyArray_NDIM(output_array) < 2 || PyArray_NDIM(output_array) > 3) { - PyErr_SetString(PyExc_ValueError, - "output array must be 2- or 3-dimensional"); + if (PyArray_TYPE(input) != PyArray_TYPE(output)) { + PyErr_SetString(PyExc_ValueError, "Mismatched types"); + goto error; + } + if (!PyArray_IS_C_CONTIGUOUS(output)) { + PyErr_SetString(PyExc_ValueError, "Output array must be C-contiguous"); goto error; } @@ -182,7 +199,7 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs) PyObject *py_is_affine; int py_is_affine2; py_is_affine = PyObject_GetAttrString(py_transform, "is_affine"); - if (py_is_affine == NULL) { + if (!py_is_affine) { goto error; } @@ -197,96 +214,69 @@ image_resample(PyObject *self, PyObject* args, PyObject *kwargs) } params.is_affine = true; } else { - transform_mesh_array = _get_transform_mesh( - py_transform, PyArray_DIMS(output_array)); - if (transform_mesh_array == NULL) { + transform_mesh = _get_transform_mesh( + py_transform, PyArray_DIMS(output)); + if (!transform_mesh) { goto error; } - params.transform_mesh = (double *)PyArray_DATA(transform_mesh_array); + params.transform_mesh = (double *)PyArray_DATA(transform_mesh); params.is_affine = false; } } - if (PyArray_NDIM(input_array) != PyArray_NDIM(output_array)) { - PyErr_Format( - PyExc_ValueError, - "Mismatched number of dimensions. Got %d and %d.", - PyArray_NDIM(input_array), PyArray_NDIM(output_array)); - goto error; - } - - if (PyArray_TYPE(input_array) != PyArray_TYPE(output_array)) { - PyErr_SetString(PyExc_ValueError, "Mismatched types"); - goto error; - } - - if (PyArray_NDIM(input_array) == 3) { - if (PyArray_DIM(output_array, 2) != 4) { + if (ndim == 3) { + switch (PyArray_TYPE(input)) { + case NPY_UINT8: + case NPY_INT8: + resample(input, output, params); + break; + case NPY_UINT16: + case NPY_INT16: + resample(input, output, params); + break; + case NPY_FLOAT32: + resample(input, output, params); + break; + case NPY_FLOAT64: + resample(input, output, params); + break; + default: PyErr_SetString( PyExc_ValueError, - "Output array must be RGBA"); - goto error; - } - - if (PyArray_DIM(input_array, 2) == 4) { - switch (PyArray_TYPE(input_array)) { - case NPY_UINT8: - case NPY_INT8: - resample(input_array, output_array, params); - break; - case NPY_UINT16: - case NPY_INT16: - resample(input_array, output_array, params); - break; - case NPY_FLOAT32: - resample(input_array, output_array, params); - break; - case NPY_FLOAT64: - resample(input_array, output_array, params); - break; - default: - PyErr_SetString( - PyExc_ValueError, - "3-dimensional arrays must be of dtype unsigned byte, " - "unsigned short, float32 or float64"); - goto error; - } - } else { - PyErr_Format( - PyExc_ValueError, - "If 3-dimensional, array must be RGBA. Got %" NPY_INTP_FMT " planes.", - PyArray_DIM(input_array, 2)); + "arrays must be of dtype byte, short, float32 or float64"); goto error; } - } else { // NDIM == 2 - switch (PyArray_TYPE(input_array)) { - case NPY_DOUBLE: - resample(input_array, output_array, params); - break; - case NPY_FLOAT: - resample(input_array, output_array, params); - break; + } else { // ndim == 2 + switch (PyArray_TYPE(input)) { case NPY_UINT8: case NPY_INT8: - resample(input_array, output_array, params); + resample(input, output, params); break; case NPY_UINT16: case NPY_INT16: - resample(input_array, output_array, params); + resample(input, output, params); + break; + case NPY_FLOAT32: + resample(input, output, params); + break; + case NPY_FLOAT64: + resample(input, output, params); break; default: - PyErr_SetString(PyExc_ValueError, "Unsupported dtype"); + PyErr_SetString( + PyExc_ValueError, + "arrays must be of dtype byte, short, float32 or float64"); goto error; } } - Py_DECREF(input_array); - Py_XDECREF(transform_mesh_array); + Py_DECREF(input); + Py_XDECREF(transform_mesh); Py_RETURN_NONE; error: - Py_XDECREF(input_array); - Py_XDECREF(transform_mesh_array); + Py_XDECREF(input); + Py_XDECREF(transform_mesh); return NULL; }