Skip to content

Commit 1ed9d14

Browse files
committed
MNT: Add a mixin MeshData class for Collections
This adds a private _MeshData mixin class to help handle mesh coordinates and array data validation in a common place.
1 parent a6615e1 commit 1ed9d14

File tree

2 files changed

+196
-64
lines changed

2 files changed

+196
-64
lines changed

lib/matplotlib/collections.py

+133-40
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ def update_scalarmappable(self):
870870
# Allow possibility to call 'self.set_array(None)'.
871871
if self._A is not None:
872872
# QuadMesh can map 2d arrays (but pcolormesh supplies 1d array)
873-
if self._A.ndim > 1 and not isinstance(self, QuadMesh):
873+
if self._A.ndim > 1 and not isinstance(self, _MeshData):
874874
raise ValueError('Collections can only map rank 1 arrays')
875875
if np.iterable(self._alpha):
876876
if self._alpha.size != self._A.size:
@@ -1889,9 +1889,11 @@ def draw(self, renderer):
18891889
renderer.close_group(self.__class__.__name__)
18901890

18911891

1892-
class QuadMesh(Collection):
1892+
class _MeshData:
18931893
r"""
1894-
Class for the efficient drawing of a quadrilateral mesh.
1894+
Class for managing the two dimensional coordinates of Quadrilateral meshes
1895+
and the associated data with them. This class is a mixin and is intended to
1896+
be used with another collection that will implement the draw separately.
18951897
18961898
A quadrilateral mesh is a grid of M by N adjacent quadrilaterals that are
18971899
defined via a (M+1, N+1) grid of vertices. The quadrilateral (m, n) is
@@ -1911,42 +1913,12 @@ class QuadMesh(Collection):
19111913
The vertices. ``coordinates[m, n]`` specifies the (x, y) coordinates
19121914
of vertex (m, n).
19131915
1914-
antialiased : bool, default: True
1915-
19161916
shading : {'flat', 'gouraud'}, default: 'flat'
1917-
1918-
Notes
1919-
-----
1920-
Unlike other `.Collection`\s, the default *pickradius* of `.QuadMesh` is 0,
1921-
i.e. `~.Artist.contains` checks whether the test point is within any of the
1922-
mesh quadrilaterals.
1923-
19241917
"""
1925-
1926-
def __init__(self, coordinates, *, antialiased=True, shading='flat',
1927-
**kwargs):
1928-
kwargs.setdefault("pickradius", 0)
1929-
# end of signature deprecation code
1930-
1918+
def __init__(self, coordinates, *, shading='flat'):
19311919
_api.check_shape((None, None, 2), coordinates=coordinates)
19321920
self._coordinates = coordinates
1933-
self._antialiased = antialiased
19341921
self._shading = shading
1935-
self._bbox = transforms.Bbox.unit()
1936-
self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2))
1937-
# super init delayed after own init because array kwarg requires
1938-
# self._coordinates and self._shading
1939-
super().__init__(**kwargs)
1940-
self.set_mouseover(False)
1941-
1942-
def get_paths(self):
1943-
if self._paths is None:
1944-
self.set_paths()
1945-
return self._paths
1946-
1947-
def set_paths(self):
1948-
self._paths = self._convert_mesh_to_paths(self._coordinates)
1949-
self.stale = True
19501922

19511923
def set_array(self, A):
19521924
"""
@@ -1985,9 +1957,6 @@ def set_array(self, A):
19851957
f"{' or '.join(map(str, ok_shapes))}, not {A.shape}")
19861958
return super().set_array(A)
19871959

1988-
def get_datalim(self, transData):
1989-
return (self.get_transform() - transData).transform_bbox(self._bbox)
1990-
19911960
def get_coordinates(self):
19921961
"""
19931962
Return the vertices of the mesh as an (M+1, N+1, 2) array.
@@ -1998,6 +1967,18 @@ def get_coordinates(self):
19981967
"""
19991968
return self._coordinates
20001969

1970+
def get_edgecolor(self):
1971+
# docstring inherited
1972+
# Note that we want to return an array of shape (N*M, 4)
1973+
# a flattened RGBA collection
1974+
return super().get_edgecolor().reshape(-1, 4)
1975+
1976+
def get_facecolor(self):
1977+
# docstring inherited
1978+
# Note that we want to return an array of shape (N*M, 4)
1979+
# a flattened RGBA collection
1980+
return super().get_facecolor().reshape(-1, 4)
1981+
20011982
@staticmethod
20021983
def _convert_mesh_to_paths(coordinates):
20031984
"""
@@ -2057,6 +2038,64 @@ def _convert_mesh_to_triangles(self, coordinates):
20572038

20582039
return triangles, colors
20592040

