From d9bd109ef1100664ddbc2d88fe5fde39151f66e5 Mon Sep 17 00:00:00 2001 From: ImportanceOfBeingErnest Date: Mon, 23 Apr 2018 03:54:19 +0200 Subject: [PATCH] legend-for-scatter --- .flake8 | 1 + .../next_whats_new/legend_for_scatter.rst | 19 +++ .../scatter_with_legend.py | 104 +++++++++++++- lib/matplotlib/collections.py | 129 ++++++++++++++++++ lib/matplotlib/tests/test_collections.py | 52 +++++++ 5 files changed, 299 insertions(+), 6 deletions(-) create mode 100644 doc/users/next_whats_new/legend_for_scatter.rst diff --git a/.flake8 b/.flake8 index b7c36183b232..920ad42ddda8 100644 --- a/.flake8 +++ b/.flake8 @@ -141,6 +141,7 @@ per-file-ignores = examples/lines_bars_and_markers/filled_step.py: E402 examples/lines_bars_and_markers/joinstyle.py: E402 examples/lines_bars_and_markers/scatter_piecharts.py: E402 + examples/lines_bars_and_markers/scatter_with_legend.py: E402 examples/lines_bars_and_markers/span_regions.py: E402 examples/lines_bars_and_markers/step_demo.py: E402 examples/misc/agg_buffer.py: E402 diff --git a/doc/users/next_whats_new/legend_for_scatter.rst b/doc/users/next_whats_new/legend_for_scatter.rst new file mode 100644 index 000000000000..9ae428ccb6a0 --- /dev/null +++ b/doc/users/next_whats_new/legend_for_scatter.rst @@ -0,0 +1,19 @@ +:orphan: + +Legend for scatter +------------------ + +A new method for creating legends for scatter plots has been introduced. +Previously, in order to obtain a legend for a :meth:`~.axes.Axes.scatter` +plot, one could either plot several scatters, each with an individual label, +or create proxy artists to show in the legend manually. +Now, :class:`~.collections.PathCollection` provides a method +:meth:`~.collections.PathCollection.legend_elements` to obtain the handles and labels +for a scatter plot in an automated way. This makes creating a legend for a +scatter plot as easy as:: + + scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3]) + plt.legend(*scatter.legend_elements()) + +An example can be found in +:ref:`automatedlegendcreation`. diff --git a/examples/lines_bars_and_markers/scatter_with_legend.py b/examples/lines_bars_and_markers/scatter_with_legend.py index f7b7af5ac3b3..fbbbd58614ec 100644 --- a/examples/lines_bars_and_markers/scatter_with_legend.py +++ b/examples/lines_bars_and_markers/scatter_with_legend.py @@ -3,20 +3,24 @@ Scatter plots with a legend =========================== -Also demonstrates how transparency of the markers -can be adjusted by giving ``alpha`` a value between -0 and 1. +To create a scatter plot with a legend one may use a loop and create one +`~.Axes.scatter` plot per item to appear in the legend and set the ``label`` +accordingly. + +The following also demonstrates how transparency of the markers +can be adjusted by giving ``alpha`` a value between 0 and 1. """ +import numpy as np +np.random.seed(19680801) import matplotlib.pyplot as plt -from numpy.random import rand fig, ax = plt.subplots() for color in ['tab:blue', 'tab:orange', 'tab:green']: n = 750 - x, y = rand(2, n) - scale = 200.0 * rand(n) + x, y = np.random.rand(2, n) + scale = 200.0 * np.random.rand(n) ax.scatter(x, y, c=color, s=scale, label=color, alpha=0.3, edgecolors='none') @@ -24,3 +28,91 @@ ax.grid(True) plt.show() + + +############################################################################## +# .. _automatedlegendcreation: +# +# Automated legend creation +# ------------------------- +# +# Another option for creating a legend for a scatter is to use the +# :class:`~matplotlib.collections.PathCollection`'s +# :meth:`~.PathCollection.legend_elements` method. +# It will automatically try to determine a useful number of legend entries +# to be shown and return a tuple of handles and labels. Those can be passed +# to the call to :meth:`~.axes.Axes.legend`. + + +N = 45 +x, y = np.random.rand(2, N) +c = np.random.randint(1, 5, size=N) +s = np.random.randint(10, 220, size=N) + +fig, ax = plt.subplots() + +scatter = ax.scatter(x, y, c=c, s=s) + +# produce a legend with the unique colors from the scatter +legend1 = ax.legend(*scatter.legend_elements(), + loc="lower left", title="Classes") +ax.add_artist(legend1) + +# produce a legend with a cross section of sizes from the scatter +handles, labels = scatter.legend_elements(prop="sizes", alpha=0.6) +legend2 = ax.legend(handles, labels, loc="upper right", title="Sizes") + +plt.show() + + +############################################################################## +# Further arguments to the :meth:`~.PathCollection.legend_elements` method +# can be used to steer how many legend entries are to be created and how they +# should be labeled. The following shows how to use some of them. +# + +volume = np.random.rayleigh(27, size=40) +amount = np.random.poisson(10, size=40) +ranking = np.random.normal(size=40) +price = np.random.uniform(1, 10, size=40) + +fig, ax = plt.subplots() + +# Because the price is much too small when being provided as size for ``s``, +# we normalize it to some useful point sizes, s=0.3*(price*3)**2 +scatter = ax.scatter(volume, amount, c=ranking, s=0.3*(price*3)**2, + vmin=-3, vmax=3, cmap="Spectral") + +# Produce a legend for the ranking (colors). Even though there are 40 different +# rankings, we only want to show 5 of them in the legend. +legend1 = ax.legend(*scatter.legend_elements(num=5), + loc="upper left", title="Ranking") +ax.add_artist(legend1) + +# Produce a legend for the price (sizes). Because we want to show the prices +# in dollars, we use the *func* argument to supply the inverse of the function +# used to calculate the sizes from above. The *fmt* ensures to show the price +# in dollars. Note how we target at 5 elements here, but obtain only 4 in the +# created legend due to the automatic round prices that are chosen for us. +kw = dict(prop="sizes", num=5, color=scatter.cmap(0.7), fmt="$ {x:.2f}", + func=lambda s: np.sqrt(s/.3)/3) +legend2 = ax.legend(*scatter.legend_elements(**kw), + loc="lower right", title="Price") + +plt.show() + +############################################################################# +# +# ------------ +# +# References +# """""""""" +# +# The usage of the following functions and methods is shown in this example: + +import matplotlib +matplotlib.axes.Axes.scatter +matplotlib.pyplot.scatter +matplotlib.axes.Axes.legend +matplotlib.pyplot.legend +matplotlib.collections.PathCollection.legend_elements diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index cfa57b4fe607..af6f8ed12a98 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -16,6 +16,7 @@ import matplotlib as mpl from . import (_path, artist, cbook, cm, colors as mcolors, docstring, lines as mlines, path as mpath, transforms) +import warnings @cbook._define_aliases({ @@ -868,6 +869,7 @@ def draw(self, renderer): class PathCollection(_CollectionWithSizes): """ This is the most basic :class:`Collection` subclass. + A :class:`PathCollection` is e.g. created by a :meth:`~.Axes.scatter` plot. """ @docstring.dedent_interpd def __init__(self, paths, sizes=None, **kwargs): @@ -890,6 +892,133 @@ def set_paths(self, paths): def get_paths(self): return self._paths + def legend_elements(self, prop="colors", num="auto", + fmt=None, func=lambda x: x, **kwargs): + """ + Creates legend handles and labels for a PathCollection. This is useful + for obtaining a legend for a :meth:`~.Axes.scatter` plot. E.g.:: + + scatter = plt.scatter([1,2,3], [4,5,6], c=[7,2,3]) + plt.legend(*scatter.legend_elements()) + + Also see the :ref:`automatedlegendcreation` example. + + Parameters + ---------- + prop : string, optional, default *"colors"* + Can be *"colors"* or *"sizes"*. In case of *"colors"*, the legend + handles will show the different colors of the collection. In case + of "sizes", the legend will show the different sizes. + num : int, None, "auto" (default), array-like, or `~.ticker.Locator`, + optional + Target number of elements to create. + If None, use all unique elements of the mappable array. If an + integer, target to use *num* elements in the normed range. + If *"auto"*, try to determine which option better suits the nature + of the data. + The number of created elements may slightly deviate from *num* due + to a `~.ticker.Locator` being used to find useful locations. + If a list or array, use exactly those elements for the legend. + Finally, a `~.ticker.Locator` can be provided. + fmt : string, `~matplotlib.ticker.Formatter`, or None (default) + The format or formatter to use for the labels. If a string must be + a valid input for a `~.StrMethodFormatter`. If None (the default), + use a `~.ScalarFormatter`. + func : function, default *lambda x: x* + Function to calculate the labels. Often the size (or color) + argument to :meth:`~.Axes.scatter` will have been pre-processed + by the user using a function *s = f(x)* to make the markers + visible; e.g. *size = np.log10(x)*. Providing the inverse of this + function here allows that pre-processing to be inverted, so that + the legend labels have the correct values; + e.g. *func = np.exp(x, 10)*. + kwargs : further parameters + Allowed kwargs are *color* and *size*. E.g. it may be useful to + set the color of the markers if *prop="sizes"* is used; similarly + to set the size of the markers if *prop="colors"* is used. + Any further parameters are passed onto the `.Line2D` instance. + This may be useful to e.g. specify a different *markeredgecolor* or + *alpha* for the legend handles. + + Returns + ------- + tuple (handles, labels) + with *handles* being a list of `.Line2D` objects + and *labels* a matching list of strings. + """ + handles = [] + labels = [] + hasarray = self.get_array() is not None + if fmt is None: + fmt = mpl.ticker.ScalarFormatter(useOffset=False, useMathText=True) + elif isinstance(fmt, str): + fmt = mpl.ticker.StrMethodFormatter(fmt) + fmt.create_dummy_axis() + + if prop == "colors": + if not hasarray: + warnings.warn("Collection without array used. Make sure to " + "specify the values to be colormapped via the " + "`c` argument.") + return handles, labels + u = np.unique(self.get_array()) + size = kwargs.pop("size", mpl.rcParams["lines.markersize"]) + elif prop == "sizes": + u = np.unique(self.get_sizes()) + color = kwargs.pop("color", "k") + else: + raise ValueError("Valid values for `prop` are 'colors' or " + f"'sizes'. You supplied '{prop}' instead.") + + fmt.set_bounds(func(u).min(), func(u).max()) + if num == "auto": + num = 9 + if len(u) <= num: + num = None + if num is None: + values = u + label_values = func(values) + else: + if prop == "colors": + arr = self.get_array() + elif prop == "sizes": + arr = self.get_sizes() + if isinstance(num, mpl.ticker.Locator): + loc = num + elif np.iterable(num): + loc = mpl.ticker.FixedLocator(num) + else: + num = int(num) + loc = mpl.ticker.MaxNLocator(nbins=num, min_n_ticks=num-1, + steps=[1, 2, 2.5, 3, 5, 6, 8, 10]) + label_values = loc.tick_values(func(arr).min(), func(arr).max()) + cond = ((label_values >= func(arr).min()) & + (label_values <= func(arr).max())) + label_values = label_values[cond] + xarr = np.linspace(arr.min(), arr.max(), 256) + values = np.interp(label_values, func(xarr), xarr) + + kw = dict(markeredgewidth=self.get_linewidths()[0], + alpha=self.get_alpha()) + kw.update(kwargs) + + for val, lab in zip(values, label_values): + if prop == "colors": + color = self.cmap(self.norm(val)) + elif prop == "sizes": + size = np.sqrt(val) + if np.isclose(size, 0.0): + continue + h = mlines.Line2D([0], [0], ls="", color=color, ms=size, + marker=self.get_paths()[0], **kw) + handles.append(h) + if hasattr(fmt, "set_locs"): + fmt.set_locs(label_values) + l = fmt(lab) + labels.append(l) + + return handles, labels + class PolyCollection(_CollectionWithSizes): @docstring.dedent_interpd diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 6812ee1ad427..c12ed7004d61 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -669,3 +669,55 @@ def test_scatter_post_alpha(): # this needs to be here to update internal state fig.canvas.draw() sc.set_alpha(.1) + + +def test_pathcollection_legend_elements(): + np.random.seed(19680801) + x, y = np.random.rand(2, 10) + y = np.random.rand(10) + c = np.random.randint(0, 5, size=10) + s = np.random.randint(10, 300, size=10) + + fig, ax = plt.subplots() + sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0) + + h, l = sc.legend_elements(fmt="{x:g}") + assert len(h) == 5 + assert_array_equal(np.array(l).astype(float), np.arange(5)) + colors = np.array([line.get_color() for line in h]) + colors2 = sc.cmap(np.arange(5)/4) + assert_array_equal(colors, colors2) + l1 = ax.legend(h, l, loc=1) + + h2, lab2 = sc.legend_elements(num=9) + assert len(h2) == 9 + l2 = ax.legend(h2, lab2, loc=2) + + h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red") + alpha = np.array([line.get_alpha() for line in h]) + assert_array_equal(alpha, 0.5) + color = np.array([line.get_markerfacecolor() for line in h]) + assert_array_equal(color, "red") + l3 = ax.legend(h, l, loc=4) + + h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}", + func=lambda x: 2*x) + actsizes = [line.get_markersize() for line in h] + labeledsizes = np.sqrt(np.array(l).astype(float)/2) + assert_array_almost_equal(actsizes, labeledsizes) + l4 = ax.legend(h, l, loc=3) + + import matplotlib.ticker as mticker + loc = mticker.MaxNLocator(nbins=9, min_n_ticks=9-1, + steps=[1, 2, 2.5, 3, 5, 6, 8, 10]) + h5, lab5 = sc.legend_elements(num=loc) + assert len(h2) == len(h5) + + levels = [-1, 0, 55.4, 260] + h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}") + assert_array_equal(np.array(lab6).astype(float), levels[2:]) + + for l in [l1, l2, l3, l4]: + ax.add_artist(l) + + fig.canvas.draw()