Skip to content

Commit 98474c8

Browse files
authored
Merge pull request #16457 from anntzer/scale-norm
Build lognorm/symlognorm from corresponding scales.
2 parents 73f8cea + 0d1d95c commit 98474c8

File tree

2 files changed

+98
-165
lines changed

2 files changed

+98
-165
lines changed

lib/matplotlib/colors.py

+98-165
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import base64
6969
from collections.abc import Sized
7070
import functools
71+
import inspect
7172
import io
7273
import itertools
7374
from numbers import Number
@@ -77,8 +78,7 @@
7778

7879
import matplotlib as mpl
7980
import numpy as np
80-
import matplotlib.cbook as cbook
81-
from matplotlib import docstring
81+
from matplotlib import cbook, docstring, scale
8282
from ._color_data import BASE_COLORS, TABLEAU_COLORS, CSS4_COLORS, XKCD_COLORS
8383

8484

@@ -1203,60 +1203,84 @@ class DivergingNorm(TwoSlopeNorm):
12031203
...
12041204

12051205

1206-
class LogNorm(Normalize):
1207-
"""Normalize a given value to the 0-1 range on a log scale."""
1206+
def _make_norm_from_scale(scale_cls, base_norm_cls=None, *, init=None):
1207+
"""
1208+
Decorator for building a `.Normalize` subclass from a `.Scale` subclass.
12081209
1209-
def _check_vmin_vmax(self):
1210-
if self.vmin > self.vmax:
1211-
raise ValueError("minvalue must be less than or equal to maxvalue")
1212-
elif self.vmin <= 0:
1213-
raise ValueError("minvalue must be positive")
1210+
After ::
12141211
1215-
def __call__(self, value, clip=None):
1216-
if clip is None:
1217-
clip = self.clip
1212+
@_make_norm_from_scale(scale_cls)
1213+
class base_norm_cls(Normalize):
1214+
...
12181215
1219-
result, is_scalar = self.process_value(value)
1216+
*base_norm_cls* is filled with methods so that normalization computations
1217+
are forwarded to *scale_cls* (i.e., *scale_cls* is the scale that would be
1218+
used for the colorbar of a mappable normalized with *base_norm_cls*).
12201219
1221-
result = np.ma.masked_less_equal(result, 0, copy=False)
1220+
The constructor signature of *base_norm_cls* is derived from the
1221+
constructor signature of *scale_cls*, but can be overridden using *init*
1222+
(a callable which is *only* used for its signature).
1223+
"""
12221224

