Skip to content

Commit 646e2c8

Browse files
legend-for-scatter
1 parent d31d102 commit 646e2c8

File tree

4 files changed

+285
-6
lines changed

4 files changed

+285
-6
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
:orphan:
2+
3+
Legend for scatter
4+
------------------
5+
6+
A new method for creating legends for scatter plots has been introduced.
7+
Previously, in order to obtain a legend for a :meth:`~.axes.Axes.scatter`
8+
plot, one could either plot several scatters, each with an individual label,
9+
or create proxy artists to show in the legend manually.
10+
Now, :class:`~.collections.PathCollection` provides a method
11+
:meth:`~.collections.PathCollection.legend_elements` to obtain the handles and labels
12+
for a scatter plot in an automated way. This makes creating a legend for a
13+
scatter plot as easy as::
14+
15+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
16+
plt.legend(*scatter.legend_elements())
17+
18+
An example can be found in
19+
:ref:`automatedlegendcreation`.

examples/lines_bars_and_markers/scatter_with_legend.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,116 @@
33
Scatter plots with a legend
44
===========================
55
6-
Also demonstrates how transparency of the markers
7-
can be adjusted by giving ``alpha`` a value between
8-
0 and 1.
6+
To create a scatter plot with a legend one may use a loop and create one
7+
`~.Axes.scatter` plot per item to appear in the legend and set the ``label``
8+
accordingly.
9+
10+
The following also demonstrates how transparency of the markers
11+
can be adjusted by giving ``alpha`` a value between 0 and 1.
912
"""
1013

14+
import numpy as np
15+
np.random.seed(19680801)
1116
import matplotlib.pyplot as plt
12-
from numpy.random import rand
1317

1418

1519
fig, ax = plt.subplots()
1620
for color in ['tab:blue', 'tab:orange', 'tab:green']:
1721
n = 750
18-
x, y = rand(2, n)
19-
scale = 200.0 * rand(n)
22+
x, y = np.random.rand(2, n)
23+
scale = 200.0 * np.random.rand(n)
2024
ax.scatter(x, y, c=color, s=scale, label=color,
2125
alpha=0.3, edgecolors='none')
2226

2327
ax.legend()
2428
ax.grid(True)
2529

2630
plt.show()
31+
32+
33+
##############################################################################
34+
# .. _automatedlegendcreation:
35+
#
36+
# Automated legend creation
37+
# -------------------------
38+
#
39+
# Another option for creating a legend for a scatter is to use the
40+
# :class:`~matplotlib.collections.PathCollection`'s
41+
# :meth:`~.PathCollection.legend_elements` method.
42+
# It will automatically try to determine a useful number of legend entries
43+
# to be shown and return a tuple of handles and labels. Those can be passed
44+
# to the call to :meth:`~.axes.Axes.legend`.
45+
46+
47+
N = 45
48+
x, y = np.random.rand(2, N)
49+
c = np.random.randint(1, 5, size=N)
50+
s = np.random.randint(10, 220, size=N)
51+
52+
fig, ax = plt.subplots()
53+
54+
scatter = ax.scatter(x, y, c=c, s=s)
55+
56+
# produce a legend with the unique colors from the scatter
57+
legend1 = ax.legend(*scatter.legend_elements(),
58+
loc="lower left", title="Classes")
59+
ax.add_artist(legend1)
60+
61+
# produce a legend with a cross section of sizes from the scatter
62+
handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6)
63+
legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes")
64+
65+
plt.show()
66+
67+
68+
##############################################################################
69+
# Further arguments to the :meth:`~.PathCollection.legend_elements` method
70+
# can be used to steer how many legend entries are to be created and how they
71+
# should be labeled. The following shows how to use some of them.
72+
#
73+
74+
volume = np.random.rayleigh(27, size=40)
75+
amount = np.random.poisson(10, size=40)
76+
ranking = np.random.normal(size=40)
77+
price = np.random.uniform(1, 10, size=40)
78+
79+
fig, ax = plt.subplots()
80+
81+
# Because the price is much too small when being provided as size for ``s``,
82+
# we normalize it to some useful point sizes, s=0.3*(price*3)**2
83+
scatter = ax.scatter(volume, amount, c=ranking, s=0.3*(price*3)**2,
84+
vmin=-3, vmax=3, cmap="Spectral")
85+
86+
# Produce a legend for the ranking (colors). Even though there are 40 different
87+
# rankings, we only want to show 5 of them in the legend.
88+
legend1 = ax.legend(*scatter.legend_elements(num=5),
89+
loc="upper left", title="Ranking")
90+
ax.add_artist(legend1)
91+
92+
# Produce a legend for the price (sizes). Because we want to show the prices
93+
# in dollars, we use the *func* argument to supply the inverse of the function
94+
# used to calculate the sizes from above. The *fmt* ensures to show the price
95+
# in dollars. Note how we target at 5 elements here, but obtain only 4 in the
96+
# created legend due to the automatic round prices that are chosen for us.
97+
kw = dict(prop="sizes", num=5, color=scatter.cmap(0.7), fmt="$ {x:.2f}",
98+
func=lambda s: np.sqrt(s/.3)/3)
99+
legend2 = ax.legend(*scatter.legend_elements(**kw),
100+
loc="lower right", title="Price")
101+
102+
plt.show()
103+
104+
#############################################################################
105+
#
106+
# ------------
107+
#
108+
# References
109+
# """"""""""
110+
#
111+
# The usage of the following functions and methods is shown in this example:
112+
113+
import matplotlib
114+
matplotlib.axes.Axes.scatter
115+
matplotlib.pyplot.scatter
116+
matplotlib.axes.Axes.legend
117+
matplotlib.pyplot.legend
118+
matplotlib.collections.PathCollection.legend_elements

