-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
Logit scale #3753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Logit scale #3753
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# make up some data in the interval ]0, 1[ | ||
y = np.random.normal(loc=0.5, scale=0.4, size=1000) | ||
y = y[(y > 0) & (y < 1)] | ||
y.sort() | ||
x = np.arange(len(y)) | ||
|
||
# plot with various axes scales | ||
plt.figure(1) | ||
|
||
# linear | ||
plt.subplot(221) | ||
plt.plot(x, y) | ||
plt.yscale('linear') | ||
plt.title('linear') | ||
plt.grid(True) | ||
|
||
|
||
# log | ||
plt.subplot(222) | ||
plt.plot(x, y) | ||
plt.yscale('log') | ||
plt.title('log') | ||
plt.grid(True) | ||
|
||
|
||
# symmetric log | ||
plt.subplot(223) | ||
plt.plot(x, y - y.mean()) | ||
plt.yscale('symlog', linthreshy=0.05) | ||
plt.title('symlog') | ||
plt.grid(True) | ||
|
||
# logit | ||
plt.subplot(223) | ||
plt.plot(x, y) | ||
plt.yscale('logit') | ||
plt.title('logit') | ||
plt.grid(True) | ||
|
||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Logit Scale | ||
----------- | ||
Added support for the 'logit' axis scale, a nonlinear transformation | ||
`x -> log10(x / (1-x))` for data between 0 and 1 excluded. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Illustrate the scale transformations applied to axes, e.g. log, symlog, logit. | ||
""" | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
# make up some data in the interval ]0, 1[ | ||
y = np.random.normal(loc=0.5, scale=0.4, size=1000) | ||
y = y[(y > 0) & (y < 1)] | ||
y.sort() | ||
x = np.arange(len(y)) | ||
|
||
# plot with various axes scales | ||
fig, axs = plt.subplots(2, 2) | ||
|
||
# linear | ||
ax = axs[0, 0] | ||
ax.plot(x, y) | ||
ax.set_yscale('linear') | ||
ax.set_title('linear') | ||
ax.grid(True) | ||
|
||
|
||
# log | ||
ax = axs[0, 1] | ||
ax.plot(x, y) | ||
ax.set_yscale('log') | ||
ax.set_title('log') | ||
ax.grid(True) | ||
|
||
|
||
# symmetric log | ||
ax = axs[1, 0] | ||
ax.plot(x, y - y.mean()) | ||
ax.set_yscale('symlog', linthreshy=0.05) | ||
ax.set_title('symlog') | ||
ax.grid(True) | ||
|
||
# logit | ||
ax = axs[1, 1] | ||
ax.plot(x, y) | ||
ax.set_yscale('logit') | ||
ax.set_title('logit') | ||
ax.grid(True) | ||
|
||
|
||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,9 +8,9 @@ | |
|
||
from matplotlib.cbook import dedent | ||
from matplotlib.ticker import (NullFormatter, ScalarFormatter, | ||
LogFormatterMathtext) | ||
LogFormatterMathtext, LogitFormatter) | ||
from matplotlib.ticker import (NullLocator, LogLocator, AutoLocator, | ||
SymmetricalLogLocator) | ||
SymmetricalLogLocator, LogitLocator) | ||
from matplotlib.transforms import Transform, IdentityTransform | ||
from matplotlib import docstring | ||
|
||
|
@@ -478,10 +478,111 @@ def get_transform(self): | |
return self._transform | ||
|
||
|
||
def _mask_non_logit(a): | ||
""" | ||
Return a Numpy masked array where all values outside ]0, 1[ are | ||
masked. If all values are inside ]0, 1[, the original array is | ||
returned. | ||
""" | ||
a = a.copy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like you switched in midstream from using masked arrays to using nans. Does this need to handle masked array inputs? If so, and if you can subsequently use only nans for bad values, you can do this: a = np.ma.masked_invalid(a).filled(np.nan) Regardless of input, that will leave you with an ndarray with nothing but valid numbers and nans. It handles inf as well as nan. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so I think keeping the logit _mask and _clip functions analogous to the log ones (_mask_non_positives etc) is good for the codebase, because they have the exact same funcion. But the *_non_positives functions have bugs, i.e. the docstring is wrong (it mentions masked arrays) and the _clip function modifies user input. So I'll correct the _non_positive function as well, I am against a PR with the logit/log parts of the code doing different things. |
||
mask = (a <= 0.0) | (a >= 1.0) | ||
a[mask] = np.nan | ||
return a | ||
|
||
|
||
def _clip_non_logit(a): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should make a copy of the input array either here or where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could achieve what you want in one line with |
||
a = a.copy() | ||
a[a <= 0.0] = 1e-300 | ||
a[a >= 1.0] = 1 - 1e-300 | ||
return a | ||
|
||
|
||
class LogitTransform(Transform): | ||
input_dims = 1 | ||
output_dims = 1 | ||
is_separable = True | ||
has_inverse = True | ||
|
||
def __init__(self, nonpos): | ||
Transform.__init__(self) | ||
if nonpos == 'mask': | ||
self._handle_nonpos = _mask_non_logit | ||
else: | ||
self._handle_nonpos = _clip_non_logit | ||
self._nonpos = nonpos | ||
|
||
def transform_non_affine(self, a): | ||
"""logit transform (base 10), masked or clipped""" | ||
a = self._handle_nonpos(a) | ||
if isinstance(a, ma.MaskedArray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can't be a masked array here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see below: the code for the LogTransforms has the same problem all over the place, I was just copying from there to keep the codebase uniform. But now if mpl is moving away from masked arrays for whatever reason, I'll just correct all the LogTransforms classes too. |
||
return ma.log10(1.0 * a / (1.0 - a)) | ||
return np.log10(1.0 * a / (1.0 - a)) | ||
|
||
def inverted(self): | ||
return LogisticTransform(self._nonpos) | ||
|
||
|
||
class LogisticTransform(Transform): | ||
input_dims = 1 | ||
output_dims = 1 | ||
is_separable = True | ||
has_inverse = True | ||
|
||
def __init__(self, nonpos='mask'): | ||
Transform.__init__(self) | ||
self._nonpos = nonpos | ||
|
||
def transform_non_affine(self, a): | ||
"""logistic transform (base 10)""" | ||
return 1.0 / (1 + 10**(-a)) | ||
|
||
def inverted(self): | ||
return LogitTransform(self._nonpos) | ||
|
||
|
||
class LogitScale(ScaleBase): | ||
""" | ||
Logit scale for data between zero and one, both excluded. | ||
|
||
This scale is similar to a log scale close to zero and to one, and almost | ||
linear around 0.5. It maps the interval ]0, 1[ onto ]-infty, +infty[. | ||
""" | ||
name = 'logit' | ||
|
||
def __init__(self, axis, nonpos='mask'): | ||
""" | ||
*nonpos*: ['mask' | 'clip' ] | ||
values beyond ]0, 1[ can be masked as invalid, or clipped to a number | ||
very close to 0 or 1 | ||
""" | ||
if nonpos not in ['mask', 'clip']: | ||
raise ValueError("nonposx, nonposy kwarg must be 'mask' or 'clip'") | ||
|
||
self._transform = LogitTransform(nonpos) | ||
|
||
def get_transform(self): | ||
""" | ||
Return a :class:`LogitTransform` instance. | ||
""" | ||
return self._transform | ||
|
||
def set_default_locators_and_formatters(self, axis): | ||
# ..., 0.01, 0.1, 0.5, 0.9, 0.99, ... | ||
axis.set_major_locator(LogitLocator()) | ||
axis.set_major_formatter(LogitFormatter()) | ||
axis.set_minor_locator(LogitLocator(minor=True)) | ||
axis.set_minor_formatter(LogitFormatter()) | ||
|
||
def limit_range_for_scale(self, vmin, vmax, minpos): | ||
return (vmin <= 0 and minpos or vmin, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it hard to parse this expression; and it needs a docstring at least to explain what There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add a short docstring, but it's the same like the analog function for LogTransforms. Again, I'm just being consistent with the codebase. minpos is not documented in the LogTransform either, but it is documented in the base class. If it gets explained in LogitTransform, it will be in LogTransform too, is my opinion. |
||
vmax >= 1 and (1 - minpos) or vmax) | ||
|
||
|
||
_scale_mapping = { | ||
'linear': LinearScale, | ||
'log': LogScale, | ||
'symlog': SymmetricalLogScale | ||
'symlog': SymmetricalLogScale, | ||
'logit': LogitScale, | ||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,20 @@ def test_log_scales(): | |
ax.axhline(24.1) | ||
|
||
|
||
@image_comparison(baseline_images=['logit_scales'], remove_text=True, | ||
extensions=['png']) | ||
def test_logit_scales(): | ||
ax = plt.subplot(111, xscale='logit') | ||
|
||
# Typical exctinction curve for logit | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: -> extinction There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
x = np.array([0.001, 0.003, 0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5, | ||
0.6, 0.7, 0.8, 0.9, 0.97, 0.99, 0.997, 0.999]) | ||
y = 1.0 / x | ||
|
||
ax.plot(x, y) | ||
ax.grid(True) | ||
|
||
|
||
@cleanup | ||
def test_log_scatter(): | ||
"""Issue #1799""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -807,6 +807,26 @@ def __call__(self, x, pos=None): | |
nearest_long(fx)) | ||
|
||
|
||
class LogitFormatter(Formatter): | ||
'''Probability formatter (using Math text)''' | ||
def __call__(self, x, pos=None): | ||
s = '' | ||
if 0.01 <= x <= 0.99: | ||
if x in [.01, 0.1, 0.5, 0.9, 0.99]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have to be so restrictive? Only those few values? Silently returning an empty string otherwise? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fair enough, fixed |
||
s = '{:.2f}'.format(x) | ||
elif x < 0.01: | ||
if is_decade(x): | ||
s = '$10^{%.0f}$' % np.log10(x) | ||
elif x > 0.99: | ||
if is_decade(1-x): | ||
s = '$1-10^{%.0f}$' % np.log10(1-x) | ||
return s | ||
|
||
def format_data_short(self, value): | ||
'return a short formatted string representation of a number' | ||
return '%-12g' % value | ||
|
||
|
||
class EngFormatter(Formatter): | ||
""" | ||
Formats axis values using engineering prefixes to represent powers of 1000, | ||
|
@@ -1694,6 +1714,88 @@ def view_limits(self, vmin, vmax): | |
return result | ||
|
||
|
||
class LogitLocator(Locator): | ||
""" | ||
Determine the tick locations for logit axes | ||
""" | ||
|
||
def __init__(self, minor=False): | ||
""" | ||
place ticks on the logit locations | ||
""" | ||
self.minor = minor | ||
|
||
def __call__(self): | ||
'Return the locations of the ticks' | ||
vmin, vmax = self.axis.get_view_interval() | ||
return self.tick_values(vmin, vmax) | ||
|
||
def tick_values(self, vmin, vmax): | ||
# dummy axis has no axes attribute | ||
if hasattr(self.axis, 'axes') and self.axis.axes.name == 'polar': | ||
raise NotImplementedError('Polar axis cannot be logit scaled yet') | ||
|
||
# what to do if a window beyond ]0, 1[ is chosen | ||
if vmin <= 0.0: | ||
if self.axis is not None: | ||
vmin = self.axis.get_minpos() | ||
|
||
if (vmin <= 0.0) or (not np.isfinite(vmin)): | ||
raise ValueError( | ||
"Data has no values in ]0, 1[ and therefore can not be " | ||
"logit-scaled.") | ||
|
||
# NOTE: for vmax, we should query a property similar to get_minpos, but | ||
# related to the maximal, less-than-one data point. Unfortunately, | ||
# get_minpos is defined very deep in the BBox and updated with data, | ||
# so for now we use the trick below. | ||
if vmax >= 1.0: | ||
if self.axis is not None: | ||
vmax = 1 - self.axis.get_minpos() | ||
|
||
if (vmax >= 1.0) or (not np.isfinite(vmax)): | ||
raise ValueError( | ||
"Data has no values in ]0, 1[ and therefore can not be " | ||
"logit-scaled.") | ||
|
||
if vmax < vmin: | ||
vmin, vmax = vmax, vmin | ||
|
||
vmin = np.log10(vmin / (1 - vmin)) | ||
vmax = np.log10(vmax / (1 - vmax)) | ||
|
||
decade_min = np.floor(vmin) | ||
decade_max = np.ceil(vmax) | ||
|
||
# major ticks | ||
if not self.minor: | ||
ticklocs = [] | ||
if (decade_min <= -1): | ||
expo = np.arange(decade_min, min(0, decade_max + 1)) | ||
ticklocs.extend(list(10**expo)) | ||
if (decade_min <= 0) and (decade_max >= 0): | ||
ticklocs.append(0.5) | ||
if (decade_max >= 1): | ||
expo = -np.arange(max(1, decade_min), decade_max + 1) | ||
ticklocs.extend(list(1 - 10**expo)) | ||
|
||
# minor ticks | ||
else: | ||
ticklocs = [] | ||
if (decade_min <= -2): | ||
expo = np.arange(decade_min, min(-1, decade_max)) | ||
newticks = np.outer(np.arange(2, 10), 10**expo).ravel() | ||
ticklocs.extend(list(newticks)) | ||
if (decade_min <= 0) and (decade_max >= 0): | ||
ticklocs.extend([0.2, 0.3, 0.4, 0.6, 0.7, 0.8]) | ||
if (decade_max >= 2): | ||
expo = -np.arange(max(2, decade_min), decade_max + 1) | ||
newticks = 1 - np.outer(np.arange(2, 10), 10**expo).ravel() | ||
ticklocs.extend(list(newticks)) | ||
|
||
return self.raise_if_exceeds(np.array(ticklocs)) | ||
|
||
|
||
class AutoLocator(MaxNLocator): | ||
def __init__(self): | ||
MaxNLocator.__init__(self, nbins=9, steps=[1, 2, 5, 10]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring does not match what the function is doing. It is always returning a copy of the input array in which any values outside ]0, 1[ are replaced by nan.