1223-
self.autoscale_None(result)
1224-
self._check_vmin_vmax()
1225-
vmin, vmax = self.vmin, self.vmax
1226-
if vmin == vmax:
1227-
result.fill(0)
1228-
else:
1225+
if base_norm_cls is None:
1226+
return functools.partial(_make_norm_from_scale, scale_cls, init=init)
1227+
1228+
if init is None:
1229+
def init(vmin=None, vmax=None, clip=False): pass
1230+
init_signature = inspect.signature(init)
1231+
1232+
class Norm(base_norm_cls):
1233+
1234+
def __init__(self, *args, **kwargs):
1235+
ba = init_signature.bind(*args, **kwargs)
1236+
ba.apply_defaults()
1237+
super().__init__(
1238+
**{k: ba.arguments.pop(k) for k in ["vmin", "vmax", "clip"]})
1239+
self._scale = scale_cls(axis=None, **ba.arguments)
1240+
self._trf = self._scale.get_transform()
1241+
self._inv_trf = self._trf.inverted()
1242+
1243+
def __call__(self, value, clip=None):
1244+
value, is_scalar = self.process_value(value)
1245+
self.autoscale_None(value)
1246+
if self.vmin > self.vmax:
1247+
raise ValueError("vmin must be less or equal to vmax")
1248+
if self.vmin == self.vmax:
1249+
return np.full_like(value, 0)
1250+
if clip is None:
1251+
clip = self.clip
12291252
if clip:
1230-
mask = np.ma.getmask(result)
1231-
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
1232-
mask=mask)
1233-
# in-place equivalent of above can be much faster
1234-
resdat = result.data
1235-
mask = result.mask
1236-
if mask is np.ma.nomask:
1237-
mask = (resdat <= 0)
1238-
else:
1239-
mask |= resdat <= 0
1240-
np.copyto(resdat, 1, where=mask)
1241-
np.log(resdat, resdat)
1242-
resdat -= np.log(vmin)
1243-
resdat /= (np.log(vmax) - np.log(vmin))
1244-
result = np.ma.array(resdat, mask=mask, copy=False)
1245-
if is_scalar:
1246-
result = result[0]
1247-
return result
1248-
1249-
def inverse(self, value):
1250-
if not self.scaled():
1251-
raise ValueError("Not invertible until scaled")
1252-
self._check_vmin_vmax()
1253-
vmin, vmax = self.vmin, self.vmax
1254-
1255-
if np.iterable(value):
1256-
val = np.ma.asarray(value)
1257-
return vmin * np.ma.power((vmax / vmin), val)
1258-
else:
1259-
return vmin * pow((vmax / vmin), value)
1253+
value = np.clip(value, self.vmin, self.vmax)
1254+
t_value = self._trf.transform(value).reshape(np.shape(value))
1255+
t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax])
1256+
if not np.isfinite([t_vmin, t_vmax]).all():
1257+
raise ValueError("Invalid vmin or vmax")
1258+
t_value -= t_vmin
1259+
t_value /= (t_vmax - t_vmin)
1260+
t_value = np.ma.masked_invalid(t_value, copy=False)
1261+
return t_value[0] if is_scalar else t_value
1262+
1263+
def inverse(self, value):
1264+
if not self.scaled():
1265+
raise ValueError("Not invertible until scaled")
1266+
if self.vmin > self.vmax:
1267+
raise ValueError("vmin must be less or equal to vmax")
1268+
t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax])
1269+
if not np.isfinite([t_vmin, t_vmax]).all():
1270+
raise ValueError("Invalid vmin or vmax")
1271+
rescaled = value * (t_vmax - t_vmin)
1272+
rescaled += t_vmin
1273+
return self._inv_trf.transform(rescaled).reshape(np.shape(value))
1274+
1275+
Norm.__name__ = base_norm_cls.__name__
1276+
Norm.__qualname__ = base_norm_cls.__qualname__
1277+
Norm.__module__ = base_norm_cls.__module__
1278+
return Norm
1279+
1280+
1281+
@_make_norm_from_scale(functools.partial(scale.LogScale, nonpositive="mask"))
1282+
class LogNorm(Normalize):
1283+
"""Normalize a given value to the 0-1 range on a log scale."""
12601284

12611285
def autoscale(self, A):
12621286
# docstring inherited.
@@ -1267,6 +1291,10 @@ def autoscale_None(self, A):
12671291
super().autoscale_None(np.ma.masked_less_equal(A, 0, copy=False))
12681292

12691293

