Skip to content

Commit 0c9f7c2

Browse files
authored
Merge pull request #8447 from anntzer/imshow-float128
ENH: Let imshow handle float128 data.
2 parents 9335cc6 + df133ef commit 0c9f7c2

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

lib/matplotlib/colors.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _to_rgba_no_colorcycle(c, alpha=None):
188188
raise ValueError("Invalid RGBA argument: {!r}".format(orig_c))
189189
# tuple color.
190190
c = np.array(c)
191-
if not np.can_cast(c.dtype, float) or c.ndim != 1:
191+
if not np.can_cast(c.dtype, float, "same_kind") or c.ndim != 1:
192192
# Test the dtype explicitly as `map(float, ...)`, `np.array(...,
193193
# float)` and `np.array(...).astype(float)` all convert "0.5" to 0.5.
194194
# Test dimensionality to reject single floats.
@@ -934,6 +934,11 @@ def __call__(self, value, clip=None):
934934
resdat -= vmin
935935
resdat /= (vmax - vmin)
936936
result = np.ma.array(resdat, mask=result.mask, copy=False)
937+
# Agg cannot handle float128. We actually only need 32-bit of
938+
# precision, but on Windows, `np.dtype(np.longdouble) == np.float64`,
939+
# so casting to float32 would lose precision on float64s as well.
940+
if result.dtype == np.longdouble:
941+
result = result.astype(np.float64)
937942
if is_scalar:
938943
result = result[0]
939944
return result

lib/matplotlib/image.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,9 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
299299
space.
300300
"""
301301
if A is None:
302-
raise RuntimeError('You must first set the image'
303-
' array or the image attribute')
304-
if any(s == 0 for s in A.shape):
302+
raise RuntimeError('You must first set the image '
303+
'array or the image attribute')
304+
if A.size == 0:
305305
raise RuntimeError("_make_image must get a non-empty image. "
306306
"Your Artist's draw method must filter before "
307307
"this method is called.")
@@ -359,7 +359,7 @@ def _make_image(self, A, in_bbox, out_bbox, clip_bbox, magnification=1.0,
359359
created_rgba_mask = False
360360

361361
if A.ndim not in (2, 3):
362-
raise ValueError("Invalid dimensions, got %s" % (A.shape,))
362+
raise ValueError("Invalid dimensions, got {}".format(A.shape))
363363

364364
if A.ndim == 2:
365365
A = self.norm(A)
@@ -591,11 +591,11 @@ def set_data(self, A):
591591
self._A = cbook.safe_masked_invalid(A, copy=True)
592592

593593
if (self._A.dtype != np.uint8 and
594-
not np.can_cast(self._A.dtype, float)):
595-
raise TypeError("Image data can not convert to float")
594+
not np.can_cast(self._A.dtype, float, "same_kind")):
595+
raise TypeError("Image data cannot be converted to float")
596596

597-
if (self._A.ndim not in (2, 3) or
598-
(self._A.ndim == 3 and self._A.shape[-1] not in (3, 4))):
597+
if not (self._A.ndim == 2
598+
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
599599
raise TypeError("Invalid dimensions for image data")
600600

601601
self._imcache = None

lib/matplotlib/tests/test_image.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -505,9 +505,9 @@ def test_jpeg_alpha():
505505
def test_nonuniformimage_setdata():
506506
ax = plt.gca()
507507
im = NonUniformImage(ax)
508-
x = np.arange(3, dtype=np.float64)
509-
y = np.arange(4, dtype=np.float64)
510-
z = np.arange(12, dtype=np.float64).reshape((4, 3))
508+
x = np.arange(3, dtype=float)
509+
y = np.arange(4, dtype=float)
510+
z = np.arange(12, dtype=float).reshape((4, 3))
511511
im.set_data(x, y, z)
512512
x[0] = y[0] = z[0, 0] = 9.9
513513
assert im._A[0, 0] == im._Ax[0] == im._Ay[0] == 0, 'value changed'
@@ -516,7 +516,7 @@ def test_nonuniformimage_setdata():
516516
def test_axesimage_setdata():
517517
ax = plt.gca()
518518
im = AxesImage(ax)
519-
z = np.arange(12, dtype=np.float64).reshape((4, 3))
519+
z = np.arange(12, dtype=float).reshape((4, 3))
520520
im.set_data(z)
521521
z[0, 0] = 9.9
522522
assert im._A[0, 0] == 0, 'value changed'
@@ -525,7 +525,7 @@ def test_axesimage_setdata():
525525
def test_figureimage_setdata():
526526
fig = plt.gcf()
527527
im = FigureImage(fig)
528-
z = np.arange(12, dtype=np.float64).reshape((4, 3))
528+
z = np.arange(12, dtype=float).reshape((4, 3))
529529
im.set_data(z)
530530
z[0, 0] = 9.9
531531
assert im._A[0, 0] == 0, 'value changed'
@@ -534,9 +534,9 @@ def test_figureimage_setdata():
534534
def test_pcolorimage_setdata():
535535
ax = plt.gca()
536536
im = PcolorImage(ax)
537-
x = np.arange(3, dtype=np.float64)
538-
y = np.arange(4, dtype=np.float64)
539-
z = np.arange(6, dtype=np.float64).reshape((3, 2))
537+
x = np.arange(3, dtype=float)
538+
y = np.arange(4, dtype=float)
539+
z = np.arange(6, dtype=float).reshape((3, 2))
540540
im.set_data(x, y, z)
541541
x[0] = y[0] = z[0, 0] = 9.9
542542
assert im._A[0, 0] == im._Ax[0] == im._Ay[0] == 0, 'value changed'
@@ -778,3 +778,8 @@ def test_empty_imshow():
778778

779779
with pytest.raises(RuntimeError):
780780
im.make_image(fig._cachedRenderer)
781+
782+
783+
def test_imshow_float128():
784+
fig, ax = plt.subplots()
785+
ax.imshow(np.zeros((3, 3), dtype=np.longdouble))

0 commit comments

Comments
 (0)