Skip to content

Commit a801e21

Browse files
legend-for-scatter
1 parent 0984694 commit a801e21

File tree

4 files changed

+233
-6
lines changed

4 files changed

+233
-6
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Legend for scatter
2+
------------------
3+
4+
A new method for creating legends for scatter plots has been introduced.
5+
Previously, in order to obtain a legend for a :meth:`~Axes.scatter` plot, one
6+
could either plot several scatters, each with an individual label, or create
7+
proxy artists to show in the legend manually.
8+
Now, :class:`~.collections.PathCollection` provides a method
9+
:meth:`~.collections.PathCollection.legend_items` to obtain the handles and labels
10+
for a scatter plot in an automated way. This makes creating a legend for a
11+
scatter plot as easy as::
12+
13+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
14+
plt.legend(*scatter.legend_items())
15+
16+
An example can be found in
17+
:ref:`automatedlegendcreation`.

examples/lines_bars_and_markers/scatter_with_legend.py

+61-6
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,79 @@
33
Scatter plots with a legend
44
===========================
55
6-
Also demonstrates how transparency of the markers
7-
can be adjusted by giving ``alpha`` a value between
8-
0 and 1.
6+
To create a scatter plot with a legend one may use a loop and create one
7+
`~.Axes.scatter` plot per item to appear in the legend and set the ``label``
8+
accordingly.
9+
10+
The following also demonstrates how transparency of the markers
11+
can be adjusted by giving ``alpha`` a value between 0 and 1.
912
"""
1013

14+
import numpy as np
15+
np.random.seed(19680801)
1116
import matplotlib.pyplot as plt
12-
from numpy.random import rand
1317

1418

1519
fig, ax = plt.subplots()
1620
for color in ['red', 'green', 'blue']:
1721
n = 750
18-
x, y = rand(2, n)
19-
scale = 200.0 * rand(n)
22+
x, y = np.random.rand(2, n)
23+
scale = 200.0 * np.random.rand(n)
2024
ax.scatter(x, y, c=color, s=scale, label=color,
2125
alpha=0.3, edgecolors='none')
2226

2327
ax.legend()
2428
ax.grid(True)
2529

2630
plt.show()
31+
32+
33+
##############################################################################
34+
# .. _automatedlegendcreation:
35+
#
36+
# Automated legend creation
37+
# -------------------------
38+
#
39+
# Another option for creating a legend for a scatter is to use the
40+
# :class:`~matplotlib.collections.PathCollection`'s
41+
# :meth:`~.PathCollection.legend_items` method.
42+
# It will automatically try to determine a useful number of legend entries
43+
# to be shown and return a tuple of handles and labels. Those can be passed
44+
# to the call to :meth:`~.axes.Axes.legend`.
45+
46+
47+
N = 45
48+
x, y = np.random.rand(2, N)
49+
c = np.random.randint(1, 5, size=N)
50+
s = np.random.randint(10, 220, size=N)
51+
52+
fig, ax = plt.subplots()
53+
54+
scatter = ax.scatter(x, y, c=c, s=s)
55+
56+
# produce a legend with the unique colors from the scatter
57+
legend1 = ax.legend(*scatter.legend_items(), loc=3, title="Classes")
58+
ax.add_artist(legend1)
59+
60+
# produce a legend with a cross section of sizes from the scatter
61+
handles, labels = scatter.legend_items(mode="sizes", alpha=0.6)
62+
legend2 = ax.legend(handles, labels, loc=1, title="Sizes")
63+
64+
plt.show()
65+
66+
67+
#############################################################################
68+
#
69+
# ------------
70+
#
71+
# References
72+
# """"""""""
73+
#
74+
# The usage of the following functions and methods is shown in this example:
75+
76+
import matplotlib
77+
matplotlib.axes.Axes.scatter
78+
matplotlib.pyplot.scatter
79+
matplotlib.axes.Axes.legend
80+
matplotlib.pyplot.legend
81+
matplotlib.collections.PathCollection.legend_items

