Skip to content

Logit scale #3753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 3, 2015
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Logit scale
  • Loading branch information
Fabio Zanini committed Mar 2, 2015
commit 76840ea91d4d8a6e09b4fa1ad6fd36b2b179971e
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
('color', 'Color'),
('text_labels_and_annotations', 'Text, labels, and annotations'),
('ticks_and_spines', 'Ticks and spines'),
('scales', 'Axis scales'),
('subplots_axes_and_figures', 'Subplots, axes, and figures'),
('style_sheets', 'Style sheets'),
('specialty_plots', 'Specialty plots'),
Expand Down
43 changes: 43 additions & 0 deletions doc/pyplots/pyplot_scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import matplotlib.pyplot as plt

# make up some data in the interval ]0, 1[
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
y = y[(y > 0) & (y < 1)]
y.sort()
x = np.arange(len(y))

# plot with various axes scales
plt.figure(1)

# linear
plt.subplot(221)
plt.plot(x, y)
plt.yscale('linear')
plt.title('linear')
plt.grid(True)


# log
plt.subplot(222)
plt.plot(x, y)
plt.yscale('log')
plt.title('log')
plt.grid(True)


# symmetric log
plt.subplot(223)
plt.plot(x, y - y.mean())
plt.yscale('symlog', linthreshy=0.05)
plt.title('symlog')
plt.grid(True)

# logit
plt.subplot(223)
plt.plot(x, y)
plt.yscale('logit')
plt.title('logit')
plt.grid(True)

plt.show()
19 changes: 19 additions & 0 deletions doc/users/pyplot_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,22 @@ variety of other coordinate systems one can choose -- see
:ref:`annotations-tutorial` and :ref:`plotting-guide-annotation` for
details. More examples can be found in
:ref:`pylab_examples-annotation_demo`.


Logarithmic and other nonlinear axis
====================================

:mod:`matplotlib.pyplot` supports not only linear axis scales, but also
logarithmic and logit scales. This is commonly used if data spans many orders
of magnitude. Changing the scale of an axis is easy:

plt.xscale('log')

An example of four plots with the same data and different scales for the y axis
is shown below.

.. plot:: pyplots/pyplot_scales.py
:include-source:

