Skip to content

Commit 76840ea

Browse files
author
Fabio Zanini
committed
Logit scale
1 parent ee086de commit 76840ea

File tree

9 files changed

+334
-3
lines changed

9 files changed

+334
-3
lines changed

doc/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@
126126
('color', 'Color'),
127127
('text_labels_and_annotations', 'Text, labels, and annotations'),
128128
('ticks_and_spines', 'Ticks and spines'),
129+
('scales', 'Axis scales'),
129130
('subplots_axes_and_figures', 'Subplots, axes, and figures'),
130131
('style_sheets', 'Style sheets'),
131132
('specialty_plots', 'Specialty plots'),

doc/pyplots/pyplot_scales.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
# make up some data in the interval ]0, 1[
5+
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
6+
y = y[(y > 0) & (y < 1)]
7+
y.sort()
8+
x = np.arange(len(y))
9+
10+
# plot with various axes scales
11+
plt.figure(1)
12+
13+
# linear
14+
plt.subplot(221)
15+
plt.plot(x, y)
16+
plt.yscale('linear')
17+
plt.title('linear')
18+
plt.grid(True)
19+
20+
21+
# log
22+
plt.subplot(222)
23+
plt.plot(x, y)
24+
plt.yscale('log')
25+
plt.title('log')
26+
plt.grid(True)
27+
28+
29+
# symmetric log
30+
plt.subplot(223)
31+
plt.plot(x, y - y.mean())
32+
plt.yscale('symlog', linthreshy=0.05)
33+
plt.title('symlog')
34+
plt.grid(True)
35+
36+
# logit
37+
plt.subplot(223)
38+
plt.plot(x, y)
39+
plt.yscale('logit')
40+
plt.title('logit')
41+
plt.grid(True)
42+
43+
plt.show()

doc/users/pyplot_tutorial.rst

+19
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,22 @@ variety of other coordinate systems one can choose -- see
280280
:ref:`annotations-tutorial` and :ref:`plotting-guide-annotation` for
281281
details. More examples can be found in
282282
:ref:`pylab_examples-annotation_demo`.
283+
284+
285+
Logarithmic and other nonlinear axis
286+
====================================
287+
288+
:mod:`matplotlib.pyplot` supports not only linear axis scales, but also
289+
logarithmic and logit scales. This is commonly used if data spans many orders
290+
of magnitude. Changing the scale of an axis is easy:
291+
292+
plt.xscale('log')
293+
294+
An example of four plots with the same data and different scales for the y axis
295+
is shown below.
296+
297+
.. plot:: pyplots/pyplot_scales.py
298+
:include-source:
299+
300+
It is also possible to add your own scale, see :ref:`adding-new-scales` for
301+
details.

doc/users/whats_new/updated_scale.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Logit Scale
2+
-----------
3+
Added support for the 'logit' axis scale, a nonlinear transformation
4+
`x -> log10(x / (1-x))` for data between 0 and 1 excluded.