1294+
@_make_norm_from_scale(
1295+
scale.SymmetricalLogScale,
1296+
init=lambda linthresh, linscale=1., vmin=None, vmax=None, clip=False, *,
1297+
base=10: None)
12701298
class SymLogNorm(Normalize):
12711299
"""
12721300
The symmetrical logarithmic scale is logarithmic in both the
@@ -1276,124 +1304,29 @@ class SymLogNorm(Normalize):
12761304
need to have a range around zero that is linear. The parameter
12771305
*linthresh* allows the user to specify the size of this range
12781306
(-*linthresh*, *linthresh*).
1279-
"""
1280-
def __init__(self, linthresh, linscale=1.0, vmin=None, vmax=None,
1281-
clip=False, *, base=None):
1282-
"""
1283-
Parameters
1284-
----------
1285-
linthresh : float
1286-
The range within which the plot is linear (to avoid having the plot
1287-
go to infinity around zero).
1288-
1289-
linscale : float, default: 1
1290-
This allows the linear range (-*linthresh* to *linthresh*)
1291-
to be stretched relative to the logarithmic range. Its
1292-
value is the number of powers of *base* to use for each
1293-
half of the linear range.
1294-
1295-
For example, when *linscale* == 1.0 (the default) and
1296-
``base=10``, then space used for the positive and negative
1297-
halves of the linear range will be equal to a decade in
1298-
the logarithmic.
1299-
1300-
base : float, default: None
1301-
If not given, defaults to ``np.e`` (consistent with prior
1302-
behavior) and warns.
1303-
1304-
In v3.3 the default value will change to 10 to be consistent with
1305-
`.SymLogNorm`.
1306-
1307-
To suppress the warning pass *base* as a keyword argument.
13081307
1309-
"""
1310-
Normalize.__init__(self, vmin, vmax, clip)
1311-
if base is None:
1312-
self._base = np.e
1313-
cbook.warn_deprecated(
1314-
"3.2", removal="3.4", message="default base will change from "
1315-
"np.e to 10 %(removal)s. To suppress this warning specify "
1316-
"the base keyword argument.")
1317-
else:
1318-
self._base = base
1319-
self._log_base = np.log(self._base)
1320-
1321-
self.linthresh = float(linthresh)
1322-
self._linscale_adj = (linscale / (1.0 - self._base ** -1))
1323-
if vmin is not None and vmax is not None:
1324-
self._transform_vmin_vmax()
1325-
1326-
def __call__(self, value, clip=None):
1327-
if clip is None:
1328-
clip = self.clip
1329-
1330-
result, is_scalar = self.process_value(value)
1331-
self.autoscale_None(result)
1332-
vmin, vmax = self.vmin, self.vmax
1333-
1334-
if vmin > vmax:
1335-
raise ValueError("minvalue must be less than or equal to maxvalue")
1336-
elif vmin == vmax:
1337-
result.fill(0)
1338-
else:
1339-
if clip:
1340-
mask = np.ma.getmask(result)
1341-
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
1342-
mask=mask)
1343-
# in-place equivalent of above can be much faster
1344-
resdat = self._transform(result.data)
1345-
resdat -= self._lower
1346-
resdat /= (self._upper - self._lower)
1347-
1348-
if is_scalar:
1349-
result = result[0]
1350-
return result
1351-
1352-
def _transform(self, a):
1353-
"""Inplace transformation."""
1354-
with np.errstate(invalid="ignore"):
1355-
masked = np.abs(a) > self.linthresh
1356-
sign = np.sign(a[masked])
1357-
log = (self._linscale_adj +
1358-
np.log(np.abs(a[masked]) / self.linthresh) / self._log_base)
1359-
log *= sign * self.linthresh
1360-
a[masked] = log
1361-
a[~masked] *= self._linscale_adj
1362-
return a
1363-
1364-
def _inv_transform(self, a):
1365-
"""Inverse inplace Transformation."""
1366-
masked = np.abs(a) > (self.linthresh * self._linscale_adj)
1367-
sign = np.sign(a[masked])
1368-
exp = np.power(self._base,
1369-
sign * a[masked] / self.linthresh - self._linscale_adj)
1370-
exp *= sign * self.linthresh
1371-
a[masked] = exp
1372-
a[~masked] /= self._linscale_adj
1373-
return a
1374-
1375-
def _transform_vmin_vmax(self):
1376-
"""Calculate vmin and vmax in the transformed system."""
1377-
vmin, vmax = self.vmin, self.vmax
1378-
arr = np.array([vmax, vmin]).astype(float)
1379-
self._upper, self._lower = self._transform(arr)
1380-
1381-
def inverse(self, value):
1382-
if not self.scaled():
1383-
raise ValueError("Not invertible until scaled")
1384-
val = np.ma.asarray(value)
1385-
val = val * (self._upper - self._lower) + self._lower
1386-
return self._inv_transform(val)
1308+
Parameters
1309+
----------
1310+
linthresh : float
1311+
The range within which the plot is linear (to avoid having the plot
1312+
go to infinity around zero).
1313+
linscale : float, default: 1
1314+
This allows the linear range (-*linthresh* to *linthresh*) to be
1315+
stretched relative to the logarithmic range. Its value is the
1316+
number of decades to use for each half of the linear range. For
1317+
example, when *linscale* == 1.0 (the default), the space used for
1318+
the positive and negative halves of the linear range will be equal
1319+
to one decade in the logarithmic range.
1320+
base : float, default: 10
1321+
"""
13871322

1388-
def autoscale(self, A):
1389-
# docstring inherited.
1390-
super().autoscale(A)
1391-
self._transform_vmin_vmax()
1323+
@property
1324+
def linthresh(self):
1325+
return self._scale.linthresh
13921326

1393-
def autoscale_None(self, A):
1394-
# docstring inherited.
1395-
super().autoscale_None(A)
1396-
self._transform_vmin_vmax()
1327+
@linthresh.setter
1328+
def linthresh(self, value):
1329+
self._scale.linthresh = value
13971330

13981331

13991332
class PowerNorm(Normalize):

0 commit comments

Comments
 (0)