Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ per-file-ignores =
examples/lines_bars_and_markers/filled_step.py: E402
examples/lines_bars_and_markers/joinstyle.py: E402
examples/lines_bars_and_markers/scatter_piecharts.py: E402
examples/lines_bars_and_markers/scatter_with_legend.py: E402
examples/lines_bars_and_markers/span_regions.py: E402
examples/lines_bars_and_markers/step_demo.py: E402
examples/misc/agg_buffer.py: E402
Expand Down
19 changes: 19 additions & 0 deletions doc/users/next_whats_new/legend_for_scatter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
:orphan:

Legend for scatter
------------------

A new method for creating legends for scatter plots has been introduced.
Previously, in order to obtain a legend for a :meth:`~.axes.Axes.scatter`
plot, one could either plot several scatters, each with an individual label,
or create proxy artists to show in the legend manually.
Now, :class:`~.collections.PathCollection` provides a method
:meth:`~.collections.PathCollection.legend_elements` to obtain the handles and labels
for a scatter plot in an automated way. This makes creating a legend for a
scatter plot as easy as::

scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
plt.legend(*scatter.legend_elements())

An example can be found in
:ref:`automatedlegendcreation`.
104 changes: 98 additions & 6 deletions examples/lines_bars_and_markers/scatter_with_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,116 @@
Scatter plots with a legend
===========================

Also demonstrates how transparency of the markers
can be adjusted by giving ``alpha`` a value between
0 and 1.
To create a scatter plot with a legend one may use a loop and create one
`~.Axes.scatter` plot per item to appear in the legend and set the ``label``
accordingly.

The following also demonstrates how transparency of the markers
can be adjusted by giving ``alpha`` a value between 0 and 1.
"""

import numpy as np
np.random.seed(19680801)
import matplotlib.pyplot as plt
from numpy.random import rand


fig, ax = plt.subplots()
for color in ['tab:blue', 'tab:orange', 'tab:green']:
n = 750
x, y = rand(2, n)
scale = 200.0 * rand(n)
x, y = np.random.rand(2, n)
scale = 200.0 * np.random.rand(n)
ax.scatter(x, y, c=color, s=scale, label=color,
alpha=0.3, edgecolors='none')

ax.legend()
ax.grid(True)

plt.show()


##############################################################################
# .. _automatedlegendcreation:
#
# Automated legend creation
# -------------------------
#
# Another option for creating a legend for a scatter is to use the
# :class:`~matplotlib.collections.PathCollection`'s
# :meth:`~.PathCollection.legend_elements` method.
# It will automatically try to determine a useful number of legend entries
# to be shown and return a tuple of handles and labels. Those can be passed
# to the call to :meth:`~.axes.Axes.legend`.


N = 45
x, y = np.random.rand(2, N)
c = np.random.randint(1, 5, size=N)
s = np.random.randint(10, 220, size=N)

fig, ax = plt.subplots()

scatter = ax.scatter(x, y, c=c, s=s)

# produce a legend with the unique colors from the scatter
legend1 = ax.legend(*scatter.legend_elements(),
loc="lower left", title="Classes")
ax.add_artist(legend1)

# produce a legend with a cross section of sizes from the scatter
handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6)
legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes")

plt.show()


##############################################################################
# Further arguments to the :meth:`~.PathCollection.legend_elements` method
# can be used to steer how many legend entries are to be created and how they
# should be labeled. The following shows how to use some of them.
#

volume = np.random.rayleigh(27, size=40)
amount = np.random.poisson(10, size=40)
ranking = np.random.normal(size=40)
price = np.random.uniform(1, 10, size=40)

fig, ax = plt.subplots()

# Because the price is much too small when being provided as size for ``s``,
# we normalize it to some useful point sizes, s=0.3*(price*3)**2
scatter = ax.scatter(volume, amount, c=ranking, s=0.3*(price*3)**2,
vmin=-3, vmax=3, cmap="Spectral")

# Produce a legend for the ranking (colors). Even though there are 40 different
# rankings, we only want to show 5 of them in the legend.
legend1 = ax.legend(*scatter.legend_elements(num=5),
loc="upper left", title="Ranking")
ax.add_artist(legend1)

# Produce a legend for the price (sizes). Because we want to show the prices
# in dollars, we use the *func* argument to supply the inverse of the function
# used to calculate the sizes from above. The *fmt* ensures to show the price
# in dollars. Note how we target at 5 elements here, but obtain only 4 in the
# created legend due to the automatic round prices that are chosen for us.
kw = dict(prop="sizes", num=5, color=scatter.cmap(0.7), fmt="$ {x:.2f}",
func=lambda s: np.sqrt(s/.3)/3)
legend2 = ax.legend(*scatter.legend_elements(**kw),
loc="lower right", title="Price")

plt.show()

#############################################################################
#
# ------------
#
# References
# """"""""""
#
# The usage of the following functions and methods is shown in this example:

