Skip to content

Let collections return linewidths "as is", without cycling. #26043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 22 additions & 41 deletions lib/matplotlib/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,8 @@ def __init__(self, *,
"""
artist.Artist.__init__(self)
cm.ScalarMappable.__init__(self, norm, cmap)
# list of un-scaled dash patterns
# this is needed scaling the dash pattern by linewidth
self._us_linestyles = [(0, None)]
# list of dash patterns
self._linestyles = [(0, None)]
# list of unbroadcast/scaled linewidths
self._us_lw = [0]
self._unscaled_linestyles = [(0, None)] # dash patterns prior to scaling by lw
self._linestyles = [(0, None)] # dash patterns scaled by lw
self._linewidths = [0]

self._gapcolor = None # Currently only used by LineCollection.
Expand Down Expand Up @@ -578,12 +573,9 @@ def set_linewidth(self, lw):
"""
if lw is None:
lw = self._get_default_linewidth()
# get the un-scaled/broadcast lw
self._us_lw = np.atleast_1d(lw)

# scale all of the dash patterns.
self._linewidths, self._linestyles = self._bcast_lwls(
self._us_lw, self._us_linestyles)
self._linewidths = np.atleast_1d(lw)
self._linestyles = self._compute_scaled_linestyles(
self._unscaled_linestyles, self._linewidths)
self.stale = True

def set_linestyle(self, ls):
Expand Down Expand Up @@ -620,13 +612,9 @@ def set_linestyle(self, ls):
except ValueError as err:
emsg = f'Do not know how to convert {ls!r} to dashes'
raise ValueError(emsg) from err

# get the list of raw 'unscaled' dash patterns
self._us_linestyles = dashes

# broadcast and scale the lw and dash patterns
self._linewidths, self._linestyles = self._bcast_lwls(
self._us_lw, self._us_linestyles)
self._unscaled_linestyles = dashes # raw 'unscaled' dash patterns
self._linestyles = self._compute_scaled_linestyles(
self._unscaled_linestyles, self._linewidths)

@_docstring.interpd
def set_capstyle(self, cs):
Expand Down Expand Up @@ -657,42 +645,35 @@ def get_joinstyle(self):
return self._joinstyle.name

@staticmethod
def _bcast_lwls(linewidths, dashes):
def _compute_scaled_linestyles(unscaled_linestyles, linewidths):
"""
Internal helper function to broadcast + scale ls/lw
Internal helper function to scale linestyles by linewidths.

In the collection drawing code, the linewidth and linestyle are cycled
through as circular buffers (via ``v[i % len(v)]``). Thus, if we are
going to scale the dash pattern at set time (not draw time) we need to
do the broadcasting now and expand both lists to be the same length.
do the cycling now and expand both lists to be the same length.

Parameters
----------
unscaled_linestyles
dash specification (offset, (dash pattern tuple))
linewidths : list
line widths of collection
dashes : list
dash specification (offset, (dash pattern tuple))

Returns
-------
linewidths, dashes : list
dashes : list
Will be the same length, dashes are scaled by paired linewidth
"""
if mpl.rcParams['_internal.classic_mode']:
return linewidths, dashes
# make sure they are the same length so we can zip them
if len(dashes) != len(linewidths):
l_dashes = len(dashes)
l_lw = len(linewidths)
gcd = math.gcd(l_dashes, l_lw)
dashes = list(dashes) * (l_lw // gcd)
linewidths = list(linewidths) * (l_dashes // gcd)

# scale the dash patterns
dashes = [mlines._scale_dashes(o, d, lw)
for (o, d), lw in zip(dashes, linewidths)]

return linewidths, dashes
return unscaled_linestyles
n_ls = len(unscaled_linestyles)
n_lw = len(linewidths)
return [mlines._scale_dashes(o, d, lw) for (o, d), lw in zip(
# How many cycles do we need?
list(unscaled_linestyles) * (n_lw // math.gcd(n_ls, n_lw)),
itertools.cycle(linewidths))]

def set_antialiased(self, aa):
"""
Expand Down Expand Up @@ -919,7 +900,7 @@ def update_from(self, other):
self._facecolors = other._facecolors
self._linewidths = other._linewidths
self._linestyles = other._linestyles
self._us_linestyles = other._us_linestyles
self._unscaled_linestyles = other._unscaled_linestyles
self._pickradius = other._pickradius
self._hatch = other._hatch

Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/legend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def get_numpoints(self, legend):

def _default_update_prop(self, legend_handle, orig_handle):
lw = orig_handle.get_linewidths()[0]
dashes = orig_handle._us_linestyles[0]
dashes = orig_handle._unscaled_linestyles[0]
color = orig_handle.get_colors()[0]
legend_handle.set_color(color)
legend_handle.set_linestyle(dashes)
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_lslw_bcast():
col.set_linewidths([1, 2, 3])

assert col.get_linestyles() == [(0, None)] * 6
assert col.get_linewidths() == [1, 2, 3] * 2
assert (col.get_linewidths() == [1, 2, 3]).all()

col.set_linestyles(['-', '-', '-'])
assert col.get_linestyles() == [(0, None)] * 3
Expand Down