|
1 | 1 | from collections import OrderedDict
|
| 2 | +from collections.abc import Iterable |
2 | 3 | from contextlib import ExitStack
|
3 | 4 | import functools
|
4 | 5 | import inspect
|
@@ -446,8 +447,22 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
|
446 | 447 | ncx, ncy = x.shape[1], y.shape[1]
|
447 | 448 | if ncx > 1 and ncy > 1 and ncx != ncy:
|
448 | 449 | raise ValueError(f"x has {ncx} columns but y has {ncy} columns")
|
| 450 | + |
| 451 | + if ('label' in kwargs and isinstance(kwargs['label'], Iterable) |
| 452 | + and not isinstance(kwargs['label'], str)): |
| 453 | + if len(kwargs['label']) != max(ncx, ncy): |
| 454 | + raise ValueError(f"if label is iterable label and input data" |
| 455 | + f" must have same length, but have lengths " |
| 456 | + f"{len(kwargs['label'])} and " |
| 457 | + f"{max(ncx, ncy)}") |
| 458 | + |
| 459 | + result = (func(x[:, j % ncx], y[:, j % ncy], kw, |
| 460 | + {**kwargs, 'label':kwargs['label'][j]}) |
| 461 | + for j in range(max(ncx, ncy))) |
| 462 | + |
449 | 463 | result = (func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
|
450 |
| - for j in range(max(ncx, ncy))) |
| 464 | + for j in range(max(ncx, ncy))) |
| 465 | + |
451 | 466 | if return_kwargs:
|
452 | 467 | return list(result)
|
453 | 468 | else:
|
|
0 commit comments