import matplotlib
matplotlib.axes.Axes.scatter
matplotlib.pyplot.scatter
matplotlib.axes.Axes.legend
matplotlib.pyplot.legend
matplotlib.collections.PathCollection.legend_elements
129 changes: 129 additions & 0 deletions lib/matplotlib/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import matplotlib as mpl
from . import (_path, artist, cbook, cm, colors as mcolors, docstring,
lines as mlines, path as mpath, transforms)
import warnings


@cbook._define_aliases({
Expand Down Expand Up @@ -868,6 +869,7 @@ def draw(self, renderer):
class PathCollection(_CollectionWithSizes):
"""
This is the most basic :class:`Collection` subclass.
A :class:`PathCollection` is e.g. created by a :meth:`~.Axes.scatter` plot.
"""
@docstring.dedent_interpd
def __init__(self, paths, sizes=None, **kwargs):
Expand All @@ -890,6 +892,133 @@ def set_paths(self, paths):
def get_paths(self):
return self._paths

def legend_elements(self, prop="colors", num="auto",
fmt=None, func=lambda x: x, **kwargs):
"""
Creates legend handles and labels for a PathCollection. This is useful
for obtaining a legend for a :meth:`~.Axes.scatter` plot. E.g.::

scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
plt.legend(*scatter.legend_elements())

Also see the :ref:`automatedlegendcreation` example.

Parameters
----------
prop : string, optional, default *"colors"*
Can be *"colors"* or *"sizes"*. In case of *"colors"*, the legend
handles will show the different colors of the collection. In case
of "sizes", the legend will show the different sizes.
num : int, None, "auto" (default), array-like, or `~.ticker.Locator`,
optional
Target number of elements to create.
If None, use all unique elements of the mappable array. If an
integer, target to use *num* elements in the normed range.
If *"auto"*, try to determine which option better suits the nature
of the data.
The number of created elements may slightly deviate from *num* due
to a `~.ticker.Locator` being used to find useful locations.
If a list or array, use exactly those elements for the legend.
Finally, a `~.ticker.Locator` can be provided.
fmt : string, `~matplotlib.ticker.Formatter`, or None (default)
The format or formatter to use for the labels. If a string must be
a valid input for a `~.StrMethodFormatter`. If None (the default),
use a `~.ScalarFormatter`.
func : function, default *lambda x: x*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this too confusing. What does *func* do and why do you need it? Maybe this needs an example...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

func would allow to calculate the values to be shown in the legend. This is mainly useful for "sizes", because unlike the c parameter in scatter, which takes numbers and maps them to color using a Normalize, the s parameter does not have a normalization. This means that usually you would encode your data into sizes and provide those sizes in the scatter call. However in the legend you might still want to show the original data's values. Example: You use the s to show price ranging between 0.5 and 10 dollars. Now you cannot use those numbers directly as scatter sizes, because that would result in tiny dots. So you convert them to sizes via s = lambda price : 0.5*(price*10)^2. Now func would allow you to provide price = lambda s: np.sqrt(2*s)/10 to have a nice legend using the actual prices in dollars.

I think you are right that an example might be useful here. I could add that in addition to the already extended example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I was particularly flummoxed by "This converts the initial values for color or size and needs to take a numpy array as input." I wasn't sure what "initial" meant.

For the doc string maybe: "Function to calculate the labels. Often the size (or color) argument to scatter will have been pre-processed by the user using a function s = f(x) to make the markers visible; eg size = np.log10(x)`. Providing this function allows that pre-processing to be inverted, so that the legend labels have the correct values. eg np.exp(x, 10)``"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reformulated the docstring and also added an additional example now.

Function to calculate the labels. Often the size (or color)
argument to :meth:`~.Axes.scatter` will have been pre-processed
by the user using a function *s = f(x)* to make the markers
visible; e.g. *size = np.log10(x)*. Providing the inverse of this
function here allows that pre-processing to be inverted, so that
the legend labels have the correct values;
e.g. *func = np.exp(x, 10)*.
kwargs : further parameters
Allowed kwargs are *color* and *size*. E.g. it may be useful to
set the color of the markers if *prop="sizes"* is used; similarly
to set the size of the markers if *prop="colors"* is used.
Any further parameters are passed onto the `.Line2D` instance.
This may be useful to e.g. specify a different *markeredgecolor* or
*alpha* for the legend handles.

Returns
-------
tuple (handles, labels)
with *handles* being a list of `.Line2D` objects
and *labels* a matching list of strings.
"""
handles = []
labels = []
hasarray = self.get_array() is not None
if fmt is None:
fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
elif isinstance(fmt, str):
fmt = mpl.ticker.StrMethodFormatter(fmt)
fmt.create_dummy_axis()

if prop == "colors":
if not hasarray:
warnings.warn("Collection without array used. Make sure to "
"specify the values to be colormapped via the "
"`c` argument.")
return handles, labels
u = np.unique(self.get_array())
size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
elif prop == "sizes":
u = np.unique(self.get_sizes())
color = kwargs.pop("color", "k")
else:
raise ValueError("Valid values for `prop` are 'colors' or "
f"'sizes'. You supplied '{prop}' instead.")

fmt.set_bounds(func(u).min(), func(u).max())
if num == "auto":
num = 9
if len(u) <= num:
num = None
if num is None:
values = u
label_values = func(values)
else:
if prop == "colors":
arr = self.get_array()
elif prop == "sizes":
arr = self.get_sizes()
if isinstance(num, mpl.ticker.Locator):
loc = num
elif np.iterable(num):
loc = mpl.ticker.FixedLocator(num)
else:
num = int(num)
loc = mpl.ticker.MaxNLocator(nbins=num, min_n_ticks=num-1,
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
label_values = loc.tick_values(func(arr).min(), func(arr).max())
cond = ((label_values >= func(arr).min()) &
(label_values <= func(arr).max()))
label_values = label_values[cond]
xarr = np.linspace(arr.min(), arr.max(), 256)
values = np.interp(label_values, func(xarr), xarr)

kw = dict(markeredgewidth=self.get_linewidths()[0],
alpha=self.get_alpha())
kw.update(kwargs)

for val, lab in zip(values, label_values):
if prop == "colors":
color = self.cmap(self.norm(val))
elif prop == "sizes":
size = np.sqrt(val)
if np.isclose(size, 0.0):
continue
h = mlines.Line2D([0], [0], ls="", color=color, ms=size,
marker=self.get_paths()[0], **kw)
handles.append(h)
if hasattr(fmt, "set_locs"):
fmt.set_locs(label_values)
l = fmt(lab)
labels.append(l)

return handles, labels


class PolyCollection(_CollectionWithSizes):
@docstring.dedent_interpd
Expand Down
52 changes: 52 additions & 0 deletions lib/matplotlib/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,55 @@ def test_scatter_post_alpha():
# this needs to be here to update internal state
fig.canvas.draw()
sc.set_alpha(.1)


def test_pathcollection_legend_elements():
np.random.seed(19680801)
x, y = np.random.rand(2, 10)
y = np.random.rand(10)
c = np.random.randint(0, 5, size=10)
s = np.random.randint(10, 300, size=10)

fig, ax = plt.subplots()
sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)

h, l = sc.legend_elements(fmt="{x:g}")
assert len(h) == 5
assert_array_equal(np.array(l).astype(float), np.arange(5))
colors = np.array([line.get_color() for line in h])
colors2 = sc.cmap(np.arange(5)/4)
assert_array_equal(colors, colors2)
l1 = ax.legend(h, l, loc=1)

h2, lab2 = sc.legend_elements(num=9)
assert len(h2) == 9
l2 = ax.legend(h2, lab2, loc=2)

h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
alpha = np.array([line.get_alpha() for line in h])
assert_array_equal(alpha, 0.5)
color = np.array([line.get_markerfacecolor() for line in h])
assert_array_equal(color, "red")
l3 = ax.legend(h, l, loc=4)

h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
func=lambda x: 2*x)
actsizes = [line.get_markersize() for line in h]
labeledsizes = np.sqrt(np.array(l).astype(float)/2)
assert_array_almost_equal(actsizes, labeledsizes)
l4 = ax.legend(h, l, loc=3)

import matplotlib.ticker as mticker
loc = mticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
h5, lab5 = sc.legend_elements(num=loc)
assert len(h2) == len(h5)

levels = [-1, 0, 55.4, 260]
h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}")
assert_array_equal(np.array(lab6).astype(float), levels[2:])

for l in [l1, l2, l3, l4]:
ax.add_artist(l)

fig.canvas.draw()