From ccc61cb2b818bbc2c9e9e2e83ebf11000760a7da Mon Sep 17 00:00:00 2001 From: Kyle Sunden Date: Wed, 29 May 2024 20:08:27 -0500 Subject: [PATCH] Backport PR #27001: [TYP] Add overload of `pyplot.subplots` --- lib/matplotlib/figure.pyi | 34 +++++++++++++++++-------- lib/matplotlib/gridspec.pyi | 2 +- lib/matplotlib/pyplot.py | 51 +++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/lib/matplotlib/figure.pyi b/lib/matplotlib/figure.pyi index eae21c2614f0..21de9159d56c 100644 --- a/lib/matplotlib/figure.pyi +++ b/lib/matplotlib/figure.pyi @@ -1,12 +1,12 @@ from collections.abc import Callable, Hashable, Iterable import os -from typing import Any, IO, Literal, TypeVar, overload +from typing import Any, IO, Literal, Sequence, TypeVar, overload import numpy as np from numpy.typing import ArrayLike from matplotlib.artist import Artist -from matplotlib.axes import Axes, SubplotBase +from matplotlib.axes import Axes from matplotlib.backend_bases import ( FigureCanvasBase, MouseButton, @@ -92,6 +92,20 @@ class FigureBase(Artist): @overload def add_subplot(self, **kwargs) -> Axes: ... @overload + def subplots( + self, + nrows: Literal[1] = ..., + ncols: Literal[1] = ..., + *, + sharex: bool | Literal["none", "all", "row", "col"] = ..., + sharey: bool | Literal["none", "all", "row", "col"] = ..., + squeeze: Literal[True] = ..., + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + ) -> Axes: ... + @overload def subplots( self, nrows: int = ..., @@ -100,11 +114,11 @@ class FigureBase(Artist): sharex: bool | Literal["none", "all", "row", "col"] = ..., sharey: bool | Literal["none", "all", "row", "col"] = ..., squeeze: Literal[False], - width_ratios: ArrayLike | None = ..., - height_ratios: ArrayLike | None = ..., + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., subplot_kw: dict[str, Any] | None = ..., - gridspec_kw: dict[str, Any] | None = ... - ) -> np.ndarray: ... + gridspec_kw: dict[str, Any] | None = ..., + ) -> np.ndarray: ... # TODO numpy/numpy#24738 @overload def subplots( self, @@ -114,11 +128,11 @@ class FigureBase(Artist): sharex: bool | Literal["none", "all", "row", "col"] = ..., sharey: bool | Literal["none", "all", "row", "col"] = ..., squeeze: bool = ..., - width_ratios: ArrayLike | None = ..., - height_ratios: ArrayLike | None = ..., + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., subplot_kw: dict[str, Any] | None = ..., - gridspec_kw: dict[str, Any] | None = ... - ) -> np.ndarray | SubplotBase | Axes: ... + gridspec_kw: dict[str, Any] | None = ..., + ) -> Axes | np.ndarray: ... def delaxes(self, ax: Axes) -> None: ... def clear(self, keep_observers: bool = ...) -> None: ... def clf(self, keep_observers: bool = ...) -> None: ... diff --git a/lib/matplotlib/gridspec.pyi b/lib/matplotlib/gridspec.pyi index 1ac1bb0b40e7..b6732ad8fafa 100644 --- a/lib/matplotlib/gridspec.pyi +++ b/lib/matplotlib/gridspec.pyi @@ -54,7 +54,7 @@ class GridSpecBase: sharey: bool | Literal["all", "row", "col", "none"] = ..., squeeze: Literal[True] = ..., subplot_kw: dict[str, Any] | None = ... - ) -> np.ndarray | SubplotBase | Axes: ... + ) -> np.ndarray | Axes: ... class GridSpec(GridSpecBase): left: float | None diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 925322cdd1e5..a3ce60f01ef5 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -1548,6 +1548,57 @@ def subplot(*args, **kwargs) -> Axes: return ax +@overload +def subplots( + nrows: Literal[1] = ..., + ncols: Literal[1] = ..., + *, + sharex: bool | Literal["none", "all", "row", "col"] = ..., + sharey: bool | Literal["none", "all", "row", "col"] = ..., + squeeze: Literal[True] = ..., + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + **fig_kw +) -> tuple[Figure, Axes]: + ... + + +@overload +def subplots( + nrows: int = ..., + ncols: int = ..., + *, + sharex: bool | Literal["none", "all", "row", "col"] = ..., + sharey: bool | Literal["none", "all", "row", "col"] = ..., + squeeze: Literal[False], + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + **fig_kw +) -> tuple[Figure, np.ndarray]: # TODO numpy/numpy#24738 + ... + + +@overload +def subplots( + nrows: int = ..., + ncols: int = ..., + *, + sharex: bool | Literal["none", "all", "row", "col"] = ..., + sharey: bool | Literal["none", "all", "row", "col"] = ..., + squeeze: bool = ..., + width_ratios: Sequence[float] | None = ..., + height_ratios: Sequence[float] | None = ..., + subplot_kw: dict[str, Any] | None = ..., + gridspec_kw: dict[str, Any] | None = ..., + **fig_kw +) -> tuple[Figure, Axes | np.ndarray]: + ... + + def subplots( nrows: int = 1, ncols: int = 1, *, sharex: bool | Literal["none", "all", "row", "col"] = False,