Skip to content

Commit 0e4f3cc

Browse files
committed
Add 16-bit PNG support to _png module.
1 parent 14a4386 commit 0e4f3cc

File tree

4 files changed

+47
-8
lines changed

4 files changed

+47
-8
lines changed

lib/matplotlib/testing/compare.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def compare_images( expected, actual, tol, in_decorator=False ):
210210
expected = convert(expected)
211211

212212
# open the image files and remove the alpha channel (if it exists)
213-
expectedImage = _png.read_png_uint8( expected )
214-
actualImage = _png.read_png_uint8( actual )
213+
expectedImage = _png.read_png_int( expected )
214+
actualImage = _png.read_png_int( actual )
215215

216216
actualImage, expectedImage = crop_to_same(actual, actualImage, expected, expectedImage)
217217

lib/matplotlib/tests/test_png.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib.cm as cm
44
import glob
55
import os
6+
import numpy as np
67

78
@image_comparison(baseline_images=['pngsuite'], extensions=['png'])
89
def test_pngsuite():
@@ -25,3 +26,12 @@ def test_pngsuite():
2526

2627
plt.gca().get_frame().set_facecolor("#ddffff")
2728
plt.gca().set_xlim(0, len(files))
29+
30+
31+
def test_imread_png_uint16():
32+
from matplotlib import _png
33+
img = _png.read_png_int(os.path.join(os.path.dirname(__file__),
34+
'baseline_images/test_png/pngtest16rgba.png'))
35+
assert (img.dtype == np.uint16)
36+
37+
assert np.sum(img.flatten()) == 104855776

src/_png.cpp

+35-6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
4848
"read_png_float(fileobj)");
4949
add_varargs_method("read_png_uint8", &_png_module::read_png_uint8,
5050
"read_png_uint8(fileobj)");
51+
add_varargs_method("read_png_int", &_png_module::read_png_int,
52+
"read_png_int(fileobj)");
5153
initialize("Module to write PNG files");
5254
}
5355

@@ -57,7 +59,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
5759
Py::Object write_png(const Py::Tuple& args);
5860
Py::Object read_png_uint8(const Py::Tuple& args);
5961
Py::Object read_png_float(const Py::Tuple& args);
60-
PyObject* _read_png(const Py::Object& py_fileobj, const bool float_result);
62+
Py::Object read_png_int(const Py::Tuple& args);
63+
PyObject* _read_png(const Py::Object& py_fileobj, const bool float_result, const int bit_depth = -1);
6164
};
6265

6366
static void write_png_data(png_structp png_ptr, png_bytep data, png_size_t length)
@@ -297,7 +300,8 @@ static void read_png_data(png_structp png_ptr, png_bytep data, png_size_t length
297300
}
298301

299302
PyObject*
300-
_png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
303+
_png_module::_read_png(const Py::Object& py_fileobj, const bool float_result,
304+
const int result_bit_depth)
301305
{
302306
png_byte header[8]; // 8 is the maximum size that can be checked
303307
FILE* fp = NULL;
@@ -502,7 +506,18 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
502506
}
503507
}
504508
} else {
505-
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UBYTE);
509+
if (result_bit_depth == 8) {
510+
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UBYTE);
511+
} else {
512+
if (bit_depth == 8) {
513+
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UBYTE);
514+
} else if (bit_depth == 16) {
515+
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UINT16);
516+
} else {
517+
throw Py::RuntimeError(
518+
"_image_module::readpng: image has unknown bit depth");
519+
}
520+
}
506521

507522
if (A == NULL)
508523
{
@@ -518,9 +533,17 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
518533
if (bit_depth == 16)
519534
{
520535
png_uint_16* ptr = &reinterpret_cast<png_uint_16*>(row)[x * dimensions[2]];
521-
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
522-
{
523-
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p] >> 8;
536+
537+
if (bit_depth == 16) {
538+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
539+
{
540+
*(png_uint_16*)(A->data + offset + p*A->strides[2]) = ptr[p];
541+
}
542+
} else {
543+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
544+
{
545+
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p] >> 8;
546+
}
524547
}
525548
}
526549
else
@@ -569,6 +592,12 @@ _png_module::read_png_float(const Py::Tuple& args)
569592

570593
Py::Object
571594
_png_module::read_png_uint8(const Py::Tuple& args)
595+
{
596+
throw Py::RuntimeError("read_png_uint8 is deprecated. Use read_png_int instead.");
597+
}
598+
599+
Py::Object
600+
_png_module::read_png_int(const Py::Tuple& args)
572601
{
573602
args.verify_length(1);
574603
return Py::asObject(_read_png(args[0], false));

0 commit comments

Comments
 (0)