Skip to content

Commit 7fac6d3

Browse files
yozhikofftimhoffm
andcommitted
Apply suggestions from code review
Co-authored-by: Tim Hoffmann <2836374+timhoffm@users.noreply.github.com>
1 parent 353985a commit 7fac6d3

File tree

3 files changed

+24
-33
lines changed

3 files changed

+24
-33
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,17 @@
11
An iterable object with labels can be passed to `.Axes.plot`
22
-------------------------------------------------------------
33

4-
If multidimensional data is used for plotting, labels can be specified in
5-
a vectorized way with an iterable object of size corresponding to the
6-
data array shape (exactly 5 labels are expected when plotting 5 lines).
7-
It works with `.Axes.plot` as well as with it's wrapper `.pyplot.plot`.
4+
When plotting multiple datasets by passing 2D data as *y* value to `~.Axes.plot`, labels for the datasets can be passed as a list, the length matching the number of columns in *y*.
85

96
.. plot::
107

11-
from matplotlib import pyplot as plt
8+
import matplotlib.pyplot as plt
9+
10+
x = [1, 2, 3]
1211

13-
x = [1, 2, 5]
12+
y = [[1, 9],
13+
[2, 8],
14+
[4, 6]]
1415

15-
y = [[2, 4, 3],
16-
[4, 7, 1],
17-
[3, 9, 2]]
18-
19-
plt.plot(x, y, label=['one', 'two', 'three'])
20-
plt.legend()
16+
plt.plot(x, y, label=['low', 'high'])
17+
plt.legend()

lib/matplotlib/axes/_base.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -448,22 +448,20 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
448448
if ncx > 1 and ncy > 1 and ncx != ncy:
449449
raise ValueError(f"x has {ncx} columns but y has {ncy} columns")
450450

451-
if ('label' in kwargs and max(ncx, ncy) > 1
452-
and isinstance(kwargs['label'], Iterable)
453-
and not isinstance(kwargs['label'], str)):
454-
if len(kwargs['label']) != max(ncx, ncy):
455-
raise ValueError(f"if label is iterable label and input data"
456-
f" must have same length, but have lengths "
457-
f"{len(kwargs['label'])} and "
458-
f"{max(ncx, ncy)}")
459-
elif 'label' in kwargs:
460-
kwargs['label'] = [kwargs['label']] * max(ncx, ncy)
451+
label = kwargs.get('label')
452+
n_datasets = max(ncx, ncy)
453+
if n_datasets > 1 and not cbook.is_scalar_or_string(label):
454+
if len(label) != n_datasets:
455+
raise ValueError(f"label must be scalar or have the same "
456+
f"length as the input data, but found "
457+
f"{len(label)} for {n_datasets} datasets.")
458+
labels = label
461459
else:
462-
kwargs['label'] = [None] * max(ncx, ncy)
460+
labels = [label] * n_datasets
463461

464462
result = (func(x[:, j % ncx], y[:, j % ncy], kw,
465-
{**kwargs, 'label': kwargs['label'][j]})
466-
for j in range(max(ncx, ncy)))
463+
{**kwargs, 'label': label})
464+
for j, label in enumerate(labels))
467465

468466
if return_kwargs:
469467
return list(result)

lib/matplotlib/tests/test_legend.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -688,10 +688,8 @@ def test_plot_multiple_input_multiple_label():
688688
ax.plot(x, y, label=label)
689689
leg = ax.legend()
690690

691-
assert len(leg.get_texts()) == 3
692-
assert leg.get_texts()[0].get_text() == 'one'
693-
assert leg.get_texts()[1].get_text() == 'two'
694-
assert leg.get_texts()[2].get_text() == 'three'
691+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
692+
assert legend_texts == ['one', 'two', 'three']
695693

696694

697695
def test_plot_multiple_input_single_label():
@@ -706,10 +704,8 @@ def test_plot_multiple_input_single_label():
706704
ax.plot(x, y, label=label)
707705
leg = ax.legend()
708706

709-
assert len(leg.get_texts()) == 3
710-
assert leg.get_texts()[0].get_text() == str(label)
711-
assert leg.get_texts()[1].get_text() == str(label)
712-
assert leg.get_texts()[2].get_text() == str(label)
707+
legend_texts = [entry.get_text() for entry in leg.get_texts()]
708+
assert legend_texts == [str(label)] * 3
713709

714710

715711
def test_plot_single_input_multiple_label():

0 commit comments

Comments
 (0)