It is also possible to add your own scale, see :ref:`adding-new-scales` for
details.
4 changes: 4 additions & 0 deletions doc/users/whats_new/updated_scale.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Logit Scale
-----------
Added support for the 'logit' axis scale, a nonlinear transformation
`x -> log10(x / (1-x))` for data between 0 and 1 excluded.
47 changes: 47 additions & 0 deletions examples/scales/scales.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Illustrate the scale transformations applied to axes, e.g. log, symlog, logit.
"""
import numpy as np
import matplotlib.pyplot as plt

# make up some data in the interval ]0, 1[
y = np.random.normal(loc=0.5, scale=0.4, size=1000)
y = y[(y > 0) & (y < 1)]
y.sort()
x = np.arange(len(y))

# plot with various axes scales
fig, axs = plt.subplots(2, 2)

# linear
ax = axs[0, 0]
ax.plot(x, y)
ax.set_yscale('linear')
ax.set_title('linear')
ax.grid(True)


# log
ax = axs[0, 1]
ax.plot(x, y)
ax.set_yscale('log')
ax.set_title('log')
ax.grid(True)


# symmetric log
ax = axs[1, 0]
ax.plot(x, y - y.mean())
ax.set_yscale('symlog', linthreshy=0.05)
ax.set_title('symlog')
ax.grid(True)

# logit
ax = axs[1, 1]
ax.plot(x, y)
ax.set_yscale('logit')
ax.set_title('logit')
ax.grid(True)


plt.show()
107 changes: 104 additions & 3 deletions lib/matplotlib/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from matplotlib.cbook import dedent
from matplotlib.ticker import (NullFormatter, ScalarFormatter,
LogFormatterMathtext)
LogFormatterMathtext, LogitFormatter)
from matplotlib.ticker import (NullLocator, LogLocator, AutoLocator,
SymmetricalLogLocator)
SymmetricalLogLocator, LogitLocator)
from matplotlib.transforms import Transform, IdentityTransform
from matplotlib import docstring

Expand Down Expand Up @@ -478,10 +478,111 @@ def get_transform(self):
return self._transform


def _mask_non_logit(a):
"""
Return a Numpy masked array where all values outside ]0, 1[ are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring does not match what the function is doing. It is always returning a copy of the input array in which any values outside ]0, 1[ are replaced by nan.

masked. If all values are inside ]0, 1[, the original array is
returned.
"""
a = a.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you switched in midstream from using masked arrays to using nans. Does this need to handle masked array inputs? If so, and if you can subsequently use only nans for bad values, you can do this:

a = np.ma.masked_invalid(a).filled(np.nan)

Regardless of input, that will leave you with an ndarray with nothing but valid numbers and nans. It handles inf as well as nan.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so I think keeping the logit _mask and _clip functions analogous to the log ones (_mask_non_positives etc) is good for the codebase, because they have the exact same funcion.

But the *_non_positives functions have bugs, i.e. the docstring is wrong (it mentions masked arrays) and the _clip function modifies user input. So I'll correct the _non_positive function as well, I am against a PR with the logit/log parts of the code doing different things.

mask = (a <= 0.0) | (a >= 1.0)
a[mask] = np.nan
return a


def _clip_non_logit(a):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should make a copy of the input array either here or where _handle_non_affine is called. Modifying user objects is bad practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could achieve what you want in one line with return np.clip(a, 1e-300, 1 - 1e-300). Either way, though, you are assuming a is double precision. Is this a safe assumption?

a = a.copy()
a[a <= 0.0] = 1e-300
a[a >= 1.0] = 1 - 1e-300
return a


class LogitTransform(Transform):
input_dims = 1
output_dims = 1
is_separable = True
has_inverse = True

def __init__(self, nonpos):
Transform.__init__(self)
if nonpos == 'mask':
self._handle_nonpos = _mask_non_logit
else:
self._handle_nonpos = _clip_non_logit
self._nonpos = nonpos

def transform_non_affine(self, a):
"""logit transform (base 10), masked or clipped"""
a = self._handle_nonpos(a)
if isinstance(a, ma.MaskedArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't be a masked array here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see below: the code for the LogTransforms has the same problem all over the place, I was just copying from there to keep the codebase uniform. But now if mpl is moving away from masked arrays for whatever reason, I'll just correct all the LogTransforms classes too.

return ma.log10(1.0 * a / (1.0 - a))
return np.log10(1.0 * a / (1.0 - a))

def inverted(self):
return LogisticTransform(self._nonpos)


class LogisticTransform(Transform):
input_dims = 1
output_dims = 1
is_separable = True
has_inverse = True

def __init__(self, nonpos='mask'):
Transform.__init__(self)
self._nonpos = nonpos

def transform_non_affine(self, a):
"""logistic transform (base 10)"""
return 1.0 / (1 + 10**(-a))

def inverted(self):
return LogitTransform(self._nonpos)


class LogitScale(ScaleBase):
"""
Logit scale for data between zero and one, both excluded.

This scale is similar to a log scale close to zero and to one, and almost
linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[.
"""
name = 'logit'

def __init__(self, axis, nonpos='mask'):
"""
*nonpos*: ['mask' | 'clip' ]
values beyond ]0, 1[ can be masked as invalid, or clipped to a number
very close to 0 or 1
"""
if nonpos not in ['mask', 'clip']:
raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'")

self._transform = LogitTransform(nonpos)

def get_transform(self):
"""
Return a :class:`LogitTransform` instance.
"""
return self._transform

def set_default_locators_and_formatters(self, axis):
# ..., 0.01, 0.1, 0.5, 0.9, 0.99, ...
axis.set_major_locator(LogitLocator())
axis.set_major_formatter(LogitFormatter())
axis.set_minor_locator(LogitLocator(minor=True))
axis.set_minor_formatter(LogitFormatter())

def limit_range_for_scale(self, vmin, vmax, minpos):
return (vmin <= 0 and minpos or vmin,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it hard to parse this expression; and it needs a docstring at least to explain what minpos is.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a short docstring, but it's the same like the analog function for LogTransforms. Again, I'm just being consistent with the codebase.

minpos is not documented in the LogTransform either, but it is documented in the base class. If it gets explained in LogitTransform, it will be in LogTransform too, is my opinion.

vmax >= 1 and (1 - minpos) or vmax)


_scale_mapping = {
'linear': LinearScale,
'log': LogScale,
'symlog': SymmetricalLogScale
'symlog': SymmetricalLogScale,
'logit': LogitScale,
}


Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
14 changes: 14 additions & 0 deletions lib/matplotlib/tests/test_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ def test_log_scales():
ax.axhline(24.1)


@image_comparison(baseline_images=['logit_scales'], remove_text=True,
extensions=['png'])
def test_logit_scales():
ax = plt.subplot(111, xscale='logit')

# Typical exctinction curve for logit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: -> extinction

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

x = np.array([0.001, 0.003, 0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5,
0.6, 0.7, 0.8, 0.9, 0.97, 0.99, 0.997, 0.999])
y = 1.0 / x

ax.plot(x, y)
ax.grid(True)


@cleanup
def test_log_scatter():
"""Issue #1799"""
Expand Down
102 changes: 102 additions & 0 deletions lib/matplotlib/ticker.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,6 +807,26 @@ def __call__(self, x, pos=None):
nearest_long(fx))


class LogitFormatter(Formatter):
'''Probability formatter (using Math text)'''
def __call__(self, x, pos=None):
s = ''
if 0.01 <= x <= 0.99:
if x in [.01, 0.1, 0.5, 0.9, 0.99]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be so restrictive? Only those few values? Silently returning an empty string otherwise?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair enough, fixed

s = '{:.2f}'.format(x)
elif x < 0.01:
if is_decade(x):
s = '$10^{%.0f}$' % np.log10(x)
elif x > 0.99:
if is_decade(1-x):
s = '$1-10^{%.0f}$' % np.log10(1-x)
return s

def format_data_short(self, value):
'return a short formatted string representation of a number'
return '%-12g' % value


class EngFormatter(Formatter):
"""
Formats axis values using engineering prefixes to represent powers of 1000,
Expand Down Expand Up @@ -1694,6 +1714,88 @@ def view_limits(self, vmin, vmax):
return result


class LogitLocator(Locator):
"""
Determine the tick locations for logit axes
"""

def __init__(self, minor=False):
"""
place ticks on the logit locations
"""
self.minor = minor

def __call__(self):
'Return the locations of the ticks'
vmin, vmax = self.axis.get_view_interval()
return self.tick_values(vmin, vmax)

def tick_values(self, vmin, vmax):
# dummy axis has no axes attribute
if hasattr(self.axis, 'axes') and self.axis.axes.name == 'polar':
raise NotImplementedError('Polar axis cannot be logit scaled yet')

# what to do if a window beyond ]0, 1[ is chosen
if vmin <= 0.0:
if self.axis is not None:
vmin = self.axis.get_minpos()

if (vmin <= 0.0) or (not np.isfinite(vmin)):
raise ValueError(
"Data has no values in ]0, 1[ and therefore can not be "
"logit-scaled.")

# NOTE: for vmax, we should query a property similar to get_minpos, but
# related to the maximal, less-than-one data point. Unfortunately,
# get_minpos is defined very deep in the BBox and updated with data,
# so for now we use the trick below.
if vmax >= 1.0:
if self.axis is not None:
vmax = 1 - self.axis.get_minpos()

if (vmax >= 1.0) or (not np.isfinite(vmax)):
raise ValueError(
"Data has no values in ]0, 1[ and therefore can not be "
"logit-scaled.")

if vmax < vmin:
vmin, vmax = vmax, vmin

vmin = np.log10(vmin / (1 - vmin))
vmax = np.log10(vmax / (1 - vmax))

decade_min = np.floor(vmin)
decade_max = np.ceil(vmax)

# major ticks
if not self.minor:
ticklocs = []
if (decade_min <= -1):
expo = np.arange(decade_min, min(0, decade_max + 1))
ticklocs.extend(list(10**expo))
if (decade_min <= 0) and (decade_max >= 0):
ticklocs.append(0.5)
if (decade_max >= 1):
expo = -np.arange(max(1, decade_min), decade_max + 1)
ticklocs.extend(list(1 - 10**expo))

# minor ticks
else:
ticklocs = []
if (decade_min <= -2):
expo = np.arange(decade_min, min(-1, decade_max))
newticks = np.outer(np.arange(2, 10), 10**expo).ravel()
ticklocs.extend(list(newticks))
if (decade_min <= 0) and (decade_max >= 0):
ticklocs.extend([0.2, 0.3, 0.4, 0.6, 0.7, 0.8])
if (decade_max >= 2):
expo = -np.arange(max(2, decade_min), decade_max + 1)
newticks = 1 - np.outer(np.arange(2, 10), 10**expo).ravel()
ticklocs.extend(list(newticks))

return self.raise_if_exceeds(np.array(ticklocs))


class AutoLocator(MaxNLocator):
def __init__(self):
MaxNLocator.__init__(self, nbins=9, steps=[1, 2, 5, 10])
Expand Down