Skip to content

Commit d9b709d

Browse files
committed
Add unit conversion for pcolor methods
1 parent 0dd5343 commit d9b709d

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy import ma
99

1010
import matplotlib as mpl
11+
import matplotlib.cm as cm
1112
import matplotlib.category # Register category unit converter as side effect.
1213
import matplotlib.cbook as cbook
1314
import matplotlib.collections as mcoll
@@ -5795,12 +5796,28 @@ def imshow(self, X, cmap=None, norm=None, *, aspect=None,
57955796
self.add_image(im)
57965797
return im
57975798

5799+
@staticmethod
5800+
def _convert_C_units(C):
5801+
"""
5802+
Remove any units attached to C, and return the units and converter used to do
5803+
the conversion.
5804+
"""
5805+
sm = cm.ScalarMappable()
5806+
C = sm._strip_units(C)
5807+
converter = sm._converter
5808+
units = sm._units
5809+
5810+
C = np.asanyarray(C)
5811+
C = cbook.safe_masked_invalid(C, copy=True)
5812+
return C, units, converter
5813+
57985814
def _pcolorargs(self, funcname, *args, shading='auto', **kwargs):
57995815
# - create X and Y if not present;
58005816
# - reshape X and Y as needed if they are 1-D;
58015817
# - check for proper sizes based on `shading` kwarg;
58025818
# - reset shading if shading='auto' to flat or nearest
58035819
# depending on size;
5820+
# - if C has units, get the converter
58045821

58055822
_valid_shading = ['gouraud', 'nearest', 'flat', 'auto']
58065823
try:
@@ -5812,19 +5829,19 @@ def _pcolorargs(self, funcname, *args, shading='auto', **kwargs):
58125829
shading = 'auto'
58135830

58145831
if len(args) == 1:
5815-
C = np.asanyarray(args[0])
5832+
C, units, converter = self._convert_C_units(args[0])
58165833
nrows, ncols = C.shape[:2]
58175834
if shading in ['gouraud', 'nearest']:
58185835
X, Y = np.meshgrid(np.arange(ncols), np.arange(nrows))
58195836
else:
58205837
X, Y = np.meshgrid(np.arange(ncols + 1), np.arange(nrows + 1))
58215838
shading = 'flat'
58225839
C = cbook.safe_masked_invalid(C, copy=True)
5823-
return X, Y, C, shading
5840+
return X, Y, C, shading, units, converter
58245841

58255842
if len(args) == 3:
58265843
# Check x and y for bad data...
5827-
C = np.asanyarray(args[2])
5844+
C, units, converter = self._convert_C_units(args[2])
58285845
# unit conversion allows e.g. datetime objects as axis values
58295846
X, Y = args[:2]
58305847
X, Y = self._process_unit_info([("x", X), ("y", Y)], kwargs)
@@ -5905,7 +5922,7 @@ def _interp_grid(X):
59055922
shading = 'flat'
59065923

59075924
C = cbook.safe_masked_invalid(C, copy=True)
5908-
return X, Y, C, shading
5925+
return X, Y, C, shading, units, converter
59095926

59105927
@_preprocess_data()
59115928
@_docstring.dedent_interpd
@@ -6057,8 +6074,9 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None,
60576074
if shading is None:
60586075
shading = mpl.rcParams['pcolor.shading']
60596076
shading = shading.lower()
6060-
X, Y, C, shading = self._pcolorargs('pcolor', *args, shading=shading,
6061-
kwargs=kwargs)
6077+
X, Y, C, shading, units, converter = self._pcolorargs(
6078+
'pcolor', *args, shading=shading, kwargs=kwargs
6079+
)
60626080
linewidths = (0.25,)
60636081
if 'linewidth' in kwargs:
60646082
kwargs['linewidths'] = kwargs.pop('linewidth')
@@ -6094,6 +6112,8 @@ def pcolor(self, *args, shading=None, alpha=None, norm=None, cmap=None,
60946112

60956113
collection = mcoll.PolyQuadMesh(
60966114
coords, array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs)
6115+
collection._units = units
6116+
collection._converter = converter
60976117
collection._scale_norm(norm, vmin, vmax)
60986118

60996119
# Transform from native to data coordinates?
@@ -6313,15 +6333,18 @@ def pcolormesh(self, *args, alpha=None, norm=None, cmap=None, vmin=None,
63136333
shading = shading.lower()
63146334
kwargs.setdefault('edgecolors', 'none')
63156335

6316-
X, Y, C, shading = self._pcolorargs('pcolormesh', *args,
6317-
shading=shading, kwargs=kwargs)
6336+
X, Y, C, shading, units, converter = self._pcolorargs(
6337+
'pcolormesh', *args, shading=shading, kwargs=kwargs
6338+
)
63186339
coords = np.stack([X, Y], axis=-1)
63196340

63206341
kwargs.setdefault('snap', mpl.rcParams['pcolormesh.snap'])
63216342

63226343
collection = mcoll.QuadMesh(
63236344
coords, antialiased=antialiased, shading=shading,
63246345
array=C, cmap=cmap, norm=norm, alpha=alpha, **kwargs)
6346+
collection._units = units
6347+
collection._converter = converter
63256348
collection._scale_norm(norm, vmin, vmax)
63266349

63276350
coords = coords.reshape(-1, 2) # flatten the grid structure; keep x, y

0 commit comments

Comments
 (0)