lib/matplotlib/collections.py

+113
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def draw(self, renderer):
877877
class PathCollection(_CollectionWithSizes):
878878
"""
879879
This is the most basic :class:`Collection` subclass.
880+
A :class:`PathCollection` is e.g. created by a :meth:`~.Axes.scatter` plot.
880881
"""
881882
@docstring.dedent_interpd
882883
def __init__(self, paths, sizes=None, **kwargs):
@@ -899,6 +900,118 @@ def set_paths(self, paths):
899900
def get_paths(self):
900901
return self._paths
901902

903+
def legend_items(self, mode="colors", useall="auto", num=10,
904+
fmt=None, func=lambda x: x, **kwargs):
905+
"""
906+
Creates legend handles and labels for a PathCollection. This is useful
907+
for obtaining a legend for a :meth:`~.Axes.scatter` plot. E.g.::
908+
909+
scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3])
910+
plt.legend(*scatter.legend_items())
911+
912+
Also see the :ref:`automatedlegendcreation` example.
913+
914+
Parameters
915+
----------
916+
mode : string, optional, default *"colors"*
917+
Can be *"colors"* or *"sizes"*. In case of *"colors"*, the legend
918+
handles will show the different colors of the collection. In case
919+
of "sizes", the legend will show the different sizes.
920+
useall : bool or string "auto", optional (default "auto")
921+
If True, use all unique elements of the mappable array. If False,
922+
target to use *num* elements in the normed range. If *"auto"*, try
923+
to determine which option better suits the nature of the data.
924+
num : int or None, optional (default 10)
925+
Target number of elements to create in case of *useall=False*.
926+
The number of created elements may slightly deviate from that due
927+
to a `~.ticker.Locator` being used to find useful locations.
928+
fmt : string, `~matplotlib.ticker.Formatter`, or None (default)
929+
The format or formatter to use for the labels. If a string must be
930+
a valid input for a `~.StrMethodFormatter`. If None (the default),
931+
use a `~.ScalarFormatter`.
932+
func : function, default *lambda x: x*
933+
Function to calculate the shown labels. This converts the initial
934+
values for color or size and needs to take a numpy array as input.
935+
Note that if e.g. the :meth:`~.Axes.scatter`'s *s* parameter has
936+
been calculated from some values *x* via a function
937+
*f* as *s = f(x)*, you need to supply the inverse of that function
938+
*f_inv* here, *func = f_inv*.
939+
kwargs : further parameters
940+
Allowed kwargs are *color* and *size*. E.g. it may be useful to
941+
set the color of the markers if *mode="sizes"* is used; similarly
942+
to set the size of the markers if *mode="colors"* is used.
943+
Any further parameters are passed onto the `.Line2D` instance.
944+
This may be useful to e.g. specify a different *markeredgecolor* or
945+
*alpha* for the legend handles.
946+
947+
Returns
948+
-------
949+
tuple (handles, labels)
950+
with *handles* being a list of `.Line2D` objects
951+
and *labels* a list of strings of the same length.
952+
"""
953+
handles = []
954+
labels = []
955+
hasarray = self.get_array() is not None
956+
if fmt is None:
957+
fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True)
958+
elif type(fmt) == str:
959+
fmt = mpl.ticker.StrMethodFormatter(fmt)
960+
fmt.create_dummy_axis()
961+
962+
if mode == "colors" and hasarray:
963+
u = np.unique(self.get_array())
964+
size = kwargs.pop("size", mpl.rcParams["lines.markersize"])
965+
elif mode == "sizes":
966+
u = np.unique(self.get_sizes())
967+
color = kwargs.pop("color", "k")
968+
else:
969+
warnings.warn("Invalid mode provided, or collections without "
970+
"array used.")
971+
972+
fmt.set_bounds(func(u).min(), func(u).max())
973+
if useall == "auto":
974+
useall = False
975+
if len(u) <= num:
976+
useall = True
977+
if useall:
978+
values = u
979+
label_values = func(values)
980+
else:
981+
if mode == "colors" and hasarray:
982+
arr = self.get_array()
983+
elif mode == "sizes":
984+
arr = self.get_sizes()
985+
loc = mpl.ticker.MaxNLocator(nbins=num, min_n_ticks=num-1,
986+
steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
987+
label_values = loc.tick_values(func(arr).min(), func(arr).max())
988+
cond = (label_values >= func(arr).min()) & \
989+
(label_values <= func(arr).max())
990+
label_values = label_values[cond]
991+
xarr = np.linspace(arr.min(), arr.max(), 256)
992+
values = np.interp(label_values, func(xarr), xarr)
993+
994+
kw = dict(markeredgewidth=self.get_linewidths()[0],
995+
alpha=self.get_alpha())
996+
kw.update(kwargs)
997+
998+
for val, lab in zip(values, label_values):
999+
if mode == "colors" and hasarray:
1000+
color = self.cmap(self.norm(val))
1001+
elif mode == "sizes":
1002+
size = np.sqrt(val)
1003+
if np.isclose(size, 0.0):
1004+
continue
1005+
h = mlines.Line2D([0], [0], ls="", color=color, ms=size,
1006+
marker=self.get_paths()[0], **kw)
1007+
handles.append(h)
1008+
if hasattr(fmt, "set_locs"):
1009+
fmt.set_locs(label_values)
1010+
l = fmt(lab)
1011+
labels.append(l)
1012+
1013+
return handles, labels
1014+
9021015

9031016
class PolyCollection(_CollectionWithSizes):
9041017
@docstring.dedent_interpd

lib/matplotlib/tests/test_collections.py

+42
Original file line numberDiff line numberDiff line change
@@ -669,3 +669,45 @@ def test_scatter_post_alpha():
669669
# this needs to be here to update internal state
670670
fig.canvas.draw()
671671
sc.set_alpha(.1)
672+
673+
674+
def test_pathcollection_legend_items():
675+
np.random.seed(19680801)
676+
x, y = np.random.rand(2, 10)
677+
y = np.random.rand(10)
678+
c = np.random.randint(0, 5, size=10)
679+
s = np.random.randint(10, 300, size=10)
680+
681+
fig, ax = plt.subplots()
682+
sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
683+
684+
h, l = sc.legend_items(fmt="{x:g}")
685+
assert len(h) == 5
686+
assert_array_equal(np.array(l).astype(float), np.arange(5))
687+
colors = np.array([line.get_color() for line in h])
688+
colors2 = sc.cmap(np.arange(5)/4)
689+
assert_array_equal(colors, colors2)
690+
l1 = ax.legend(h, l, loc=1)
691+
692+
h, l = sc.legend_items(useall=False)
693+
assert len(h) == 9
694+
l2 = ax.legend(h, l, loc=2)
695+
696+
h, l = sc.legend_items(mode="sizes", useall=False, alpha=0.5, color="red")
697+
alpha = np.array([line.get_alpha() for line in h])
698+
assert_array_equal(alpha, 0.5)
699+
color = np.array([line.get_markerfacecolor() for line in h])
700+
assert_array_equal(color, "red")
701+
l3 = ax.legend(h, l, loc=4)
702+
703+
h, l = sc.legend_items(mode="sizes", useall=False, num=4, fmt="{x:.2f}",
704+
func=lambda x: 2*x)
705+
actsizes = [line.get_markersize() for line in h]
706+
labeledsizes = np.sqrt(np.array(l).astype(float)/2)
707+
assert_array_almost_equal(actsizes, labeledsizes)
708+
l4 = ax.legend(h, l, loc=3)
709+
710+
for l in [l1, l2, l3, l4]:
711+
ax.add_artist(l)
712+
713+
fig.canvas.draw()

0 commit comments

Comments
 (0)