From e0178fbe7c351175d9f560fca59d5e05fe3e37d1 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Sun, 27 Aug 2017 22:06:40 -0400 Subject: [PATCH 1/2] ENH: add `ax` or `fig` kwarg to every pyplot function This allows passing `ax=some_ax` or `fig=some_fig` into all of the pyplot plotting functions. --- lib/matplotlib/pyplot.py | 645 +++++++++++++++++++++------- lib/matplotlib/tests/test_pyplot.py | 8 + tools/boilerplate.py | 38 +- 3 files changed, 531 insertions(+), 160 deletions(-) diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 99af494c7880..3d05e3cb6977 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -734,7 +734,7 @@ def xkcd( return stack -## Figures ## +# # Figures # # def figure( # autoincrement if None, else integer from 1-N @@ -979,7 +979,7 @@ def get_figlabels() -> list[Any]: return [m.canvas.figure.get_label() for m in managers] -def get_current_fig_manager() -> FigureManagerBase | None: +def get_current_fig_manager(*, fig: Figure | None = None) -> FigureManagerBase | None: """ Return the figure manager of the current figure. @@ -993,17 +993,23 @@ def get_current_fig_manager() -> FigureManagerBase | None: ------- `.FigureManagerBase` or backend-dependent subclass thereof """ - return gcf().canvas.manager + if fig is None: + fig = gcf() + return fig.canvas.manager @_copy_docstring_and_deprecators(FigureCanvasBase.mpl_connect) -def connect(s: str, func: Callable[[Event], Any]) -> int: - return gcf().canvas.mpl_connect(s, func) +def connect(s: str, func: Callable[[Event], Any], *, fig: Figure | None = None) -> int: + if fig is None: + fig = gcf() + return fig.canvas.mpl_connect(s, func) @_copy_docstring_and_deprecators(FigureCanvasBase.mpl_disconnect) -def disconnect(cid: int) -> None: - return gcf().canvas.mpl_disconnect(cid) +def disconnect(cid: int, *, fig: Figure | None = None) -> None: + if fig is None: + fig = gcf() + return fig.canvas.mpl_disconnect(cid) def close(fig: None | int | str | Figure | Literal["all"] = None) -> None: @@ -1048,14 +1054,17 @@ def close(fig: None | int | str | Figure | Literal["all"] = None) -> None: "or None, not %s" % type(fig)) -def clf() -> None: - """Clear the current figure.""" - gcf().clear() +def clf(*, fig: Figure | None = None) -> None: + """ + Clear the current figure. + """ + if fig is None: + fig = gcf() + fig.clf() -def draw() -> None: - """ - Redraw the current figure. +def draw(*, fig: Figure | None = None) -> None: + """Redraw the current figure. This is used to update a figure that has been altered, but not automatically re-drawn. If interactive mode is on (via `.ion()`), this @@ -1070,23 +1079,25 @@ def draw() -> None: .FigureCanvasBase.draw_idle .FigureCanvasBase.draw """ - gcf().canvas.draw_idle() + if fig is None: + fig = gcf() + fig.canvas.draw_idle() @_copy_docstring_and_deprecators(Figure.savefig) def savefig(*args, **kwargs) -> None: - fig = gcf() # savefig default implementation has no return, so mypy is unhappy # presumably this is here because subclasses can return? - res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value] + fig = kwargs.pop('fig', None) or gcf() + res = fig.savefig(*args, **kwargs) fig.canvas.draw_idle() # Need this if 'transparent=True', to reset colors. return res ## Putting things in figures ## - - -def figlegend(*args, **kwargs) -> Legend: +def figlegend(*args, fig: Figure | None = None, **kwargs) -> Legend: + if fig is None: + fig = gcf() return gcf().legend(*args, **kwargs) if Figure.legend.__doc__: figlegend.__doc__ = Figure.legend.__doc__ \ @@ -1216,7 +1227,7 @@ def cla() -> None: ## More ways of creating axes ## @_docstring.dedent_interpd -def subplot(*args, **kwargs) -> Axes: +def subplot(*args, fig: Figure | None = None, **kwargs) -> Axes: """ Add an Axes to the current figure or retrieve an existing Axes. @@ -1381,7 +1392,8 @@ def subplot(*args, **kwargs) -> Axes: raise TypeError("subplot() got an unexpected keyword argument 'ncols' " "and/or 'nrows'. Did you intend to call subplots()?") - fig = gcf() + if fig is None: + fig = gcf() # First, search for an existing subplot with a matching spec. key = SubplotSpec._from_subplot_args(fig, args) @@ -1793,7 +1805,7 @@ def subplot_tool(targetfig: Figure | None = None) -> SubplotTool: "an associated toolbar") -def box(on: bool | None = None) -> None: +def box(on: bool | None = None, ax: Axes | None = None) -> None: """ Turn the axes box on or off on the current axes. @@ -1808,7 +1820,8 @@ def box(on: bool | None = None) -> None: :meth:`matplotlib.axes.Axes.set_frame_on` :meth:`matplotlib.axes.Axes.get_frame_on` """ - ax = gca() + if ax is None: + ax = gca() if on is None: on = not ax.get_frame_on() ax.set_frame_on(on) @@ -1834,6 +1847,9 @@ def xlim(*args, **kwargs) -> tuple[float, float]: Setting limits turns autoscaling off for the x-axis. + ax : matplotlib.axes.Axes, optional + Defaults to the current axes. + Returns ------- left, right @@ -2009,7 +2025,7 @@ def yticks( ... rotation=45) # Set text labels and properties. >>> yticks([]) # Disable yticks. """ - ax = gca() + ax = kwargs.pop('ax', None) or gca() if ticks is None: locs = ax.get_yticks(minor=minor) @@ -2095,7 +2111,7 @@ def rgrids( # set the locations and labels of the radial gridlines lines, labels = rgrids( (0.25, 0.5, 1.0), ('Tom', 'Dick', 'Harry' )) """ - ax = gca() + ax = kwargs.pop('ax', None) or gca() if not isinstance(ax, PolarAxes): raise RuntimeError('rgrids only defined for polar axes') if all(p is None for p in [radii, labels, angle, fmt]) and not kwargs: @@ -2170,7 +2186,7 @@ def thetagrids( # set the locations and labels of the angular gridlines lines, labels = thetagrids(range(45, 360, 90), ('NE', 'NW', 'SW', 'SE')) """ - ax = gca() + ax = kwargs.pop('ax', None) or gca() if not isinstance(ax, PolarAxes): raise RuntimeError('thetagrids only defined for polar axes') if all(param is None for param in [angles, labels, fmt]) and not kwargs: @@ -2401,9 +2417,13 @@ def figimage( vmax: float | None = None, origin: Literal["upper", "lower"] | None = None, resize: bool = False, + *, + fig: matplotlib.figure.Figure | None = None, **kwargs, ) -> FigureImage: - return gcf().figimage( + if fig is None: + fig = gcf() + return fig.figimage( X, xo=xo, yo=yo, @@ -2421,21 +2441,33 @@ def figimage( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Figure.text) def figtext( - x: float, y: float, s: str, fontdict: dict[str, Any] | None = None, **kwargs + x: float, + y: float, + s: str, + fontdict: dict[str, Any] | None = None, + *, + fig: matplotlib.figure.Figure | None = None, + **kwargs, ) -> Text: - return gcf().text(x, y, s, fontdict=fontdict, **kwargs) + if fig is None: + fig = gcf() + return fig.text(x, y, s, fontdict=fontdict, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Figure.gca) -def gca() -> Axes: - return gcf().gca() +def gca(*, fig: matplotlib.figure.Figure | None = None) -> Axes: + if fig is None: + fig = gcf() + return fig.gca() # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Figure._gci) -def gci() -> ScalarMappable | None: - return gcf()._gci() +def gci(*, fig: matplotlib.figure.Figure | None = None) -> ScalarMappable | None: + if fig is None: + fig = gcf() + return fig._gci() # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2447,8 +2479,12 @@ def ginput( mouse_add: MouseButton = MouseButton.LEFT, mouse_pop: MouseButton = MouseButton.RIGHT, mouse_stop: MouseButton = MouseButton.MIDDLE, + *, + fig: matplotlib.figure.Figure | None = None, ) -> list[tuple[int, int]]: - return gcf().ginput( + if fig is None: + fig = gcf() + return fig.ginput( n=n, timeout=timeout, show_clicks=show_clicks, @@ -2467,16 +2503,22 @@ def subplots_adjust( top: float | None = None, wspace: float | None = None, hspace: float | None = None, + *, + fig: matplotlib.figure.Figure | None = None, ) -> None: - return gcf().subplots_adjust( + if fig is None: + fig = gcf() + return fig.subplots_adjust( left=left, bottom=bottom, right=right, top=top, wspace=wspace, hspace=hspace ) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Figure.suptitle) -def suptitle(t: str, **kwargs) -> Text: - return gcf().suptitle(t, **kwargs) +def suptitle(t: str, *, fig: matplotlib.figure.Figure | None = None, **kwargs) -> Text: + if fig is None: + fig = gcf() + return fig.suptitle(t, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2487,22 +2529,31 @@ def tight_layout( h_pad: float | None = None, w_pad: float | None = None, rect: tuple[float, float, float, float] | None = None, + fig: matplotlib.figure.Figure | None = None, ) -> None: - return gcf().tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) + if fig is None: + fig = gcf() + return fig.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Figure.waitforbuttonpress) -def waitforbuttonpress(timeout: float = -1) -> None | bool: - return gcf().waitforbuttonpress(timeout=timeout) +def waitforbuttonpress( + timeout: float = -1, *, fig: matplotlib.figure.Figure | None = None +) -> None | bool: + if fig is None: + fig = gcf() + return fig.waitforbuttonpress(timeout=timeout) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.acorr) def acorr( - x: ArrayLike, *, data=None, **kwargs + x: ArrayLike, *, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs ) -> tuple[np.ndarray, np.ndarray, LineCollection | Line2D, Line2D | None]: - return gca().acorr(x, **({"data": data} if data is not None else {}), **kwargs) + if ax is None: + ax = gca() + return ax.acorr(x, **({"data": data} if data is not None else {}), **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2516,9 +2567,12 @@ def angle_spectrum( sides: Literal["default", "onesided", "twosided"] | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, Line2D]: - return gca().angle_spectrum( + if ax is None: + ax = gca() + return ax.angle_spectrum( x, Fs=Fs, Fc=Fc, @@ -2549,9 +2603,13 @@ def annotate( | None = None, arrowprops: dict[str, Any] | None = None, annotation_clip: bool | None = None, + *, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Annotation: - return gca().annotate( + if ax is None: + ax = gca() + return ax.annotate( text, xy, xytext=xytext, @@ -2565,8 +2623,18 @@ def annotate( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.arrow) -def arrow(x: float, y: float, dx: float, dy: float, **kwargs) -> FancyArrow: - return gca().arrow(x, y, dx, dy, **kwargs) +def arrow( + x: float, + y: float, + dx: float, + dy: float, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +) -> FancyArrow: + if ax is None: + ax = gca() + return ax.arrow(x, y, dx, dy, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2575,22 +2643,43 @@ def autoscale( enable: bool = True, axis: Literal["both", "x", "y"] = "both", tight: bool | None = None, + *, + ax: matplotlib.axes._axes.Axes | None = None, ) -> None: - return gca().autoscale(enable=enable, axis=axis, tight=tight) + if ax is None: + ax = gca() + return ax.autoscale(enable=enable, axis=axis, tight=tight) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.axhline) -def axhline(y: float = 0, xmin: float = 0, xmax: float = 1, **kwargs) -> Line2D: - return gca().axhline(y=y, xmin=xmin, xmax=xmax, **kwargs) +def axhline( + y: float = 0, + xmin: float = 0, + xmax: float = 1, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +) -> Line2D: + if ax is None: + ax = gca() + return ax.axhline(y=y, xmin=xmin, xmax=xmax, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.axhspan) def axhspan( - ymin: float, ymax: float, xmin: float = 0, xmax: float = 1, **kwargs + ymin: float, + ymax: float, + xmin: float = 0, + xmax: float = 1, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, ) -> Polygon: - return gca().axhspan(ymin, ymax, xmin=xmin, xmax=xmax, **kwargs) + if ax is None: + ax = gca() + return ax.axhspan(ymin, ymax, xmin=xmin, xmax=xmax, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2600,9 +2689,12 @@ def axis( /, *, emit: bool = True, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[float, float, float, float]: - return gca().axis(arg, emit=emit, **kwargs) + if ax is None: + ax = gca() + return ax.axis(arg, emit=emit, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2612,23 +2704,43 @@ def axline( xy2: tuple[float, float] | None = None, *, slope: float | None = None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Line2D: - return gca().axline(xy1, xy2=xy2, slope=slope, **kwargs) + if ax is None: + ax = gca() + return ax.axline(xy1, xy2=xy2, slope=slope, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.axvline) -def axvline(x: float = 0, ymin: float = 0, ymax: float = 1, **kwargs) -> Line2D: - return gca().axvline(x=x, ymin=ymin, ymax=ymax, **kwargs) +def axvline( + x: float = 0, + ymin: float = 0, + ymax: float = 1, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +) -> Line2D: + if ax is None: + ax = gca() + return ax.axvline(x=x, ymin=ymin, ymax=ymax, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.axvspan) def axvspan( - xmin: float, xmax: float, ymin: float = 0, ymax: float = 1, **kwargs + xmin: float, + xmax: float, + ymin: float = 0, + ymax: float = 1, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, ) -> Polygon: - return gca().axvspan(xmin, xmax, ymin=ymin, ymax=ymax, **kwargs) + if ax is None: + ax = gca() + return ax.axvspan(xmin, xmax, ymin=ymin, ymax=ymax, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2641,9 +2753,12 @@ def bar( *, align: Literal["center", "edge"] = "center", data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> BarContainer: - return gca().bar( + if ax is None: + ax = gca() + return ax.bar( x, height, width=width, @@ -2656,8 +2771,12 @@ def bar( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.barbs) -def barbs(*args, data=None, **kwargs) -> Barbs: - return gca().barbs(*args, **({"data": data} if data is not None else {}), **kwargs) +def barbs( + *args, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> Barbs: + if ax is None: + ax = gca() + return ax.barbs(*args, **({"data": data} if data is not None else {}), **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2670,9 +2789,12 @@ def barh( *, align: Literal["center", "edge"] = "center", data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> BarContainer: - return gca().barh( + if ax is None: + ax = gca() + return ax.barh( y, width, height=height, @@ -2692,9 +2814,12 @@ def bar_label( fmt: str | Callable[[float], str] = "%g", label_type: Literal["center", "edge"] = "edge", padding: float = 0, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> list[Text]: - return gca().bar_label( + if ax is None: + ax = gca() + return ax.bar_label( container, labels=labels, fmt=fmt, @@ -2736,8 +2861,11 @@ def boxplot( capwidths: float | ArrayLike | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, ) -> dict[str, Any]: - return gca().boxplot( + if ax is None: + ax = gca() + return ax.boxplot( x, notch=notch, sym=sym, @@ -2776,17 +2904,28 @@ def broken_barh( yrange: tuple[float, float], *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> BrokenBarHCollection: - return gca().broken_barh( + if ax is None: + ax = gca() + return ax.broken_barh( xranges, yrange, **({"data": data} if data is not None else {}), **kwargs ) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.clabel) -def clabel(CS: ContourSet, levels: ArrayLike | None = None, **kwargs) -> list[Text]: - return gca().clabel(CS, levels=levels, **kwargs) +def clabel( + CS: ContourSet, + levels: ArrayLike | None = None, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +) -> list[Text]: + if ax is None: + ax = gca() + return ax.clabel(CS, levels=levels, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2806,9 +2945,12 @@ def cohere( scale_by_freq: bool | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray]: - return gca().cohere( + if ax is None: + ax = gca() + return ax.cohere( x, y, NFFT=NFFT, @@ -2827,10 +2969,12 @@ def cohere( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.contour) -def contour(*args, data=None, **kwargs) -> QuadContourSet: - __ret = gca().contour( - *args, **({"data": data} if data is not None else {}), **kwargs - ) +def contour( + *args, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> QuadContourSet: + if ax is None: + ax = gca() + __ret = ax.contour(*args, **({"data": data} if data is not None else {}), **kwargs) if __ret._A is not None: sci(__ret) # noqa return __ret @@ -2838,10 +2982,12 @@ def contour(*args, data=None, **kwargs) -> QuadContourSet: # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.contourf) -def contourf(*args, data=None, **kwargs) -> QuadContourSet: - __ret = gca().contourf( - *args, **({"data": data} if data is not None else {}), **kwargs - ) +def contourf( + *args, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> QuadContourSet: + if ax is None: + ax = gca() + __ret = ax.contourf(*args, **({"data": data} if data is not None else {}), **kwargs) if __ret._A is not None: sci(__ret) # noqa return __ret @@ -2866,9 +3012,12 @@ def csd( return_line: bool | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Line2D]: - return gca().csd( + if ax is None: + ax = gca() + return ax.csd( x, y, NFFT=NFFT, @@ -2896,9 +3045,12 @@ def ecdf( orientation: Literal["vertical", "horizonatal"] = "vertical", compress: bool = False, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Line2D: - return gca().ecdf( + if ax is None: + ax = gca() + return ax.ecdf( x, weights=weights, complementary=complementary, @@ -2929,9 +3081,12 @@ def errorbar( capthick: float | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> ErrorbarContainer: - return gca().errorbar( + if ax is None: + ax = gca() + return ax.errorbar( x, y, yerr=yerr, @@ -2965,9 +3120,12 @@ def eventplot( linestyles: LineStyleType | Sequence[LineStyleType] = "solid", *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> EventCollection: - return gca().eventplot( + if ax is None: + ax = gca() + return ax.eventplot( positions, orientation=orientation, lineoffsets=lineoffsets, @@ -2983,8 +3141,12 @@ def eventplot( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.fill) -def fill(*args, data=None, **kwargs) -> list[Polygon]: - return gca().fill(*args, **({"data": data} if data is not None else {}), **kwargs) +def fill( + *args, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> list[Polygon]: + if ax is None: + ax = gca() + return ax.fill(*args, **({"data": data} if data is not None else {}), **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -2998,9 +3160,12 @@ def fill_between( step: Literal["pre", "post", "mid"] | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> PolyCollection: - return gca().fill_between( + if ax is None: + ax = gca() + return ax.fill_between( x, y1, y2=y2, @@ -3023,9 +3188,12 @@ def fill_betweenx( interpolate: bool = False, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> PolyCollection: - return gca().fill_betweenx( + if ax is None: + ax = gca() + return ax.fill_betweenx( y, x1, x2=x2, @@ -3043,9 +3211,13 @@ def grid( visible: bool | None = None, which: Literal["major", "minor", "both"] = "major", axis: Literal["both", "x", "y"] = "both", + *, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> None: - return gca().grid(visible=visible, which=which, axis=axis, **kwargs) + if ax is None: + ax = gca() + return ax.grid(visible=visible, which=which, axis=axis, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3071,9 +3243,12 @@ def hexbin( marginals: bool = False, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> PolyCollection: - __ret = gca().hexbin( + if ax is None: + ax = gca() + __ret = ax.hexbin( x, y, C=C, @@ -3119,13 +3294,16 @@ def hist( stacked: bool = False, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[ np.ndarray | list[np.ndarray], np.ndarray, BarContainer | Polygon | list[BarContainer | Polygon], ]: - return gca().hist( + if ax is None: + ax = gca() + return ax.hist( x, bins=bins, range=range, @@ -3156,9 +3334,12 @@ def stairs( baseline: float | ArrayLike = 0, fill: bool = False, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> StepPatch: - return gca().stairs( + if ax is None: + ax = gca() + return ax.stairs( values, edges=edges, orientation=orientation, @@ -3182,9 +3363,12 @@ def hist2d( cmax: float | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, QuadMesh]: - __ret = gca().hist2d( + if ax is None: + ax = gca() + __ret = ax.hist2d( x, y, bins=bins, @@ -3211,9 +3395,12 @@ def hlines( label: str = "", *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> LineCollection: - return gca().hlines( + if ax is None: + ax = gca() + return ax.hlines( y, xmin, xmax, @@ -3245,9 +3432,12 @@ def imshow( resample: bool | None = None, url: str | None = None, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> AxesImage: - __ret = gca().imshow( + if ax is None: + ax = gca() + __ret = ax.imshow( X, cmap=cmap, norm=norm, @@ -3272,22 +3462,34 @@ def imshow( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.legend) -def legend(*args, **kwargs) -> Legend: - return gca().legend(*args, **kwargs) +def legend(*args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs) -> Legend: + if ax is None: + ax = gca() + return ax.legend(*args, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.locator_params) def locator_params( - axis: Literal["both", "x", "y"] = "both", tight: bool | None = None, **kwargs + axis: Literal["both", "x", "y"] = "both", + tight: bool | None = None, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, ) -> None: - return gca().locator_params(axis=axis, tight=tight, **kwargs) + if ax is None: + ax = gca() + return ax.locator_params(axis=axis, tight=tight, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.loglog) -def loglog(*args, **kwargs) -> list[Line2D]: - return gca().loglog(*args, **kwargs) +def loglog( + *args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> list[Line2D]: + if ax is None: + ax = gca() + return ax.loglog(*args, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3302,9 +3504,12 @@ def magnitude_spectrum( scale: Literal["default", "linear", "dB"] | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, Line2D]: - return gca().magnitude_spectrum( + if ax is None: + ax = gca() + return ax.magnitude_spectrum( x, Fs=Fs, Fc=Fc, @@ -3324,20 +3529,27 @@ def margins( x: float | None = None, y: float | None = None, tight: bool | None = True, + ax: matplotlib.axes._axes.Axes | None = None, ) -> tuple[float, float] | None: - return gca().margins(*margins, x=x, y=y, tight=tight) + if ax is None: + ax = gca() + return ax.margins(*margins, x=x, y=y, tight=tight) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.minorticks_off) -def minorticks_off() -> None: - return gca().minorticks_off() +def minorticks_off(*, ax: matplotlib.axes._axes.Axes | None = None) -> None: + if ax is None: + ax = gca() + return ax.minorticks_off() # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.minorticks_on) -def minorticks_on() -> None: - return gca().minorticks_on() +def minorticks_on(*, ax: matplotlib.axes._axes.Axes | None = None) -> None: + if ax is None: + ax = gca() + return ax.minorticks_on() # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3351,9 +3563,12 @@ def pcolor( vmin: float | None = None, vmax: float | None = None, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Collection: - __ret = gca().pcolor( + if ax is None: + ax = gca() + __ret = ax.pcolor( *args, shading=shading, alpha=alpha, @@ -3380,9 +3595,12 @@ def pcolormesh( shading: Literal["flat", "nearest", "gouraud", "auto"] | None = None, antialiased: bool = False, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> QuadMesh: - __ret = gca().pcolormesh( + if ax is None: + ax = gca() + __ret = ax.pcolormesh( *args, alpha=alpha, norm=norm, @@ -3409,9 +3627,12 @@ def phase_spectrum( sides: Literal["default", "onesided", "twosided"] | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, Line2D]: - return gca().phase_spectrum( + if ax is None: + ax = gca() + return ax.phase_spectrum( x, Fs=Fs, Fc=Fc, @@ -3446,8 +3667,11 @@ def pie( normalize: bool = True, hatch: str | Sequence[str] | None = None, data=None, + ax: matplotlib.axes._axes.Axes | None = None, ): - return gca().pie( + if ax is None: + ax = gca() + return ax.pie( x, explode=explode, labels=labels, @@ -3477,9 +3701,12 @@ def plot( scalex: bool = True, scaley: bool = True, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> list[Line2D]: - return gca().plot( + if ax is None: + ax = gca() + return ax.plot( *args, scalex=scalex, scaley=scaley, @@ -3499,9 +3726,12 @@ def plot_date( ydate: bool = False, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> list[Line2D]: - return gca().plot_date( + if ax is None: + ax = gca() + return ax.plot_date( x, y, fmt=fmt, @@ -3531,9 +3761,12 @@ def psd( return_line: bool | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray] | tuple[np.ndarray, np.ndarray, Line2D]: - return gca().psd( + if ax is None: + ax = gca() + return ax.psd( x, NFFT=NFFT, Fs=Fs, @@ -3552,10 +3785,12 @@ def psd( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.quiver) -def quiver(*args, data=None, **kwargs) -> Quiver: - __ret = gca().quiver( - *args, **({"data": data} if data is not None else {}), **kwargs - ) +def quiver( + *args, data=None, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> Quiver: + if ax is None: + ax = gca() + __ret = ax.quiver(*args, **({"data": data} if data is not None else {}), **kwargs) sci(__ret) return __ret @@ -3563,9 +3798,18 @@ def quiver(*args, data=None, **kwargs) -> Quiver: # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.quiverkey) def quiverkey( - Q: Quiver, X: float, Y: float, U: float, label: str, **kwargs + Q: Quiver, + X: float, + Y: float, + U: float, + label: str, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, ) -> QuiverKey: - return gca().quiverkey(Q, X, Y, U, label, **kwargs) + if ax is None: + ax = gca() + return ax.quiverkey(Q, X, Y, U, label, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3586,9 +3830,12 @@ def scatter( edgecolors: Literal["face", "none"] | ColorType | Sequence[ColorType] | None = None, plotnonfinite: bool = False, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> PathCollection: - __ret = gca().scatter( + if ax is None: + ax = gca() + __ret = ax.scatter( x, y, s=s, @@ -3611,14 +3858,22 @@ def scatter( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.semilogx) -def semilogx(*args, **kwargs) -> list[Line2D]: - return gca().semilogx(*args, **kwargs) +def semilogx( + *args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> list[Line2D]: + if ax is None: + ax = gca() + return ax.semilogx(*args, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.semilogy) -def semilogy(*args, **kwargs) -> list[Line2D]: - return gca().semilogy(*args, **kwargs) +def semilogy( + *args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> list[Line2D]: + if ax is None: + ax = gca() + return ax.semilogy(*args, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3644,9 +3899,12 @@ def specgram( vmax: float | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, AxesImage]: - __ret = gca().specgram( + if ax is None: + ax = gca() + __ret = ax.specgram( x, NFFT=NFFT, Fs=Fs, @@ -3679,9 +3937,13 @@ def spy( markersize: float | None = None, aspect: Literal["equal", "auto"] | float | None = "equal", origin: Literal["upper", "lower"] = "upper", + *, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> AxesImage: - __ret = gca().spy( + if ax is None: + ax = gca() + __ret = ax.spy( Z, precision=precision, marker=marker, @@ -3697,8 +3959,19 @@ def spy( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.stackplot) -def stackplot(x, *args, labels=(), colors=None, baseline="zero", data=None, **kwargs): - return gca().stackplot( +def stackplot( + x, + *args, + labels=(), + colors=None, + baseline="zero", + data=None, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +): + if ax is None: + ax = gca() + return ax.stackplot( x, *args, labels=labels, @@ -3720,8 +3993,11 @@ def stem( label: str | None = None, orientation: Literal["vertical", "horizontal"] = "vertical", data=None, + ax: matplotlib.axes._axes.Axes | None = None, ) -> StemContainer: - return gca().stem( + if ax is None: + ax = gca() + return ax.stem( *args, linefmt=linefmt, markerfmt=markerfmt, @@ -3741,9 +4017,12 @@ def step( *args, where: Literal["pre", "post", "mid"] = "pre", data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> list[Line2D]: - return gca().step( + if ax is None: + ax = gca() + return ax.step( x, y, *args, @@ -3776,8 +4055,11 @@ def streamplot( broken_streamlines=True, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, ): - __ret = gca().streamplot( + if ax is None: + ax = gca() + __ret = ax.streamplot( x, y, u, @@ -3818,9 +4100,13 @@ def table( loc="bottom", bbox=None, edges="closed", + *, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ): - return gca().table( + if ax is None: + ax = gca() + return ax.table( cellText=cellText, cellColours=cellColours, cellLoc=cellLoc, @@ -3841,15 +4127,30 @@ def table( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.text) def text( - x: float, y: float, s: str, fontdict: dict[str, Any] | None = None, **kwargs + x: float, + y: float, + s: str, + fontdict: dict[str, Any] | None = None, + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, ) -> Text: - return gca().text(x, y, s, fontdict=fontdict, **kwargs) + if ax is None: + ax = gca() + return ax.text(x, y, s, fontdict=fontdict, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.tick_params) -def tick_params(axis: Literal["both", "x", "y"] = "both", **kwargs) -> None: - return gca().tick_params(axis=axis, **kwargs) +def tick_params( + axis: Literal["both", "x", "y"] = "both", + *, + ax: matplotlib.axes._axes.Axes | None = None, + **kwargs, +) -> None: + if ax is None: + ax = gca() + return ax.tick_params(axis=axis, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3862,8 +4163,11 @@ def ticklabel_format( useOffset: bool | float | None = None, useLocale: bool | None = None, useMathText: bool | None = None, + ax: matplotlib.axes._axes.Axes | None = None, ) -> None: - return gca().ticklabel_format( + if ax is None: + ax = gca() + return ax.ticklabel_format( axis=axis, style=style, scilimits=scilimits, @@ -3875,8 +4179,10 @@ def ticklabel_format( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.tricontour) -def tricontour(*args, **kwargs): - __ret = gca().tricontour(*args, **kwargs) +def tricontour(*args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs): + if ax is None: + ax = gca() + __ret = ax.tricontour(*args, **kwargs) if __ret._A is not None: sci(__ret) # noqa return __ret @@ -3884,8 +4190,10 @@ def tricontour(*args, **kwargs): # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.tricontourf) -def tricontourf(*args, **kwargs): - __ret = gca().tricontourf(*args, **kwargs) +def tricontourf(*args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs): + if ax is None: + ax = gca() + __ret = ax.tricontourf(*args, **kwargs) if __ret._A is not None: sci(__ret) # noqa return __ret @@ -3902,9 +4210,12 @@ def tripcolor( vmax=None, shading="flat", facecolors=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ): - __ret = gca().tripcolor( + if ax is None: + ax = gca() + __ret = ax.tripcolor( *args, alpha=alpha, norm=norm, @@ -3921,8 +4232,10 @@ def tripcolor( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.triplot) -def triplot(*args, **kwargs): - return gca().triplot(*args, **kwargs) +def triplot(*args, ax: matplotlib.axes._axes.Axes | None = None, **kwargs): + if ax is None: + ax = gca() + return ax.triplot(*args, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -3943,8 +4256,11 @@ def violinplot( | None = None, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, ) -> dict[str, Collection]: - return gca().violinplot( + if ax is None: + ax = gca() + return ax.violinplot( dataset, positions=positions, vert=vert, @@ -3970,9 +4286,12 @@ def vlines( label: str = "", *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> LineCollection: - return gca().vlines( + if ax is None: + ax = gca() + return ax.vlines( x, ymin, ymax, @@ -3995,9 +4314,12 @@ def xcorr( maxlags: int = 10, *, data=None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> tuple[np.ndarray, np.ndarray, LineCollection | Line2D, Line2D | None]: - return gca().xcorr( + if ax is None: + ax = gca() + return ax.xcorr( x, y, normed=normed, @@ -4011,8 +4333,10 @@ def xcorr( # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes._sci) -def sci(im: ScalarMappable) -> None: - return gca()._sci(im) +def sci(im: ScalarMappable, *, ax: matplotlib.axes._axes.Axes | None = None) -> None: + if ax is None: + ax = gca() + return ax._sci(im) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -4024,9 +4348,12 @@ def title( pad: float | None = None, *, y: float | None = None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Text: - return gca().set_title(label, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs) + if ax is None: + ax = gca() + return ax.set_title(label, fontdict=fontdict, loc=loc, pad=pad, y=y, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @@ -4037,9 +4364,12 @@ def xlabel( labelpad: float | None = None, *, loc: Literal["left", "center", "right"] | None = None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Text: - return gca().set_xlabel( + if ax is None: + ax = gca() + return ax.set_xlabel( xlabel, fontdict=fontdict, labelpad=labelpad, loc=loc, **kwargs ) @@ -4052,23 +4382,34 @@ def ylabel( labelpad: float | None = None, *, loc: Literal["bottom", "center", "top"] | None = None, + ax: matplotlib.axes._axes.Axes | None = None, **kwargs, ) -> Text: - return gca().set_ylabel( + if ax is None: + ax = gca() + return ax.set_ylabel( ylabel, fontdict=fontdict, labelpad=labelpad, loc=loc, **kwargs ) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.set_xscale) -def xscale(value: str | ScaleBase, **kwargs) -> None: - return gca().set_xscale(value, **kwargs) +def xscale( + value: str | ScaleBase, *, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> None: + if ax is None: + ax = gca() + return ax.set_xscale(value, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. @_copy_docstring_and_deprecators(Axes.set_yscale) -def yscale(value: str | ScaleBase, **kwargs) -> None: - return gca().set_yscale(value, **kwargs) +def yscale( + value: str | ScaleBase, *, ax: matplotlib.axes._axes.Axes | None = None, **kwargs +) -> None: + if ax is None: + ax = gca() + return ax.set_yscale(value, **kwargs) # Autogenerated by boilerplate.py. Do not edit as changes will be lost. diff --git a/lib/matplotlib/tests/test_pyplot.py b/lib/matplotlib/tests/test_pyplot.py index 16a927e9e154..1c7b183cdb62 100644 --- a/lib/matplotlib/tests/test_pyplot.py +++ b/lib/matplotlib/tests/test_pyplot.py @@ -456,3 +456,11 @@ def test_figure_hook(): fig = plt.figure() assert fig._test_was_here + + +def test_explicit_ax(): + fig, (ax1, ax2) = plt.subplots(2) + plt.plot(range(5)) + plt.plot(range(5)[::-1], ax=ax1) + assert len(ax1.lines) == 1 + assert len(ax2.lines) == 1 diff --git a/tools/boilerplate.py b/tools/boilerplate.py index e3b4809d96b3..6132fb3c3e13 100644 --- a/tools/boilerplate.py +++ b/tools/boilerplate.py @@ -65,7 +65,9 @@ def enum_str_back_compat_patch(self): AXES_CMAPPABLE_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Axes.{called_name}) def {name}{signature}: - __ret = gca().{called_name}{call} + if ax is None: + ax = gca() + __ret = ax.{called_name}{call} {sci_command} return __ret """ @@ -73,13 +75,17 @@ def {name}{signature}: AXES_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Axes.{called_name}) def {name}{signature}: - return gca().{called_name}{call} + if ax is None: + ax = gca() + return ax.{called_name}{call} """ FIGURE_METHOD_TEMPLATE = AUTOGEN_MSG + """ @_copy_docstring_and_deprecators(Figure.{called_name}) def {name}{signature}: - return gcf().{called_name}{call} + if fig is None: + fig = gcf() + return fig.{called_name}{call} """ CMAP_TEMPLATE = ''' @@ -131,7 +137,8 @@ def __repr__(self): return self._repr -def generate_function(name, called_fullname, template, **kwargs): +def generate_function(name, called_fullname, template, implicit_input, + **kwargs): """ Create a wrapper function *pyplot_name* calling *call_name*. @@ -150,6 +157,9 @@ def generate_function(name, called_fullname, template, **kwargs): - called_name: The name of the called function. - call: Parameters passed to *called_name* (including parentheses). + implicit_input : {'ax', 'fig', None} + Any extra kwargs that should be injected. + **kwargs Additional parameters are passed to ``template.format()``. """ @@ -169,10 +179,21 @@ def generate_function(name, called_fullname, template, **kwargs): # Replace self argument. params = list(signature.parameters.values())[1:] + param_tail = [] + if implicit_input is not None: + param_tail.append(inspect.Parameter(name=implicit_input, + default=None, + kind=inspect.Parameter.KEYWORD_ONLY, + annotation=class_ | None)) + param_tail.extend( + [p for p in params + if p.kind is inspect.Parameter.VAR_KEYWORD]) + non_var_params = [p for p in params + if p.kind is not inspect.Parameter.VAR_KEYWORD] signature = str(signature.replace(parameters=[ param.replace(default=value_formatter(param.default)) if param.default is not param.empty else param - for param in params])) + for param in non_var_params] + param_tail)) # How to call the wrapped function. call = '(' + ', '.join(( # Pass "intended-as-positional" parameters positionally to avoid @@ -198,7 +219,7 @@ def generate_function(name, called_fullname, template, **kwargs): None).format(param.name) for param in params) + ')' # Bail out in case of name collision. - for reserved in ('gca', 'gci', 'gcf', '__ret'): + for reserved in ('gca', 'gci', 'gcf', '__ret', implicit_input): if reserved in params: raise ValueError( f'Method {called_fullname} has kwarg named {reserved}') @@ -332,7 +353,7 @@ def boilerplate_gen(): else: name = called_name = spec yield generate_function(name, f'Figure.{called_name}', - FIGURE_METHOD_TEMPLATE) + FIGURE_METHOD_TEMPLATE, implicit_input='fig') for spec in _axes_commands: if ':' in spec: @@ -343,7 +364,8 @@ def boilerplate_gen(): template = (AXES_CMAPPABLE_METHOD_TEMPLATE if name in cmappable else AXES_METHOD_TEMPLATE) yield generate_function(name, f'Axes.{called_name}', template, - sci_command=cmappable.get(name)) + sci_command=cmappable.get(name), + implicit_input='ax') cmaps = ( 'autumn', From 2c7f076e1284fe188535ea62fae746f2d8fa3ed6 Mon Sep 17 00:00:00 2001 From: Thomas A Caswell Date: Mon, 3 Jul 2023 16:44:38 -0400 Subject: [PATCH 2/2] STY: attempt to placate mypy Bit confused why it was not hitting these before. --- lib/matplotlib/cm.pyi | 4 ++++ lib/matplotlib/pyplot.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/matplotlib/cm.pyi b/lib/matplotlib/cm.pyi index 5a90863dec41..0327affa32d9 100644 --- a/lib/matplotlib/cm.pyi +++ b/lib/matplotlib/cm.pyi @@ -24,6 +24,10 @@ class ScalarMappable: cmap: colors.Colormap | None colorbar: Colorbar | None callbacks: cbook.CallbackRegistry + + # private use + _A: ArrayLike | None + def __init__( self, norm: colors.Normalize | None = ..., diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 3d05e3cb6977..784f1bbb77e3 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -1086,12 +1086,12 @@ def draw(*, fig: Figure | None = None) -> None: @_copy_docstring_and_deprecators(Figure.savefig) def savefig(*args, **kwargs) -> None: + fig = kwargs.pop('fig', None) or gcf() # savefig default implementation has no return, so mypy is unhappy # presumably this is here because subclasses can return? - fig = kwargs.pop('fig', None) or gcf() - res = fig.savefig(*args, **kwargs) + res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value] fig.canvas.draw_idle() # Need this if 'transparent=True', to reset colors. - return res + return res # type: ignore[return-value] ## Putting things in figures ##