From d7513a7d57df9ee8c33fcfc5eeec679ac2deb312 Mon Sep 17 00:00:00 2001 From: Ruth Comer <10599679+rcomer@users.noreply.github.com> Date: Fri, 13 Dec 2024 20:12:57 +0000 Subject: [PATCH] POC: make scaling optional in Collection.get_linestyle --- .../text_labels_and_annotations/legend_demo.py | 7 +++---- lib/matplotlib/collections.py | 12 ++++++++++-- lib/matplotlib/collections.pyi | 2 +- lib/matplotlib/legend_handler.py | 5 +++-- lib/matplotlib/tests/test_collections.py | 8 ++++---- lib/matplotlib/tests/test_legend.py | 12 +++++++++++- lib/mpl_toolkits/mplot3d/tests/test_legend3d.py | 2 +- 7 files changed, 33 insertions(+), 15 deletions(-) diff --git a/galleries/examples/text_labels_and_annotations/legend_demo.py b/galleries/examples/text_labels_and_annotations/legend_demo.py index 2f550729837e..6680045830bf 100644 --- a/galleries/examples/text_labels_and_annotations/legend_demo.py +++ b/galleries/examples/text_labels_and_annotations/legend_demo.py @@ -145,15 +145,14 @@ def create_artists(self, legend, orig_handle, except IndexError: color = orig_handle.get_colors()[0] try: - dashes = orig_handle.get_dashes()[i] + linestyle = orig_handle.get_linestyle(scaled=False)[i] except IndexError: - dashes = orig_handle.get_dashes()[0] + linestyle = orig_handle.get_linestyle(scaled=False)[0] try: lw = orig_handle.get_linewidths()[i] except IndexError: lw = orig_handle.get_linewidths()[0] - if dashes[1] is not None: - legline.set_dashes(dashes[1]) + legline.set_linestyle(linestyle) legline.set_color(color) legline.set_transform(trans) legline.set_linewidth(lw) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index f18d5a4c3a8c..2b9ced21a201 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -626,6 +626,11 @@ def set_linestyle(self, ls): ':', '', (offset, on-off-seq)}. See `.Line2D.set_linestyle` for a complete description. """ + if isinstance(ls, (str, tuple)): + self._original_linestyle = [ls] + else: + self._original_linestyle = ls + try: dashes = [mlines._get_dash_pattern(ls)] except ValueError: @@ -866,8 +871,11 @@ def set_alpha(self, alpha): def get_linewidth(self): return self._linewidths - def get_linestyle(self): - return self._linestyles + def get_linestyle(self, scaled=False): + if scaled: + return self._linestyles + + return self._original_linestyle def _set_mappable_flags(self): """ diff --git a/lib/matplotlib/collections.pyi b/lib/matplotlib/collections.pyi index 0805adef4293..9f2bef3a99fa 100644 --- a/lib/matplotlib/collections.pyi +++ b/lib/matplotlib/collections.pyi @@ -68,7 +68,7 @@ class Collection(colorizer.ColorizingArtist): def set_edgecolor(self, c: ColorType | Sequence[ColorType]) -> None: ... def set_alpha(self, alpha: float | Sequence[float] | None) -> None: ... def get_linewidth(self) -> float | Sequence[float]: ... - def get_linestyle(self) -> LineStyleType | Sequence[LineStyleType]: ... + def get_linestyle(self, scaled: bool = ...) -> LineStyleType | Sequence[LineStyleType]: ... def update_scalarmappable(self) -> None: ... def get_fill(self) -> bool: ... def update_from(self, other: Artist) -> None: ... diff --git a/lib/matplotlib/legend_handler.py b/lib/matplotlib/legend_handler.py index 97076ad09cb8..75b89c671eca 100644 --- a/lib/matplotlib/legend_handler.py +++ b/lib/matplotlib/legend_handler.py @@ -407,7 +407,7 @@ def get_numpoints(self, legend): def _default_update_prop(self, legend_handle, orig_handle): lw = orig_handle.get_linewidths()[0] - dashes = orig_handle._us_linestyles[0] + dashes = orig_handle.get_linestyles(scaled=False)[0] color = orig_handle.get_colors()[0] legend_handle.set_color(color) legend_handle.set_linestyle(dashes) @@ -798,7 +798,8 @@ def get_first(prop_array): legend_handle._hatch_color = orig_handle._hatch_color # Setters are fine for the remaining attributes. legend_handle.set_linewidth(get_first(orig_handle.get_linewidths())) - legend_handle.set_linestyle(get_first(orig_handle.get_linestyles())) + legend_handle.set_linestyle( + get_first(orig_handle.get_linestyles(scaled=False))) legend_handle.set_transform(get_first(orig_handle.get_transforms())) legend_handle.set_figure(orig_handle.get_figure()) # Alpha is already taken into account by the color attributes. diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 11934cfca2c3..3622002c3b50 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -86,7 +86,7 @@ def test__EventCollection__get_props(): # check that the default lineoffset matches the input lineoffset assert props['lineoffset'] == coll.get_lineoffset() # check that the default linestyle matches the input linestyle - assert coll.get_linestyle() == [(0, None)] + assert coll.get_linestyle(scaled=True) == [(0, None)] # check that the default color matches the input color for color in [coll.get_color(), *coll.get_colors()]: np.testing.assert_array_equal(color, props['color']) @@ -248,7 +248,7 @@ def test__EventCollection__set_lineoffset(): ]) def test__EventCollection__set_prop(): for prop, value, expected in [ - ('linestyle', 'dashed', [(0, (6.0, 6.0))]), + ('linestyle', 'dashed', ['dashed']), ('linestyle', (0, (6., 6.)), [(0, (6.0, 6.0))]), ('linewidth', 5, 5), ]: @@ -666,11 +666,11 @@ def test_lslw_bcast(): col.set_linestyles(['-', '-']) col.set_linewidths([1, 2, 3]) - assert col.get_linestyles() == [(0, None)] * 6 + assert col.get_linestyles(scaled=True) == [(0, None)] * 6 assert col.get_linewidths() == [1, 2, 3] * 2 col.set_linestyles(['-', '-', '-']) - assert col.get_linestyles() == [(0, None)] * 3 + assert col.get_linestyles(scaled=True) == [(0, None)] * 3 assert (col.get_linewidths() == [1, 2, 3]).all() diff --git a/lib/matplotlib/tests/test_legend.py b/lib/matplotlib/tests/test_legend.py index 61892378bd03..21ccd6aa456a 100644 --- a/lib/matplotlib/tests/test_legend.py +++ b/lib/matplotlib/tests/test_legend.py @@ -528,6 +528,16 @@ def test_legend_stackplot(): ax.legend(loc='best') +@mpl.style.context('default') +def test_polycollection_linestyles_unscaled(): + fig, ax = plt.subplots() + q = ax.quiver(0, 0, 7, 1, label='v1 + v2', linewidth=3, linestyle='dotted') + leg = ax.legend() + handle, = leg.legend_handles + + assert q.get_linestyle(scaled=False)[0] == handle.get_linestyle() + + def test_cross_figure_patch_legend(): fig, ax = plt.subplots() fig2, ax2 = plt.subplots() @@ -612,7 +622,7 @@ def test_linecollection_scaled_dashes(): h1, h2, h3 = leg.legend_handles for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)): - assert oh.get_linestyles()[0] == lh._dash_pattern + assert oh.get_linestyles()[0] == lh.get_linestyle() def test_handler_numpoints(): diff --git a/lib/mpl_toolkits/mplot3d/tests/test_legend3d.py b/lib/mpl_toolkits/mplot3d/tests/test_legend3d.py index 0935bbe7f6b0..e0d3395c68da 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_legend3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_legend3d.py @@ -55,7 +55,7 @@ def test_linecollection_scaled_dashes(): h1, h2, h3 = leg.legend_handles for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)): - assert oh.get_linestyles()[0] == lh._dash_pattern + assert oh.get_linestyles()[0] == lh.get_linestyle() def test_handlerline3d():