Skip to content

POC: make scaling optional in Collection.get_linestyle #29304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions galleries/examples/text_labels_and_annotations/legend_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,14 @@ def create_artists(self, legend, orig_handle,
except IndexError:
color = orig_handle.get_colors()[0]
try:
dashes = orig_handle.get_dashes()[i]
linestyle = orig_handle.get_linestyle(scaled=False)[i]
except IndexError:
dashes = orig_handle.get_dashes()[0]
linestyle = orig_handle.get_linestyle(scaled=False)[0]
try:
lw = orig_handle.get_linewidths()[i]
except IndexError:
lw = orig_handle.get_linewidths()[0]
if dashes[1] is not None:
legline.set_dashes(dashes[1])
legline.set_linestyle(linestyle)
legline.set_color(color)
legline.set_transform(trans)
legline.set_linewidth(lw)
Expand Down
12 changes: 10 additions & 2 deletions lib/matplotlib/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,11 @@ def set_linestyle(self, ls):
':', '', (offset, on-off-seq)}. See `.Line2D.set_linestyle` for a
complete description.
"""
if isinstance(ls, (str, tuple)):
self._original_linestyle = [ls]
Copy link
Member Author

@rcomer rcomer Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got a lot more failures if I did not make sure this is a list. I note that Collection.get_facecolor always returns a sequence of colours, having applied to_rgba_array.

else:
self._original_linestyle = ls

try:
dashes = [mlines._get_dash_pattern(ls)]
except ValueError:
Expand Down Expand Up @@ -866,8 +871,11 @@ def set_alpha(self, alpha):
def get_linewidth(self):
return self._linewidths

def get_linestyle(self):
return self._linestyles
def get_linestyle(self, scaled=False):
if scaled:
return self._linestyles

return self._original_linestyle

def _set_mappable_flags(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/collections.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Collection(colorizer.ColorizingArtist):
def set_edgecolor(self, c: ColorType | Sequence[ColorType]) -> None: ...
def set_alpha(self, alpha: float | Sequence[float] | None) -> None: ...
def get_linewidth(self) -> float | Sequence[float]: ...
def get_linestyle(self) -> LineStyleType | Sequence[LineStyleType]: ...
def get_linestyle(self, scaled: bool = ...) -> LineStyleType | Sequence[LineStyleType]: ...
def update_scalarmappable(self) -> None: ...
def get_fill(self) -> bool: ...
def update_from(self, other: Artist) -> None: ...
Expand Down
5 changes: 3 additions & 2 deletions lib/matplotlib/legend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def get_numpoints(self, legend):

def _default_update_prop(self, legend_handle, orig_handle):
lw = orig_handle.get_linewidths()[0]
dashes = orig_handle._us_linestyles[0]
dashes = orig_handle.get_linestyles(scaled=False)[0]
color = orig_handle.get_colors()[0]
legend_handle.set_color(color)
legend_handle.set_linestyle(dashes)
Expand Down Expand Up @@ -798,7 +798,8 @@ def get_first(prop_array):
legend_handle._hatch_color = orig_handle._hatch_color
# Setters are fine for the remaining attributes.
legend_handle.set_linewidth(get_first(orig_handle.get_linewidths()))
legend_handle.set_linestyle(get_first(orig_handle.get_linestyles()))
legend_handle.set_linestyle(
get_first(orig_handle.get_linestyles(scaled=False)))
legend_handle.set_transform(get_first(orig_handle.get_transforms()))
legend_handle.set_figure(orig_handle.get_figure())
# Alpha is already taken into account by the color attributes.
Expand Down
8 changes: 4 additions & 4 deletions lib/matplotlib/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test__EventCollection__get_props():
# check that the default lineoffset matches the input lineoffset
assert props['lineoffset'] == coll.get_lineoffset()
# check that the default linestyle matches the input linestyle
assert coll.get_linestyle() == [(0, None)]
assert coll.get_linestyle(scaled=True) == [(0, None)]
# check that the default color matches the input color
for color in [coll.get_color(), *coll.get_colors()]:
np.testing.assert_array_equal(color, props['color'])
Expand Down Expand Up @@ -248,7 +248,7 @@ def test__EventCollection__set_lineoffset():
])
def test__EventCollection__set_prop():
for prop, value, expected in [
('linestyle', 'dashed', [(0, (6.0, 6.0))]),
('linestyle', 'dashed', ['dashed']),
('linestyle', (0, (6., 6.)), [(0, (6.0, 6.0))]),
('linewidth', 5, 5),
]:
Expand Down Expand Up @@ -666,11 +666,11 @@ def test_lslw_bcast():
col.set_linestyles(['-', '-'])
col.set_linewidths([1, 2, 3])

assert col.get_linestyles() == [(0, None)] * 6
assert col.get_linestyles(scaled=True) == [(0, None)] * 6
assert col.get_linewidths() == [1, 2, 3] * 2

col.set_linestyles(['-', '-', '-'])
assert col.get_linestyles() == [(0, None)] * 3
assert col.get_linestyles(scaled=True) == [(0, None)] * 3
assert (col.get_linewidths() == [1, 2, 3]).all()


Expand Down
12 changes: 11 additions & 1 deletion lib/matplotlib/tests/test_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,16 @@ def test_legend_stackplot():
ax.legend(loc='best')


@mpl.style.context('default')
def test_polycollection_linestyles_unscaled():
fig, ax = plt.subplots()
q = ax.quiver(0, 0, 7, 1, label='v1 + v2', linewidth=3, linestyle='dotted')
leg = ax.legend()
handle, = leg.legend_handles

assert q.get_linestyle(scaled=False)[0] == handle.get_linestyle()


def test_cross_figure_patch_legend():
fig, ax = plt.subplots()
fig2, ax2 = plt.subplots()
Expand Down Expand Up @@ -612,7 +622,7 @@ def test_linecollection_scaled_dashes():
h1, h2, h3 = leg.legend_handles

for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)):
assert oh.get_linestyles()[0] == lh._dash_pattern
assert oh.get_linestyles()[0] == lh.get_linestyle()


def test_handler_numpoints():
Expand Down
2 changes: 1 addition & 1 deletion lib/mpl_toolkits/mplot3d/tests/test_legend3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_linecollection_scaled_dashes():
h1, h2, h3 = leg.legend_handles

for oh, lh in zip((lc1, lc2, lc3), (h1, h2, h3)):
assert oh.get_linestyles()[0] == lh._dash_pattern
assert oh.get_linestyles()[0] == lh.get_linestyle()


def test_handlerline3d():
Expand Down
Loading