From 4b99bc6c72add8b5535feec7154d6a5c160e1976 Mon Sep 17 00:00:00 2001 From: Elliott Sales de Andrade Date: Wed, 3 May 2023 19:50:07 -0400 Subject: [PATCH] Fix default return of Collection.get_{cap,join}style If neither are specified at object creation, the default is to be `None`. This broke `get_{cap,join}style` when the enum wrappers were created as they assume the internal value is always an enum value. --- lib/matplotlib/collections.py | 20 ++++++++++++++++++-- lib/matplotlib/collections.pyi | 4 ++-- lib/matplotlib/tests/test_collections.py | 4 ++++ 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 68140194942b..a6e4af5af11f 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -641,8 +641,16 @@ def set_capstyle(self, cs): """ self._capstyle = CapStyle(cs) + @_docstring.interpd def get_capstyle(self): - return self._capstyle.name + """ + Return the cap style for the collection (for all its elements). + + Returns + ------- + %(CapStyle)s or None + """ + return self._capstyle.name if self._capstyle else None @_docstring.interpd def set_joinstyle(self, js): @@ -655,8 +663,16 @@ def set_joinstyle(self, js): """ self._joinstyle = JoinStyle(js) + @_docstring.interpd def get_joinstyle(self): - return self._joinstyle.name + """ + Return the join style for the collection (for all its elements). + + Returns + ------- + %(JoinStyle)s or None + """ + return self._joinstyle.name if self._joinstyle else None @staticmethod def _bcast_lwls(linewidths, dashes): diff --git a/lib/matplotlib/collections.pyi b/lib/matplotlib/collections.pyi index 25eece49d755..2bb7d38369b5 100644 --- a/lib/matplotlib/collections.pyi +++ b/lib/matplotlib/collections.pyi @@ -51,9 +51,9 @@ class Collection(artist.Artist, cm.ScalarMappable): def set_linewidth(self, lw: float | Sequence[float]) -> None: ... def set_linestyle(self, ls: LineStyleType | Sequence[LineStyleType]) -> None: ... def set_capstyle(self, cs: CapStyleType) -> None: ... - def get_capstyle(self) -> Literal["butt", "projecting", "round"]: ... + def get_capstyle(self) -> Literal["butt", "projecting", "round"] | None: ... def set_joinstyle(self, js: JoinStyleType) -> None: ... - def get_joinstyle(self) -> Literal["miter", "round", "bevel"]: ... + def get_joinstyle(self) -> Literal["miter", "round", "bevel"] | None: ... def set_antialiased(self, aa: bool | Sequence[bool]) -> None: ... def set_color(self, c: ColorType | Sequence[ColorType]) -> None: ... def set_facecolor(self, c: ColorType | Sequence[ColorType]) -> None: ... diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 2a1002b6df59..43bbea34a2e5 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -625,6 +625,8 @@ def test_set_wrong_linestyle(): @mpl.style.context('default') def test_capstyle(): + col = mcollections.PathCollection([]) + assert col.get_capstyle() is None col = mcollections.PathCollection([], capstyle='round') assert col.get_capstyle() == 'round' col.set_capstyle('butt') @@ -633,6 +635,8 @@ def test_capstyle(): @mpl.style.context('default') def test_joinstyle(): + col = mcollections.PathCollection([]) + assert col.get_joinstyle() is None col = mcollections.PathCollection([], joinstyle='round') assert col.get_joinstyle() == 'round' col.set_joinstyle('miter')