diff --git a/doc/users/next_whats_new/bar_plot_labels.rst b/doc/users/next_whats_new/bar_plot_labels.rst new file mode 100644 index 000000000000..6da57a317f59 --- /dev/null +++ b/doc/users/next_whats_new/bar_plot_labels.rst @@ -0,0 +1,16 @@ +Easier labelling of bars in bar plot +------------------------------------ + +The ``label`` argument of `~matplotlib.axes.Axes.bar` can now +be passed a list of labels for the bars. + +.. code-block:: python + + import matplotlib.pyplot as plt + + x = ["a", "b", "c"] + y = [10, 20, 15] + + fig, ax = plt.subplots() + bar_container = ax.barh(x, y, label=x) + [bar.get_label() for bar in bar_container] diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index 25b3cf36651f..3159a55b7e83 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -2256,6 +2256,14 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", The tick labels of the bars. Default: None (Use default numeric labels.) + label : str or list of str, optional + A single label is attached to the resulting `.BarContainer` as a + label for the whole dataset. + If a list is provided, it must be the same length as *x* and + labels the individual bars. Repeated labels are not de-duplicated + and will cause repeated label entries, so this is best used when + bars also differ in style (e.g., by passing a list to *color*.) + xerr, yerr : float or array-like of shape(N,) or shape(2, N), optional If not *None*, add horizontal / vertical errorbars to the bar tips. The values are +/- sizes relative to the data: @@ -2381,6 +2389,16 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", tick_label_axis = self.yaxis tick_label_position = y + if not isinstance(label, str) and np.iterable(label): + bar_container_label = '_nolegend_' + patch_labels = label + else: + bar_container_label = label + patch_labels = ['_nolegend_'] * len(x) + if len(patch_labels) != len(x): + raise ValueError(f'number of labels ({len(patch_labels)}) ' + f'does not match number of bars ({len(x)}).') + linewidth = itertools.cycle(np.atleast_1d(linewidth)) hatch = itertools.cycle(np.atleast_1d(hatch)) color = itertools.chain(itertools.cycle(mcolors.to_rgba_array(color)), @@ -2420,14 +2438,14 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", patches = [] args = zip(left, bottom, width, height, color, edgecolor, linewidth, - hatch) - for l, b, w, h, c, e, lw, htch in args: + hatch, patch_labels) + for l, b, w, h, c, e, lw, htch, lbl in args: r = mpatches.Rectangle( xy=(l, b), width=w, height=h, facecolor=c, edgecolor=e, linewidth=lw, - label='_nolegend_', + label=lbl, hatch=htch, ) r._internal_update(kwargs) @@ -2466,7 +2484,8 @@ def bar(self, x, height, width=0.8, bottom=None, *, align="center", datavalues = width bar_container = BarContainer(patches, errorbar, datavalues=datavalues, - orientation=orientation, label=label) + orientation=orientation, + label=bar_container_label) self.add_container(bar_container) if tick_labels is not None: diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 51fbea3cfaaa..3a1ba341b2e0 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -1886,6 +1886,35 @@ def test_bar_hatches(fig_test, fig_ref): ax_test.bar(x, y, hatch=hatches) +@pytest.mark.parametrize( + ("x", "width", "label", "expected_labels", "container_label"), + [ + ("x", 1, "x", ["_nolegend_"], "x"), + (["a", "b", "c"], [10, 20, 15], ["A", "B", "C"], + ["A", "B", "C"], "_nolegend_"), + (["a", "b", "c"], [10, 20, 15], ["R", "Y", "_nolegend_"], + ["R", "Y", "_nolegend_"], "_nolegend_"), + (["a", "b", "c"], [10, 20, 15], "bars", + ["_nolegend_", "_nolegend_", "_nolegend_"], "bars"), + ] +) +def test_bar_labels(x, width, label, expected_labels, container_label): + _, ax = plt.subplots() + bar_container = ax.bar(x, width, label=label) + bar_labels = [bar.get_label() for bar in bar_container] + assert expected_labels == bar_labels + assert bar_container.get_label() == container_label + + +def test_bar_labels_length(): + _, ax = plt.subplots() + with pytest.raises(ValueError): + ax.bar(["x", "y"], [1, 2], label=["X", "Y", "Z"]) + _, ax = plt.subplots() + with pytest.raises(ValueError): + ax.bar(["x", "y"], [1, 2], label=["X"]) + + def test_pandas_minimal_plot(pd): # smoke test that series and index objects do not warn for x in [pd.Series([1, 2], dtype="float64"),