lib/matplotlib/collections.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,7 @@ def draw(self, renderer):
868868
class PathCollection(_CollectionWithSizes):
869869
"""
870870
This is the most basic :class:`Collection` subclass.
871+
A :class:`PathCollection` is e.g. created by a :meth:`~.Axes.scatter` plot.
871872
"""
872873
@docstring.dedent_interpd
873874
def __init__(self, paths, sizes=None, **kwargs):
@@ -890,6 +891,125 @@ def set_paths(self, paths):
890891
def get_paths(self):
891892
return self._paths
892893

894+
def legend_elements(self, prop="colors", num="auto",
895+
fmt=None, func=lambda x: x, **kwargs):
896+
"""
897+
Creates legend handles and labels for a PathCollection. This is useful
898+
for obtaining a legend for a :meth:`~.Axes.scatter` plot. E.g.::
899+
900+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
901+
plt.legend(*scatter.legend_elements())
902+
903+
Also see the :ref:`automatedlegendcreation` example.
904+
905+
Parameters
906+
----------
907+
prop : string, optional, default *"colors"*
908+
Can be *"colors"* or *"sizes"*. In case of *"colors"*, the legend
909+
handles will show the different colors of the collection. In case
910+
of "sizes", the legend will show the different sizes.
911+
num : int, None, "auto" (default), or `~.ticker.Locator`, optional
912+
Target number of elements to create.
913+
If None, use all unique elements of the mappable array. If an
914+
integer, target to use *num* elements in the normed range.
915+
If *"auto"*, try to determine which option better suits the nature
916+
of the data.
917+
The number of created elements may slightly deviate from *num* due
918+
to a `~.ticker.Locator` being used to find useful locations.
919+
Finally, a `~.ticker.Locator` can be provided.
920+
fmt : string, `~matplotlib.ticker.Formatter`, or None (default)
921+
The format or formatter to use for the labels. If a string must be
922+
a valid input for a `~.StrMethodFormatter`. If None (the default),
923+
use a `~.ScalarFormatter`.
924+
func : function, default *lambda x: x*
925+
Function to calculate the labels. Often the size (or color)
926+
argument to :meth:`~.Axes.scatter` will have been pre-processed
927+
by the user using a function *s = f(x)* to make the markers
928+
visible; e.g. *size = np.log10(x)*. Providing the inverse of this
929+
function here allows that pre-processing to be inverted, so that
930+
the legend labels have the correct values;
931+
e.g. *func = np.exp(x, 10)*.
932+
kwargs : further parameters
933+
Allowed kwargs are *color* and *size*. E.g. it may be useful to
934+
set the color of the markers if *prop="sizes"* is used; similarly
935+
to set the size of the markers if *prop="colors"* is used.
936+
Any further parameters are passed onto the `.Line2D` instance.
937+
This may be useful to e.g. specify a different *markeredgecolor* or
938+
*alpha* for the legend handles.
939+
940+
Returns
941+
-------
942+
tuple (handles, labels)
943+
with *handles* being a list of `.Line2D` objects
944+
and *labels* a matching list of strings.
945+
"""
946+
handles = []
947+
labels = []
948+
hasarray = self.get_array() is not None
949+
if fmt is None:
950+
fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
951+
elif isinstance(fmt, str):
952+
fmt = mpl.ticker.StrMethodFormatter(fmt)
953+
fmt.create_dummy_axis()
954+
955+
if prop == "colors" and hasarray:
956+
u = np.unique(self.get_array())
957+
size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
958+
elif prop == "sizes":
959+
u = np.unique(self.get_sizes())
960+
color = kwargs.pop("color", "k")
961+
else:
962+
warnings.warn("Invalid prop provided, or collection without "
963+
"array used.")
964+
return handles, labels
965+
966+
fmt.set_bounds(func(u).min(), func(u).max())
967+
if num == "auto":
968+
num = 9
969+
if len(u) <= num:
970+
num = None
971+
if num is None:
972+
values = u
973+
label_values = func(values)
974+
else:
975+
if prop == "colors":
976+
arr = self.get_array()
977+
elif prop == "sizes":
978+
arr = self.get_sizes()
979+
if isinstance(num, mpl.ticker.Locator):
980+
loc = num
981+
else:
982+
num = int(num)
983+
loc = mpl.ticker.MaxNLocator(nbins=num, min_n_ticks=num-1,
984+
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
985+
label_values = loc.tick_values(func(arr).min(), func(arr).max())
986+
cond = ((label_values >= func(arr).min()) &
987+
(label_values <= func(arr).max()))
988+
label_values = label_values[cond]
989+
xarr = np.linspace(arr.min(), arr.max(), 256)
990+
values = np.interp(label_values, func(xarr), xarr)
991+
992+
kw = dict(markeredgewidth=self.get_linewidths()[0],
993+
alpha=self.get_alpha())
994+
kw.update(kwargs)
995+
996+
for val, lab in zip(values, label_values):
997+
if prop == "colors":
998+
color = self.cmap(self.norm(val))
999+
elif prop == "sizes":
1000+
size = np.sqrt(val)
1001+
if np.isclose(size, 0.0):
1002+
continue
1003+
h = mlines.Line2D([0], [0], ls="", color=color, ms=size,
1004+
marker=self.get_paths()[0], **kw)
1005+
handles.append(h)
1006+
if hasattr(fmt, "set_locs"):
1007+
fmt.set_locs(label_values)
1008+
l = fmt(lab)
1009+
labels.append(l)
1010+
1011+
return handles, labels
1012+
8931013

8941014
class PolyCollection(_CollectionWithSizes):
8951015
@docstring.dedent_interpd

lib/matplotlib/tests/test_collections.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,51 @@ def test_scatter_post_alpha():
669669
# this needs to be here to update internal state
670670
fig.canvas.draw()
671671
sc.set_alpha(.1)
672+
673+
674+
def test_pathcollection_legend_elements():
675+
np.random.seed(19680801)
676+
x, y = np.random.rand(2, 10)
677+
y = np.random.rand(10)
678+
c = np.random.randint(0, 5, size=10)
679+
s = np.random.randint(10, 300, size=10)
680+
681+
fig, ax = plt.subplots()
682+
sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
683+
684+
h, l = sc.legend_elements(fmt="{x:g}")
685+
assert len(h) == 5
686+
assert_array_equal(np.array(l).astype(float), np.arange(5))
687+
colors = np.array([line.get_color() for line in h])
688+
colors2 = sc.cmap(np.arange(5)/4)
689+
assert_array_equal(colors, colors2)
690+
l1 = ax.legend(h, l, loc=1)
691+
692+
h2, lab2 = sc.legend_elements(num=9)
693+
assert len(h2) == 9
694+
l2 = ax.legend(h2, lab2, loc=2)
695+
696+
h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
697+
alpha = np.array([line.get_alpha() for line in h])
698+
assert_array_equal(alpha, 0.5)
699+
color = np.array([line.get_markerfacecolor() for line in h])
700+
assert_array_equal(color, "red")
701+
l3 = ax.legend(h, l, loc=4)
702+
703+
h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
704+
func=lambda x: 2*x)
705+
actsizes = [line.get_markersize() for line in h]
706+
labeledsizes = np.sqrt(np.array(l).astype(float)/2)
707+
assert_array_almost_equal(actsizes, labeledsizes)
708+
l4 = ax.legend(h, l, loc=3)
709+
710+
import matplotlib.ticker as mticker
711+
loc = mticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
712+
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
713+
h5, lab5 = sc.legend_elements(num=loc)
714+
assert len(h2) == len(h5)
715+
716+
for l in [l1, l2, l3, l4]:
717+
ax.add_artist(l)
718+
719+
fig.canvas.draw()

0 commit comments

Comments
 (0)