Skip to content

Commit 7dbd4ed

Browse files
committed
Merge pull request #915 from mdboom/png-16bit
_png extension loads 16-bit PNGs as 8-bit
2 parents 14a4386 + c2872fa commit 7dbd4ed

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

lib/matplotlib/testing/compare.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def compare_images( expected, actual, tol, in_decorator=False ):
210210
expected = convert(expected)
211211

212212
# open the image files and remove the alpha channel (if it exists)
213-
expectedImage = _png.read_png_uint8( expected )
214-
actualImage = _png.read_png_uint8( actual )
213+
expectedImage = _png.read_png_int( expected )
214+
actualImage = _png.read_png_int( actual )
215215

216216
actualImage, expectedImage = crop_to_same(actual, actualImage, expected, expectedImage)
217217

Loading

lib/matplotlib/tests/test_png.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib.cm as cm
44
import glob
55
import os
6+
import numpy as np
67

78
@image_comparison(baseline_images=['pngsuite'], extensions=['png'])
89
def test_pngsuite():
@@ -25,3 +26,12 @@ def test_pngsuite():
2526

2627
plt.gca().get_frame().set_facecolor("#ddffff")
2728
plt.gca().set_xlim(0, len(files))
29+
30+
31+
def test_imread_png_uint16():
32+
from matplotlib import _png
33+
img = _png.read_png_int(os.path.join(os.path.dirname(__file__),
34+
'baseline_images/test_png/uint16.png'))
35+
36+
assert (img.dtype == np.uint16)
37+
assert np.sum(img.flatten()) == 134184960

src/_png.cpp

+45-9
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
4848
"read_png_float(fileobj)");
4949
add_varargs_method("read_png_uint8", &_png_module::read_png_uint8,
5050
"read_png_uint8(fileobj)");
51+
add_varargs_method("read_png_int", &_png_module::read_png_int,
52+
"read_png_int(fileobj)");
5153
initialize("Module to write PNG files");
5254
}
5355

@@ -57,7 +59,8 @@ class _png_module : public Py::ExtensionModule<_png_module>
5759
Py::Object write_png(const Py::Tuple& args);
5860
Py::Object read_png_uint8(const Py::Tuple& args);
5961
Py::Object read_png_float(const Py::Tuple& args);
60-
PyObject* _read_png(const Py::Object& py_fileobj, const bool float_result);
62+
Py::Object read_png_int(const Py::Tuple& args);
63+
PyObject* _read_png(const Py::Object& py_fileobj, const bool float_result, int result_bit_depth = -1);
6164
};
6265

6366
static void write_png_data(png_structp png_ptr, png_bytep data, png_size_t length)
@@ -297,7 +300,8 @@ static void read_png_data(png_structp png_ptr, png_bytep data, png_size_t length
297300
}
298301

299302
PyObject*
300-
_png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
303+
_png_module::_read_png(const Py::Object& py_fileobj, const bool float_result,
304+
int result_bit_depth)
301305
{
302306
png_byte header[8]; // 8 is the maximum size that can be checked
303307
FILE* fp = NULL;
@@ -502,7 +506,18 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
502506
}
503507
}
504508
} else {
505-
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UBYTE);
509+
if (result_bit_depth < 0) {
510+
result_bit_depth = bit_depth;
511+
}
512+
513+
if (result_bit_depth == 8) {
514+
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UBYTE);
515+
} else if (result_bit_depth == 16) {
516+
A = (PyArrayObject *) PyArray_SimpleNew(num_dims, dimensions, NPY_UINT16);
517+
} else {
518+
throw Py::RuntimeError(
519+
"_image_module::readpng: image has unknown bit depth");
520+
}
506521

507522
if (A == NULL)
508523
{
@@ -518,17 +533,32 @@ _png_module::_read_png(const Py::Object& py_fileobj, const bool float_result)
518533
if (bit_depth == 16)
519534
{
520535
png_uint_16* ptr = &reinterpret_cast<png_uint_16*>(row)[x * dimensions[2]];
521-
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
522-
{
523-
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p] >> 8;
536+
537+
if (result_bit_depth == 16) {
538+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
539+
{
540+
*(png_uint_16*)(A->data + offset + p*A->strides[2]) = ptr[p];
541+
}
542+
} else {
543+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
544+
{
545+
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p] >> 8;
546+
}
524547
}
525548
}
526549
else
527550
{
528551
png_byte* ptr = &(row[x * dimensions[2]]);
529-
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
530-
{
531-
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p];
552+
if (result_bit_depth == 16) {
553+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
554+
{
555+
*(png_uint_16*)(A->data + offset + p*A->strides[2]) = ptr[p];
556+
}
557+
} else {
558+
for (png_uint_32 p = 0; p < (png_uint_32)dimensions[2]; p++)
559+
{
560+
*(png_byte*)(A->data + offset + p*A->strides[2]) = ptr[p];
561+
}
532562
}
533563
}
534564
}
@@ -569,6 +599,12 @@ _png_module::read_png_float(const Py::Tuple& args)
569599

570600
Py::Object
571601
_png_module::read_png_uint8(const Py::Tuple& args)
602+
{
603+
throw Py::RuntimeError("read_png_uint8 is deprecated. Use read_png_int instead.");
604+
}
605+
606+
Py::Object
607+
_png_module::read_png_int(const Py::Tuple& args)
572608
{
573609
args.verify_length(1);
574610
return Py::asObject(_read_png(args[0], false));

0 commit comments

Comments
 (0)