examples/scales/scales.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""
2+
Illustrate the scale transformations applied to axes, e.g. log, symlog, logit.
3+
"""
4+
import numpy as np
5+
import matplotlib.pyplot as plt
6+
7+
# make up some data in the interval ]0, 1[
8+
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
9+
y = y[(y > 0) & (y < 1)]
10+
y.sort()
11+
x = np.arange(len(y))
12+
13+
# plot with various axes scales
14+
fig, axs = plt.subplots(2, 2)
15+
16+
# linear
17+
ax = axs[0, 0]
18+
ax.plot(x, y)
19+
ax.set_yscale('linear')
20+
ax.set_title('linear')
21+
ax.grid(True)
22+
23+
24+
# log
25+
ax = axs[0, 1]
26+
ax.plot(x, y)
27+
ax.set_yscale('log')
28+
ax.set_title('log')
29+
ax.grid(True)
30+
31+
32+
# symmetric log
33+
ax = axs[1, 0]
34+
ax.plot(x, y - y.mean())
35+
ax.set_yscale('symlog', linthreshy=0.05)
36+
ax.set_title('symlog')
37+
ax.grid(True)
38+
39+
# logit
40+
ax = axs[1, 1]
41+
ax.plot(x, y)
42+
ax.set_yscale('logit')
43+
ax.set_title('logit')
44+
ax.grid(True)
45+
46+
47+
plt.show()

lib/matplotlib/scale.py

+104-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
from matplotlib.cbook import dedent
1010
from matplotlib.ticker import (NullFormatter, ScalarFormatter,
11-
LogFormatterMathtext)
11+
LogFormatterMathtext, LogitFormatter)
1212
from matplotlib.ticker import (NullLocator, LogLocator, AutoLocator,
13-
SymmetricalLogLocator)
13+
SymmetricalLogLocator, LogitLocator)
1414
from matplotlib.transforms import Transform, IdentityTransform
1515
from matplotlib import docstring
1616

@@ -478,10 +478,111 @@ def get_transform(self):
478478
return self._transform
479479

480480

481+
def _mask_non_logit(a):
482+
"""
483+
Return a Numpy masked array where all values outside ]0, 1[ are
484+
masked. If all values are inside ]0, 1[, the original array is
485+
returned.
486+
"""
487+
a = a.copy()
488+
mask = (a <= 0.0) | (a >= 1.0)
489+
a[mask] = np.nan
490+
return a
491+
492+
493+
def _clip_non_logit(a):
494+
a = a.copy()
495+
a[a <= 0.0] = 1e-300
496+
a[a >= 1.0] = 1 - 1e-300
497+
return a
498+
499+
500+
class LogitTransform(Transform):
501+
input_dims = 1
502+
output_dims = 1
503+
is_separable = True
504+
has_inverse = True
505+
506+
def __init__(self, nonpos):
507+
Transform.__init__(self)
508+
if nonpos == 'mask':
509+
self._handle_nonpos = _mask_non_logit
510+
else:
511+
self._handle_nonpos = _clip_non_logit
512+
self._nonpos = nonpos
513+
514+
def transform_non_affine(self, a):
515+
"""logit transform (base 10), masked or clipped"""
516+
a = self._handle_nonpos(a)
517+
if isinstance(a, ma.MaskedArray):
518+
return ma.log10(1.0 * a / (1.0 - a))
519+
return np.log10(1.0 * a / (1.0 - a))
520+
521+
def inverted(self):
522+
return LogisticTransform(self._nonpos)
523+
524+
525+
class LogisticTransform(Transform):
526+
input_dims = 1
527+
output_dims = 1
528+
is_separable = True
529+
has_inverse = True
530+
531+
def __init__(self, nonpos='mask'):
532+
Transform.__init__(self)
533+
self._nonpos = nonpos
534+
535+
def transform_non_affine(self, a):
536+
"""logistic transform (base 10)"""
537+
return 1.0 / (1 + 10**(-a))
538+
539+
def inverted(self):
540+
return LogitTransform(self._nonpos)
541+
542+
543+
class LogitScale(ScaleBase):
544+
"""
545+
Logit scale for data between zero and one, both excluded.
546+
547+
This scale is similar to a log scale close to zero and to one, and almost
548+
linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[.
549+
"""
550+
name = 'logit'
551+
552+
def __init__(self, axis, nonpos='mask'):
553+
"""
554+
*nonpos*: ['mask' | 'clip' ]
555+
values beyond ]0, 1[ can be masked as invalid, or clipped to a number
556+
very close to 0 or 1
557+
"""
558+
if nonpos not in ['mask', 'clip']:
559+
raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'")
560+
561+
self._transform = LogitTransform(nonpos)
562+
563+
def get_transform(self):
564+
"""
565+
Return a :class:`LogitTransform` instance.
566+
"""
567+
return self._transform
568+
569+
def set_default_locators_and_formatters(self, axis):
570+
# ..., 0.01, 0.1, 0.5, 0.9, 0.99, ...
571+
axis.set_major_locator(LogitLocator())
572+
axis.set_major_formatter(LogitFormatter())
573+
axis.set_minor_locator(LogitLocator(minor=True))
574+
axis.set_minor_formatter(LogitFormatter())
575+
576+
def limit_range_for_scale(self, vmin, vmax, minpos):
577+
return (vmin <= 0 and minpos or vmin,
578+
vmax >= 1 and (1 - minpos) or vmax)
579+
580+
481581
_scale_mapping = {
482582
'linear': LinearScale,
483583
'log': LogScale,
484-
'symlog': SymmetricalLogScale
584+
'symlog': SymmetricalLogScale,
585+
'logit': LogitScale,
485586
}
486587

487588

lib/matplotlib/tests/test_scale.py

+14
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@ def test_log_scales():
1414
ax.axhline(24.1)
1515

1616

17+
@image_comparison(baseline_images=['logit_scales'], remove_text=True,
18+
extensions=['png'])
19+
def test_logit_scales():
20+
ax = plt.subplot(111, xscale='logit')
21+
22+
# Typical exctinction curve for logit
23+
x = np.array([0.001, 0.003, 0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5,
24+
0.6, 0.7, 0.8, 0.9, 0.97, 0.99, 0.997, 0.999])
25+
y = 1.0 / x
26+
27+
ax.plot(x, y)
28+
ax.grid(True)
29+
30+
1731
@cleanup
1832
def test_log_scatter():
1933
"""Issue #1799"""

lib/matplotlib/ticker.py

+102
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,26 @@ def __call__(self, x, pos=None):
807807
nearest_long(fx))
808808

809809

810+
class LogitFormatter(Formatter):
811+
'''Probability formatter (using Math text)'''
812+
def __call__(self, x, pos=None):
813+
s = ''
814+
if 0.01 <= x <= 0.99:
815+
if x in [.01, 0.1, 0.5, 0.9, 0.99]:
816+
s = '{:.2f}'.format(x)
817+
elif x < 0.01:
818+
if is_decade(x):
819+
s = '$10^{%.0f}$' % np.log10(x)
820+
elif x > 0.99:
821+
if is_decade(1-x):
822+
s = '$1-10^{%.0f}$' % np.log10(1-x)
823+
return s
824+
825+
def format_data_short(self, value):
826+
'return a short formatted string representation of a number'
827+
return '%-12g' % value
828+
829+
810830
class EngFormatter(Formatter):
811831
"""
812832
Formats axis values using engineering prefixes to represent powers of 1000,
@@ -1694,6 +1714,88 @@ def view_limits(self, vmin, vmax):
16941714
return result
16951715

16961716

1717+
class LogitLocator(Locator):
1718+
"""
1719+
Determine the tick locations for logit axes
1720+
"""
1721+
1722+
def __init__(self, minor=False):
1723+
"""
1724+
place ticks on the logit locations
1725+
"""
1726+
self.minor = minor
1727+
1728+
def __call__(self):
1729+
'Return the locations of the ticks'
1730+
vmin, vmax = self.axis.get_view_interval()
1731+
return self.tick_values(vmin, vmax)
1732+
1733+
def tick_values(self, vmin, vmax):
1734+
# dummy axis has no axes attribute
1735+
if hasattr(self.axis, 'axes') and self.axis.axes.name == 'polar':
1736+
raise NotImplementedError('Polar axis cannot be logit scaled yet')
1737+
1738+
# what to do if a window beyond ]0, 1[ is chosen
1739+
if vmin <= 0.0:
1740+
if self.axis is not None:
1741+
vmin = self.axis.get_minpos()
1742+
1743+
if (vmin <= 0.0) or (not np.isfinite(vmin)):
1744+
raise ValueError(
1745+
"Data has no values in ]0, 1[ and therefore can not be "
1746+
"logit-scaled.")
1747+
1748+
# NOTE: for vmax, we should query a property similar to get_minpos, but
1749+
# related to the maximal, less-than-one data point. Unfortunately,
1750+
# get_minpos is defined very deep in the BBox and updated with data,
1751+
# so for now we use the trick below.
1752+
if vmax >= 1.0:
1753+
if self.axis is not None:
1754+
vmax = 1 - self.axis.get_minpos()
1755+
1756+
if (vmax >= 1.0) or (not np.isfinite(vmax)):
1757+
raise ValueError(
1758+
"Data has no values in ]0, 1[ and therefore can not be "
1759+
"logit-scaled.")
1760+
1761+
if vmax < vmin:
1762+
vmin, vmax = vmax, vmin
1763+
1764+
vmin = np.log10(vmin / (1 - vmin))
1765+
vmax = np.log10(vmax / (1 - vmax))
1766+
1767+
decade_min = np.floor(vmin)
1768+
decade_max = np.ceil(vmax)
1769+
1770+
# major ticks
1771+
if not self.minor:
1772+
ticklocs = []
1773+
if (decade_min <= -1):
1774+
expo = np.arange(decade_min, min(0, decade_max + 1))
1775+
ticklocs.extend(list(10**expo))
1776+
if (decade_min <= 0) and (decade_max >= 0):
1777+
ticklocs.append(0.5)
1778+
if (decade_max >= 1):
1779+
expo = -np.arange(max(1, decade_min), decade_max + 1)
1780+
ticklocs.extend(list(1 - 10**expo))
1781+
1782+
# minor ticks
1783+
else:
1784+
ticklocs = []
1785+
if (decade_min <= -2):
1786+
expo = np.arange(decade_min, min(-1, decade_max))
1787+
newticks = np.outer(np.arange(2, 10), 10**expo).ravel()
1788+
ticklocs.extend(list(newticks))
1789+
if (decade_min <= 0) and (decade_max >= 0):
1790+
ticklocs.extend([0.2, 0.3, 0.4, 0.6, 0.7, 0.8])
1791+
if (decade_max >= 2):
1792+
expo = -np.arange(max(2, decade_min), decade_max + 1)
1793+
newticks = 1 - np.outer(np.arange(2, 10), 10**expo).ravel()
1794+
ticklocs.extend(list(newticks))
1795+
1796+
return self.raise_if_exceeds(np.array(ticklocs))
1797+
1798+
16971799
class AutoLocator(MaxNLocator):
16981800
def __init__(self):
16991801
MaxNLocator.__init__(self, nbins=9, steps=[1, 2, 5, 10])

0 commit comments

Comments
 (0)