diff --git a/doc/missing-references.json b/doc/missing-references.json index d1dc67b7e6a8..654b3ffce066 100644 --- a/doc/missing-references.json +++ b/doc/missing-references.json @@ -304,8 +304,8 @@ "matplotlib.collections._CollectionWithSizes.set_sizes": [ "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.barbs:179", "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.broken_barh:84", - "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.fill_between:120", - "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.fill_betweenx:120", + "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.fill_between:121", + "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.fill_betweenx:121", "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.hexbin:213", "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.pcolor:182", "lib/matplotlib/axes/_axes.py:docstring of matplotlib.axes._axes.Axes.quiver:215", @@ -316,10 +316,11 @@ "lib/matplotlib/collections.py:docstring of matplotlib.artist.PolyQuadMesh.set:44", "lib/matplotlib/collections.py:docstring of matplotlib.artist.RegularPolyCollection.set:44", "lib/matplotlib/collections.py:docstring of matplotlib.artist.StarPolygonCollection.set:44", + "lib/matplotlib/collections.py:docstring of matplotlib.artist.FillBetweenPolyCollection.set:45", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.barbs:179", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.broken_barh:84", - "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.fill_between:120", - "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.fill_betweenx:120", + "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.fill_between:121", + "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.fill_betweenx:121", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.hexbin:213", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.pcolor:182", "lib/matplotlib/pyplot.py:docstring of matplotlib.pyplot.quiver:215", diff --git a/doc/users/next_whats_new/fill_between_poly_collection.rst b/doc/users/next_whats_new/fill_between_poly_collection.rst new file mode 100644 index 000000000000..6c3b7673e631 --- /dev/null +++ b/doc/users/next_whats_new/fill_between_poly_collection.rst @@ -0,0 +1,22 @@ +``FillBetweenPolyCollection`` +----------------------------- + +The new class :class:`matplotlib.collections.FillBetweenPolyCollection` provides +the ``set_data`` method, enabling e.g. resampling +(:file:`galleries/event_handling/resample.html`). +:func:`matplotlib.axes.Axes.fill_between` and +:func:`matplotlib.axes.Axes.fill_betweenx` now return this new class. + +.. code-block:: python + + import numpy as np + from matplotlib import pyplot as plt + + t = np.linspace(0, 1) + + fig, ax = plt.subplots() + coll = ax.fill_between(t, -t**2, t**2) + fig.savefig("before.png") + + coll.set_data(t, -t**4, t**4) + fig.savefig("after.png") diff --git a/galleries/examples/event_handling/resample.py b/galleries/examples/event_handling/resample.py index 913cac9cdf0c..f4209ddc6334 100644 --- a/galleries/examples/event_handling/resample.py +++ b/galleries/examples/event_handling/resample.py @@ -22,13 +22,19 @@ # A class that will downsample the data and recompute when zoomed. class DataDisplayDownsampler: - def __init__(self, xdata, ydata): - self.origYData = ydata + def __init__(self, xdata, y1data, y2data): + self.origY1Data = y1data + self.origY2Data = y2data self.origXData = xdata self.max_points = 50 self.delta = xdata[-1] - xdata[0] - def downsample(self, xstart, xend): + def plot(self, ax): + x, y1, y2 = self._downsample(self.origXData.min(), self.origXData.max()) + (self.line,) = ax.plot(x, y1, 'o-') + self.poly_collection = ax.fill_between(x, y1, y2, step="pre", color="r") + + def _downsample(self, xstart, xend): # get the points in the view range mask = (self.origXData > xstart) & (self.origXData < xend) # dilate the mask by one to catch the points just outside @@ -39,36 +45,41 @@ def downsample(self, xstart, xend): # mask data xdata = self.origXData[mask] - ydata = self.origYData[mask] + y1data = self.origY1Data[mask] + y2data = self.origY2Data[mask] # downsample data xdata = xdata[::ratio] - ydata = ydata[::ratio] + y1data = y1data[::ratio] + y2data = y2data[::ratio] - print(f"using {len(ydata)} of {np.sum(mask)} visible points") + print(f"using {len(y1data)} of {np.sum(mask)} visible points") - return xdata, ydata + return xdata, y1data, y2data def update(self, ax): - # Update the line + # Update the artists lims = ax.viewLim if abs(lims.width - self.delta) > 1e-8: self.delta = lims.width xstart, xend = lims.intervalx - self.line.set_data(*self.downsample(xstart, xend)) + x, y1, y2 = self._downsample(xstart, xend) + self.line.set_data(x, y1) + self.poly_collection.set_data(x, y1, y2, step="pre") ax.figure.canvas.draw_idle() # Create a signal xdata = np.linspace(16, 365, (365-16)*4) -ydata = np.sin(2*np.pi*xdata/153) + np.cos(2*np.pi*xdata/127) +y1data = np.sin(2*np.pi*xdata/153) + np.cos(2*np.pi*xdata/127) +y2data = y1data + .2 -d = DataDisplayDownsampler(xdata, ydata) +d = DataDisplayDownsampler(xdata, y1data, y2data) fig, ax = plt.subplots() # Hook up the line -d.line, = ax.plot(xdata, ydata, 'o-') +d.plot(ax) ax.set_autoscale_on(False) # Otherwise, infinite loop # Connect for changing the view limits diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index a16e19f90152..5462b6fe5096 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -6,7 +6,6 @@ import re import numpy as np -from numpy import ma import matplotlib as mpl import matplotlib.category # Register category unit converter as side effect. @@ -5551,18 +5550,18 @@ def _fill_between_x_or_y( i.e. constant in between *{ind}*. The value determines where the step will occur: - - 'pre': The y value is continued constantly to the left from - every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the - value ``y[i]``. + - 'pre': The {dep} value is continued constantly to the left from + every *{ind}* position, i.e. the interval ``({ind}[i-1], {ind}[i]]`` + has the value ``{dep}[i]``. - 'post': The y value is continued constantly to the right from - every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the - value ``y[i]``. - - 'mid': Steps occur half-way between the *x* positions. + every *{ind}* position, i.e. the interval ``[{ind}[i], {ind}[i+1])`` + has the value ``{dep}[i]``. + - 'mid': Steps occur half-way between the *{ind}* positions. Returns ------- - `.PolyCollection` - A `.PolyCollection` containing the plotted polygons. + `.FillBetweenPolyCollection` + A `.FillBetweenPolyCollection` containing the plotted polygons. Other Parameters ---------------- @@ -5570,124 +5569,39 @@ def _fill_between_x_or_y( DATA_PARAMETER_PLACEHOLDER **kwargs - All other keyword arguments are passed on to `.PolyCollection`. - They control the `.Polygon` properties: + All other keyword arguments are passed on to + `.FillBetweenPolyCollection`. They control the `.Polygon` properties: - %(PolyCollection:kwdoc)s + %(FillBetweenPolyCollection:kwdoc)s See Also -------- fill_between : Fill between two sets of y-values. fill_betweenx : Fill between two sets of x-values. """ - - dep_dir = {"x": "y", "y": "x"}[ind_dir] + dep_dir = mcoll.FillBetweenPolyCollection._f_dir_from_t(ind_dir) if not mpl.rcParams["_internal.classic_mode"]: kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection) if not any(c in kwargs for c in ("color", "facecolor")): - kwargs["facecolor"] = \ - self._get_patches_for_fill.get_next_color() - - # Handle united data, such as dates - ind, dep1, dep2 = map( - ma.masked_invalid, self._process_unit_info( - [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs)) - - for name, array in [ - (ind_dir, ind), (f"{dep_dir}1", dep1), (f"{dep_dir}2", dep2)]: - if array.ndim > 1: - raise ValueError(f"{name!r} is not 1-dimensional") + kwargs["facecolor"] = self._get_patches_for_fill.get_next_color() - if where is None: - where = True - else: - where = np.asarray(where, dtype=bool) - if where.size != ind.size: - raise ValueError(f"where size ({where.size}) does not match " - f"{ind_dir} size ({ind.size})") - where = where & ~functools.reduce( - np.logical_or, map(np.ma.getmaskarray, [ind, dep1, dep2])) - - ind, dep1, dep2 = np.broadcast_arrays( - np.atleast_1d(ind), dep1, dep2, subok=True) - - polys = [] - for idx0, idx1 in cbook.contiguous_regions(where): - indslice = ind[idx0:idx1] - dep1slice = dep1[idx0:idx1] - dep2slice = dep2[idx0:idx1] - if step is not None: - step_func = cbook.STEP_LOOKUP_MAP["steps-" + step] - indslice, dep1slice, dep2slice = \ - step_func(indslice, dep1slice, dep2slice) - - if not len(indslice): - continue + ind, dep1, dep2 = self._fill_between_process_units( + ind_dir, dep_dir, ind, dep1, dep2, **kwargs) - N = len(indslice) - pts = np.zeros((2 * N + 2, 2)) - - if interpolate: - def get_interp_point(idx): - im1 = max(idx - 1, 0) - ind_values = ind[im1:idx+1] - diff_values = dep1[im1:idx+1] - dep2[im1:idx+1] - dep1_values = dep1[im1:idx+1] - - if len(diff_values) == 2: - if np.ma.is_masked(diff_values[1]): - return ind[im1], dep1[im1] - elif np.ma.is_masked(diff_values[0]): - return ind[idx], dep1[idx] - - diff_order = diff_values.argsort() - diff_root_ind = np.interp( - 0, diff_values[diff_order], ind_values[diff_order]) - ind_order = ind_values.argsort() - diff_root_dep = np.interp( - diff_root_ind, - ind_values[ind_order], dep1_values[ind_order]) - return diff_root_ind, diff_root_dep - - start = get_interp_point(idx0) - end = get_interp_point(idx1) - else: - # Handle scalar dep2 (e.g. 0): the fill should go all - # the way down to 0 even if none of the dep1 sample points do. - start = indslice[0], dep2slice[0] - end = indslice[-1], dep2slice[-1] - - pts[0] = start - pts[N + 1] = end - - pts[1:N+1, 0] = indslice - pts[1:N+1, 1] = dep1slice - pts[N+2:, 0] = indslice[::-1] - pts[N+2:, 1] = dep2slice[::-1] - - if ind_dir == "y": - pts = pts[:, ::-1] - - polys.append(pts) - - collection = mcoll.PolyCollection(polys, **kwargs) - - # now update the datalim and autoscale - pts = np.vstack([np.hstack([ind[where, None], dep1[where, None]]), - np.hstack([ind[where, None], dep2[where, None]])]) - if ind_dir == "y": - pts = pts[:, ::-1] - - up_x = up_y = True - if "transform" in kwargs: - up_x, up_y = kwargs["transform"].contains_branch_seperately(self.transData) - self.update_datalim(pts, updatex=up_x, updatey=up_y) + collection = mcoll.FillBetweenPolyCollection( + ind_dir, ind, dep1, dep2, + where=where, interpolate=interpolate, step=step, **kwargs) - self.add_collection(collection, autolim=False) + self.add_collection(collection) self._request_autoscale_view() return collection + def _fill_between_process_units(self, ind_dir, dep_dir, ind, dep1, dep2, **kwargs): + """Handle united data, such as dates.""" + return map(np.ma.masked_invalid, self._process_unit_info( + [(ind_dir, ind), (dep_dir, dep1), (dep_dir, dep2)], kwargs)) + def fill_between(self, x, y1, y2=0, where=None, interpolate=False, step=None, **kwargs): return self._fill_between_x_or_y( diff --git a/lib/matplotlib/axes/_axes.pyi b/lib/matplotlib/axes/_axes.pyi index 4e7746dd9ef4..2c54c9b55ce0 100644 --- a/lib/matplotlib/axes/_axes.pyi +++ b/lib/matplotlib/axes/_axes.pyi @@ -5,6 +5,7 @@ from matplotlib.artist import Artist from matplotlib.backend_bases import RendererBase from matplotlib.collections import ( Collection, + FillBetweenPolyCollection, LineCollection, PathCollection, PolyCollection, @@ -459,7 +460,7 @@ class Axes(_AxesBase): *, data=..., **kwargs - ) -> PolyCollection: ... + ) -> FillBetweenPolyCollection: ... def fill_betweenx( self, y: ArrayLike, @@ -471,7 +472,7 @@ class Axes(_AxesBase): *, data=..., **kwargs - ) -> PolyCollection: ... + ) -> FillBetweenPolyCollection: ... def imshow( self, X: ArrayLike | PIL.Image.Image, diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 397e4c12557d..e668308abc82 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -10,6 +10,7 @@ """ import itertools +import functools import math from numbers import Number, Real import warnings @@ -1254,6 +1255,248 @@ def set_verts_and_codes(self, verts, codes): self.stale = True +class FillBetweenPolyCollection(PolyCollection): + """ + `.PolyCollection` that fills the area between two x- or y-curves. + """ + def __init__( + self, t_direction, t, f1, f2, *, + where=None, interpolate=False, step=None, **kwargs): + """ + Parameters + ---------- + t_direction : {{'x', 'y'}} + The axes on which the variable lies. + + - 'x': the curves are ``(t, f1)`` and ``(t, f2)``. + - 'y': the curves are ``(f1, t)`` and ``(f2, t)``. + + t : array (length N) + The ``t_direction`` coordinates of the nodes defining the curves. + + f1 : array (length N) or scalar + The other coordinates of the nodes defining the first curve. + + f2 : array (length N) or scalar + The other coordinates of the nodes defining the second curve. + + where : array of bool (length N), optional + Define *where* to exclude some {dir} regions from being filled. + The filled regions are defined by the coordinates ``t[where]``. + More precisely, fill between ``t[i]`` and ``t[i+1]`` if + ``where[i] and where[i+1]``. Note that this definition implies + that an isolated *True* value between two *False* values in *where* + will not result in filling. Both sides of the *True* position + remain unfilled due to the adjacent *False* values. + + interpolate : bool, default: False + This option is only relevant if *where* is used and the two curves + are crossing each other. + + Semantically, *where* is often used for *f1* > *f2* or + similar. By default, the nodes of the polygon defining the filled + region will only be placed at the positions in the *t* array. + Such a polygon cannot describe the above semantics close to the + intersection. The t-sections containing the intersection are + simply clipped. + + Setting *interpolate* to *True* will calculate the actual + intersection point and extend the filled region up to this point. + + step : {{'pre', 'post', 'mid'}}, optional + Define *step* if the filling should be a step function, + i.e. constant in between *t*. The value determines where the + step will occur: + + - 'pre': The f value is continued constantly to the left from + every *t* position, i.e. the interval ``(t[i-1], t[i]]`` has the + value ``f[i]``. + - 'post': The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[t[i], t[i+1])`` has the + value ``f[i]``. + - 'mid': Steps occur half-way between the *t* positions. + + **kwargs + Forwarded to `.PolyCollection`. + + See Also + -------- + .Axes.fill_between, .Axes.fill_betweenx + """ + self.t_direction = t_direction + self._interpolate = interpolate + self._step = step + verts = self._make_verts(t, f1, f2, where) + super().__init__(verts, **kwargs) + + @staticmethod + def _f_dir_from_t(t_direction): + """The direction that is other than `t_direction`.""" + if t_direction == "x": + return "y" + elif t_direction == "y": + return "x" + else: + msg = f"t_direction must be 'x' or 'y', got {t_direction!r}" + raise ValueError(msg) + + @property + def _f_direction(self): + """The direction that is other than `self.t_direction`.""" + return self._f_dir_from_t(self.t_direction) + + def set_data(self, t, f1, f2, *, where=None): + """ + Set new values for the two bounding curves. + + Parameters + ---------- + t : array (length N) + The ``self.t_direction`` coordinates of the nodes defining the curves. + + f1 : array (length N) or scalar + The other coordinates of the nodes defining the first curve. + + f2 : array (length N) or scalar + The other coordinates of the nodes defining the second curve. + + where : array of bool (length N), optional + Define *where* to exclude some {dir} regions from being filled. + The filled regions are defined by the coordinates ``t[where]``. + More precisely, fill between ``t[i]`` and ``t[i+1]`` if + ``where[i] and where[i+1]``. Note that this definition implies + that an isolated *True* value between two *False* values in *where* + will not result in filling. Both sides of the *True* position + remain unfilled due to the adjacent *False* values. + + See Also + -------- + .PolyCollection.set_verts, .Line2D.set_data + """ + t, f1, f2 = self.axes._fill_between_process_units( + self.t_direction, self._f_direction, t, f1, f2) + + verts = self._make_verts(t, f1, f2, where) + self.set_verts(verts) + + def get_datalim(self, transData): + """Calculate the data limits and return them as a `.Bbox`.""" + datalim = transforms.Bbox.null() + datalim.update_from_data_xy((self.get_transform() - transData).transform( + np.concatenate([self._bbox, [self._bbox.minpos]]))) + return datalim + + def _make_verts(self, t, f1, f2, where): + """ + Make verts that can be forwarded to `.PolyCollection`. + """ + self._validate_shapes(self.t_direction, self._f_direction, t, f1, f2) + + where = self._get_data_mask(t, f1, f2, where) + t, f1, f2 = np.broadcast_arrays(np.atleast_1d(t), f1, f2, subok=True) + + self._bbox = transforms.Bbox.null() + self._bbox.update_from_data_xy(self._fix_pts_xy_order(np.concatenate([ + np.stack((t[where], f[where]), axis=-1) for f in (f1, f2)]))) + + return [ + self._make_verts_for_region(t, f1, f2, idx0, idx1) + for idx0, idx1 in cbook.contiguous_regions(where) + ] + + def _get_data_mask(self, t, f1, f2, where): + """ + Return a bool array, with True at all points that should eventually be rendered. + + The array is True at a point if none of the data inputs + *t*, *f1*, *f2* is masked and if the input *where* is true at that point. + """ + if where is None: + where = True + else: + where = np.asarray(where, dtype=bool) + if where.size != t.size: + msg = "where size ({}) does not match {!r} size ({})".format( + where.size, self.t_direction, t.size) + raise ValueError(msg) + return where & ~functools.reduce( + np.logical_or, map(np.ma.getmaskarray, [t, f1, f2])) + + @staticmethod + def _validate_shapes(t_dir, f_dir, t, f1, f2): + """Validate that t, f1 and f2 are 1-dimensional and have the same length.""" + names = (d + s for d, s in zip((t_dir, f_dir, f_dir), ("", "1", "2"))) + for name, array in zip(names, [t, f1, f2]): + if array.ndim > 1: + raise ValueError(f"{name!r} is not 1-dimensional") + if t.size > 1 and array.size > 1 and t.size != array.size: + msg = "{!r} has size {}, but {!r} has an unequal size of {}".format( + t_dir, t.size, name, array.size) + raise ValueError(msg) + + def _make_verts_for_region(self, t, f1, f2, idx0, idx1): + """ + Make ``verts`` for a contiguous region between ``idx0`` and ``idx1``, taking + into account ``step`` and ``interpolate``. + """ + t_slice = t[idx0:idx1] + f1_slice = f1[idx0:idx1] + f2_slice = f2[idx0:idx1] + if self._step is not None: + step_func = cbook.STEP_LOOKUP_MAP["steps-" + self._step] + t_slice, f1_slice, f2_slice = step_func(t_slice, f1_slice, f2_slice) + + if self._interpolate: + start = self._get_interpolating_points(t, f1, f2, idx0) + end = self._get_interpolating_points(t, f1, f2, idx1) + else: + # Handle scalar f2 (e.g. 0): the fill should go all + # the way down to 0 even if none of the dep1 sample points do. + start = t_slice[0], f2_slice[0] + end = t_slice[-1], f2_slice[-1] + + pts = np.concatenate(( + np.asarray([start]), + np.stack((t_slice, f1_slice), axis=-1), + np.asarray([end]), + np.stack((t_slice, f2_slice), axis=-1)[::-1])) + + return self._fix_pts_xy_order(pts) + + @classmethod + def _get_interpolating_points(cls, t, f1, f2, idx): + """Calculate interpolating points.""" + im1 = max(idx - 1, 0) + t_values = t[im1:idx+1] + diff_values = f1[im1:idx+1] - f2[im1:idx+1] + f1_values = f1[im1:idx+1] + + if len(diff_values) == 2: + if np.ma.is_masked(diff_values[1]): + return t[im1], f1[im1] + elif np.ma.is_masked(diff_values[0]): + return t[idx], f1[idx] + + diff_root_t = cls._get_diff_root(0, diff_values, t_values) + diff_root_f = cls._get_diff_root(diff_root_t, t_values, f1_values) + return diff_root_t, diff_root_f + + @staticmethod + def _get_diff_root(x, xp, fp): + """Calculate diff root.""" + order = xp.argsort() + return np.interp(x, xp[order], fp[order]) + + def _fix_pts_xy_order(self, pts): + """ + Fix pts calculation results with `self.t_direction`. + + In the workflow, it is assumed that `self.t_direction` is 'x'. If this + is not true, we need to exchange the coordinates. + """ + return pts[:, ::-1] if self.t_direction == "y" else pts + + class RegularPolyCollection(_CollectionWithSizes): """A collection of n-sided regular polygons.""" diff --git a/lib/matplotlib/collections.pyi b/lib/matplotlib/collections.pyi index e4c46229517f..06d8676867ee 100644 --- a/lib/matplotlib/collections.pyi +++ b/lib/matplotlib/collections.pyi @@ -106,6 +106,29 @@ class PolyCollection(_CollectionWithSizes): self, verts: Sequence[ArrayLike | Path], codes: Sequence[int] ) -> None: ... +class FillBetweenPolyCollection(PolyCollection): + def __init__( + self, + t_direction: Literal["x", "y"], + t: ArrayLike, + f1: ArrayLike, + f2: ArrayLike, + *, + where: Sequence[bool] | None = ..., + interpolate: bool = ..., + step: Literal["pre", "post", "mid"] | None = ..., + **kwargs, + ) -> None: ... + def set_data( + self, + t: ArrayLike, + f1: ArrayLike, + f2: ArrayLike, + *, + where: Sequence[bool] | None = ..., + ) -> None: ... + def get_datalim(self, transData: transforms.Transform) -> transforms.Bbox: ... + class RegularPolyCollection(_CollectionWithSizes): def __init__( self, numsides: int, *, rotation: float = ..., sizes: ArrayLike = ..., **kwargs diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 69c80e6d3579..af9b9096451a 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -101,6 +101,7 @@ from matplotlib.contour import ContourSet, QuadContourSet from matplotlib.collections import ( Collection, + FillBetweenPolyCollection, LineCollection, PolyCollection, PathCollection, @@ -3315,7 +3316,7 @@ def fill_between( *, data=None, **kwargs, -) -> PolyCollection: +) -> FillBetweenPolyCollection: return gca().fill_between( x, y1, @@ -3340,7 +3341,7 @@ def fill_betweenx( *, data=None, **kwargs, -) -> PolyCollection: +) -> FillBetweenPolyCollection: return gca().fill_betweenx( y, x1, diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index f9b2f5624a61..44292cbc1d53 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -17,6 +17,7 @@ import matplotlib.transforms as mtransforms from matplotlib.collections import (Collection, LineCollection, EventCollection, PolyCollection) +from matplotlib.collections import FillBetweenPolyCollection from matplotlib.testing.decorators import check_figures_equal, image_comparison @@ -830,6 +831,35 @@ def test_collection_set_verts_array(): assert np.array_equal(ap._codes, atp._codes) +@check_figures_equal(extensions=["png"]) +@pytest.mark.parametrize("kwargs", [{}, {"step": "pre"}]) +def test_fill_between_poly_collection_set_data(fig_test, fig_ref, kwargs): + t = np.linspace(0, 16) + f1 = np.sin(t) + f2 = f1 + 0.2 + + fig_ref.subplots().fill_between(t, f1, f2, **kwargs) + + coll = fig_test.subplots().fill_between(t, -1, 1.2, **kwargs) + coll.set_data(t, f1, f2) + + +@pytest.mark.parametrize(("t_direction", "f1", "shape", "where", "msg"), [ + ("z", None, None, None, r"t_direction must be 'x' or 'y', got 'z'"), + ("x", None, (-1, 1), None, r"'x' is not 1-dimensional"), + ("x", None, None, [False] * 3, r"where size \(3\) does not match 'x' size \(\d+\)"), + ("y", [1, 2], None, None, r"'y' has size \d+, but 'x1' has an unequal size of \d+"), +]) +def test_fill_between_poly_collection_raise(t_direction, f1, shape, where, msg): + t = np.linspace(0, 16) + f1 = np.sin(t) if f1 is None else np.asarray(f1) + f2 = f1 + 0.2 + if shape: + t = t.reshape(*shape) + with pytest.raises(ValueError, match=msg): + FillBetweenPolyCollection(t_direction, t, f1, f2, where=where) + + def test_collection_set_array(): vals = [*range(10)]