2041+
2042+
class QuadMesh(_MeshData, Collection):
2043+
r"""
2044+
Class for the efficient drawing of a quadrilateral mesh.
2045+
2046+
A quadrilateral mesh is a grid of M by N adjacent quadrilaterals that are
2047+
defined via a (M+1, N+1) grid of vertices. The quadrilateral (m, n) is
2048+
defined by the vertices ::
2049+
2050+
(m+1, n) ----------- (m+1, n+1)
2051+
/ /
2052+
/ /
2053+
/ /
2054+
(m, n) -------- (m, n+1)
2055+
2056+
The mesh need not be regular and the polygons need not be convex.
2057+
2058+
Parameters
2059+
----------
2060+
coordinates : (M+1, N+1, 2) array-like
2061+
The vertices. ``coordinates[m, n]`` specifies the (x, y) coordinates
2062+
of vertex (m, n).
2063+
2064+
antialiased : bool, default: True
2065+
2066+
shading : {'flat', 'gouraud'}, default: 'flat'
2067+
2068+
Notes
2069+
-----
2070+
Unlike other `.Collection`\s, the default *pickradius* of `.QuadMesh` is 0,
2071+
i.e. `~.Artist.contains` checks whether the test point is within any of the
2072+
mesh quadrilaterals.
2073+
2074+
"""
2075+
2076+
def __init__(self, coordinates, *, antialiased=True, shading='flat',
2077+
**kwargs):
2078+
kwargs.setdefault("pickradius", 0)
2079+
super().__init__(coordinates=coordinates, shading=shading)
2080+
Collection.__init__(self, **kwargs)
2081+
2082+
self._antialiased = antialiased
2083+
self._bbox = transforms.Bbox.unit()
2084+
self._bbox.update_from_data_xy(self._coordinates.reshape(-1, 2))
2085+
self.set_mouseover(False)
2086+
2087+
def get_paths(self):
2088+
if self._paths is None:
2089+
self.set_paths()
2090+
return self._paths
2091+
2092+
def set_paths(self):
2093+
self._paths = self._convert_mesh_to_paths(self._coordinates)
2094+
self.stale = True
2095+
2096+
def get_datalim(self, transData):
2097+
return (self.get_transform() - transData).transform_bbox(self._bbox)
2098+
20602099
@artist.allow_rasterization
20612100
def draw(self, renderer):
20622101
if not self.get_visible():
@@ -2113,8 +2152,41 @@ def get_cursor_data(self, event):
21132152
return None
21142153

21152154

2116-
class PolyQuadMesh(PolyCollection, QuadMesh):
2155+
class PolyQuadMesh(_MeshData, PolyCollection):
2156+
"""
2157+
Class for drawing a quadrilateral mesh as individual Polygons.
2158+
2159+
A quadrilateral mesh is a grid of M by N adjacent quadrilaterals that are
2160+
defined via a (M+1, N+1) grid of vertices. The quadrilateral (m, n) is
2161+
defined by the vertices ::
2162+
2163+
(m+1, n) ----------- (m+1, n+1)
2164+
/ /
2165+
/ /
2166+
/ /
2167+
(m, n) -------- (m, n+1)
2168+
2169+
The mesh need not be regular and the polygons need not be convex.
2170+
2171+
Parameters
2172+
----------
2173+
coordinates : (M+1, N+1, 2) array-like
2174+
The vertices. ``coordinates[m, n]`` specifies the (x, y) coordinates
2175+
of vertex (m, n).
2176+
2177+
Notes
2178+
-----
2179+
Unlike `.QuadMesh`, this class will draw each cell as an individual Polygon.
2180+
This is significantly slower, but allows for more flexibility when wanting
2181+
to add additional properties to the cells, such as hatching.
2182+
2183+
Another difference from `.QuadMesh` is that if any of the vertices or data
2184+
of a cell are masked, that Polygon will **not** be drawn and it won't be in
2185+
the list of paths returned.
2186+
"""
2187+
21172188
def __init__(self, coordinates, **kwargs):
2189+
super().__init__(coordinates=coordinates)
21182190
X = coordinates[..., 0]
21192191
Y = coordinates[..., 1]
21202192

@@ -2128,6 +2200,7 @@ def __init__(self, coordinates, **kwargs):
21282200
mask |= np.ma.getmaskarray(C)
21292201

21302202
unmask = ~mask
2203+
self._valid_polys = unmask.ravel()
21312204
X1 = np.ma.filled(X[:-1, :-1])[unmask]
21322205
Y1 = np.ma.filled(Y[:-1, :-1])[unmask]
21332206
X2 = np.ma.filled(X[1:, :-1])[unmask]
@@ -2140,5 +2213,25 @@ def __init__(self, coordinates, **kwargs):
21402213

21412214
xy = np.ma.stack([X1, Y1, X2, Y2, X3, Y3, X4, Y4, X1, Y1], axis=-1)
21422215
verts = xy.reshape((npoly, 5, 2))
2143-
# We need both verts and coordinates here to go through the super chain
2144-
super().__init__(coordinates=coordinates, verts=verts, **kwargs)
2216+
# Setting the verts updates the paths of the PolyCollection
2217+
PolyCollection.__init__(self, verts=verts, **kwargs)
2218+
2219+
def get_edgecolor(self):
2220+
# docstring inherited
2221+
# We only want to return the facecolors of the polygons
2222+
# that were drawn.
2223+
ec = super().get_edgecolor()
2224+
if len(ec) != len(self._valid_polys):
2225+
# Mapping is off
2226+
return ec
2227+
return ec[self._valid_polys, :]
2228+
2229+
def get_facecolor(self):
2230+
# docstring inherited
2231+
# We only want to return the facecolors of the polygons
2232+
# that were drawn.
2233+
fc = super().get_facecolor()
2234+
if len(fc) != len(self._valid_polys):
2235+
# Mapping is off
2236+
return fc
2237+
return fc[self._valid_polys, :]

0 commit comments

Comments
 (0)