Skip to content

seaborn: complete and fix axisgrid module #11096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions stubs/seaborn/@tests/stubtest_allowlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ seaborn.external.docscrape.ClassDoc.__init__ # stubtest doesn't like ABC class
seaborn.external.docscrape.NumpyDocString.__str__ # weird signature

seaborn(\.regression)?\.lmplot # the `data` argument is required but it defaults to `None` at runtime

seaborn.axisgrid.Grid.tight_layout # the method doesn't really take pos args but runtime has *args
13 changes: 9 additions & 4 deletions stubs/seaborn/seaborn/_core/typing.pyi
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from _typeshed import Incomplete
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from datetime import date, datetime, timedelta
from typing import Any
from typing import Any, Protocol
from typing_extensions import TypeAlias

from matplotlib.colors import Colormap, Normalize
from numpy import ndarray
from pandas import Index, Series, Timedelta, Timestamp
from pandas import DataFrame, Index, Series, Timedelta, Timestamp

class _SupportsDataFrame(Protocol):
# `__dataframe__` should return pandas.core.interchange.dataframe_protocol.DataFrame
# but this class needs to be defined as a Protocol, not as an ABC.
def __dataframe__(self, nan_as_null: bool = ..., allow_copy: bool = ...) -> Incomplete: ...

ColumnName: TypeAlias = str | bytes | date | datetime | timedelta | bool | complex | Timestamp | Timedelta
Vector: TypeAlias = Series[Any] | Index[Any] | ndarray[Any, Any]
VariableSpec: TypeAlias = ColumnName | Vector | None
VariableSpecList: TypeAlias = list[VariableSpec] | Index[Any] | None
DataSource: TypeAlias = Incomplete
DataSource: TypeAlias = DataFrame | _SupportsDataFrame | Mapping[ColumnName, Incomplete] | None
OrderSpec: TypeAlias = Iterable[str] | None
NormSpec: TypeAlias = tuple[float | None, float | None] | Normalize | None
PaletteSpec: TypeAlias = str | list[Incomplete] | dict[Incomplete, Incomplete] | Colormap | None
Expand Down
245 changes: 201 additions & 44 deletions stubs/seaborn/seaborn/axisgrid.pyi
Original file line number Diff line number Diff line change
@@ -1,56 +1,207 @@
import os
from _typeshed import Incomplete
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any, TypeVar
from typing_extensions import Concatenate, Literal, ParamSpec, Self
from typing import IO, Any, TypeVar
from typing_extensions import Concatenate, Literal, ParamSpec, Self, TypeAlias

import numpy as np
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.backend_bases import MouseEvent, RendererBase
from matplotlib.colors import Colormap, Normalize
from matplotlib.figure import Figure
from matplotlib.font_manager import FontProperties
from matplotlib.gridspec import SubplotSpec
from matplotlib.legend import Legend
from matplotlib.patches import Patch
from matplotlib.path import Path as mpl_Path
from matplotlib.patheffects import AbstractPathEffect
from matplotlib.scale import ScaleBase
from matplotlib.text import Text
from matplotlib.typing import ColorType
from numpy.typing import NDArray
from matplotlib.transforms import Bbox, BboxBase, Transform, TransformedPath
from matplotlib.typing import ColorType, LineStyleType, MarkerType
from numpy.typing import ArrayLike, NDArray
from pandas import DataFrame, Series

from ._core.typing import ColumnName, DataSource, _SupportsDataFrame
from .palettes import _RGBColorPalette
from .utils import _Palette
from .utils import _DataSourceWideForm, _Palette, _Vector

__all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]

_P = ParamSpec("_P")
_R = TypeVar("_R")

_LiteralFont: TypeAlias = Literal["xx-small", "x-small", "small", "medium", "large", "x-large", "xx-large"]

class _BaseGrid:
def set(self, **kwargs: Incomplete) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.set`
def set(
self,
*,
# Keywords follow `matplotlib.axes.Axes.set`. Each keyword <KW> corresponds to a `set_<KW>` method
adjustable: Literal["box", "datalim"] = ...,
agg_filter: Callable[[ArrayLike, float], tuple[NDArray[np.floating[Any]], float, float]] | None = ...,
alpha: float | None = ...,
anchor: str | tuple[float, float] = ...,
animated: bool = ...,
aspect: float | Literal["auto", "equal"] = ...,
autoscale_on: bool = ...,
autoscalex_on: bool = ...,
autoscaley_on: bool = ...,
axes_locator: Callable[[Axes, RendererBase], Bbox] = ...,
axisbelow: bool | Literal["line"] = ...,
box_aspect: float | None = ...,
clip_box: BboxBase | None = ...,
clip_on: bool = ...,
clip_path: Patch | mpl_Path | TransformedPath | None = ...,
facecolor: ColorType | None = ...,
frame_on: bool = ...,
gid: str | None = ...,
in_layout: bool = ...,
label: object = ...,
mouseover: bool = ...,
navigate: bool = ...,
path_effects: list[AbstractPathEffect] = ...,
picker: bool | float | Callable[[Artist, MouseEvent], tuple[bool, dict[Any, Any]]] | None = ...,
position: Bbox | tuple[float, float, float, float] = ...,
prop_cycle: Incomplete = ..., # TODO: use cycler.Cycler when cycler gets typed
rasterization_zorder: float | None = ...,
rasterized: bool = ...,
sketch_params: float | None = ...,
snap: bool | None = ...,
subplotspec: SubplotSpec = ...,
title: str = ...,
transform: Transform | None = ...,
url: str | None = ...,
visible: bool = ...,
xbound: float | None | tuple[float | None, float | None] = ...,
xlabel: str = ...,
xlim: float | None | tuple[float | None, float | None] = ...,
xmargin: float = ...,
xscale: str | ScaleBase = ...,
xticklabels: Iterable[str | Text] = ...,
xticks: ArrayLike = ...,
ybound: float | None | tuple[float | None, float | None] = ...,
ylabel: str = ...,
ylim: float | None | tuple[float | None, float | None] = ...,
ymargin: float = ...,
yscale: str | ScaleBase = ...,
yticklabels: Iterable[str | Text] = ...,
yticks: ArrayLike = ...,
zorder: float = ...,
**kwargs: Any,
) -> Self: ...
@property
def fig(self) -> Figure: ...
@property
def figure(self) -> Figure: ...
def apply(self, func: Callable[Concatenate[Self, _P], object], *args: _P.args, **kwargs: _P.kwargs) -> Self: ...
def pipe(self, func: Callable[Concatenate[Self, _P], _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
def savefig(
self, *args: Incomplete, **kwargs: Incomplete
) -> None: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.savefig`
self,
# Signature follows `matplotlib.figure.Figure.savefig`
fname: str | os.PathLike[Any] | IO[Any],
*,
transparent: bool | None = None,
dpi: float | Literal["figure"] | None = 96,
facecolor: ColorType | Literal["auto"] | None = "auto",
edgecolor: ColorType | Literal["auto"] | None = "auto",
orientation: Literal["landscape", "portrait"] = "portrait",
format: str | None = None,
bbox_inches: Literal["tight"] | Bbox | None = "tight",
pad_inches: float | Literal["layout"] | None = None,
backend: str | None = None,
**kwargs: Any,
) -> None: ...

class Grid(_BaseGrid):
def __init__(self) -> None: ...
def tight_layout(
self, *args: Incomplete, **kwargs: Incomplete
) -> Self: ... # *args and **kwargs are passed to `matplotlib.figure.Figure.tight_layout`
self,
*,
# Keywords follow `matplotlib.figure.Figure.tight_layout`
pad: float = 1.08,
h_pad: float | None = None,
w_pad: float | None = None,
rect: tuple[float, float, float, float] | None = None,
) -> Self: ...
def add_legend(
self,
legend_data: Mapping[Any, Artist] | None = None, # cannot use precise key type because of invariant Mapping keys
# Cannot use precise key type with union for legend_data because of invariant Mapping keys
legend_data: Mapping[Any, Artist] | None = None,
title: str | None = None,
label_order: list[str] | None = None,
adjust_subtitles: bool = False,
**kwargs: Incomplete, # **kwargs are passed to `matplotlib.figure.Figure.legend`
*,
# Keywords follow `matplotlib.legend.Legend`
loc: str | int | tuple[float, float] | None = None,
numpoints: int | None = None,
markerscale: float | None = None,
markerfirst: bool = True,
reverse: bool = False,
scatterpoints: int | None = None,
scatteryoffsets: Iterable[float] | None = None,
prop: FontProperties | dict[str, Any] | None = None,
fontsize: int | _LiteralFont | None = None,
labelcolor: str | Iterable[str] | None = None,
borderpad: float | None = None,
labelspacing: float | None = None,
handlelength: float | None = None,
handleheight: float | None = None,
handletextpad: float | None = None,
borderaxespad: float | None = None,
columnspacing: float | None = None,
ncols: int = 1,
mode: Literal["expand"] | None = None,
fancybox: bool | None = None,
shadow: bool | dict[str, float] | None = None,
title_fontsize: int | _LiteralFont | None = None,
framealpha: float | None = None,
edgecolor: ColorType | None = None,
facecolor: ColorType | None = None,
bbox_to_anchor: BboxBase | tuple[float, float] | tuple[float, float, float, float] | None = None,
bbox_transform: Transform | None = None,
frameon: bool | None = None,
handler_map: None = None,
title_fontproperties: FontProperties | None = None,
alignment: Literal["center", "left", "right"] = "center",
ncol: int = 1,
draggable: bool = False,
) -> Self: ...
@property
def legend(self) -> Legend | None: ...
def tick_params(
self, axis: Literal["x", "y", "both"] = "both", **kwargs: Incomplete
) -> Self: ... # **kwargs are passed to `matplotlib.axes.Axes.tick_params`
self,
axis: Literal["x", "y", "both"] = "both",
*,
# Keywords follow `matplotlib.axes.Axes.tick_params`
which: Literal["major", "minor", "both"] = "major",
reset: bool = False,
direction: Literal["in", "out", "inout"] = ...,
length: float = ...,
width: float = ...,
color: ColorType = ...,
pad: float = ...,
labelsize: float | str = ...,
labelcolor: ColorType = ...,
labelfontfamily: str = ...,
colors: ColorType = ...,
zorder: float = ...,
bottom: bool = ...,
top: bool = ...,
left: bool = ...,
right: bool = ...,
labelbottom: bool = ...,
labeltop: bool = ...,
labelleft: bool = ...,
labelright: bool = ...,
labelrotation: float = ...,
grid_color: ColorType = ...,
grid_alpha: float = ...,
grid_linewidth: float = ...,
grid_linestyle: str = ...,
**kwargs: Any,
) -> Self: ...

class FacetGrid(Grid):
data: DataFrame
Expand All @@ -60,7 +211,7 @@ class FacetGrid(Grid):
hue_kws: dict[str, Any]
def __init__(
self,
data: DataFrame,
data: DataFrame | _SupportsDataFrame,
Comment on lines -63 to +214
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seaborn version 0.13 started supporting other data frames that support the __dataframe__ protocol: https://seaborn.pydata.org/whatsnew/v0.13.0.html#support-for-alternate-dataframe-libraries

*,
row: str | None = None,
col: str | None = None,
Expand Down Expand Up @@ -88,10 +239,10 @@ class FacetGrid(Grid):
def map(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
def map_dataframe(self, func: Callable[..., object], *args: str, **kwargs: Any) -> Self: ...
def facet_axis(self, row_i: int, col_j: int, modify_state: bool = True) -> Axes: ...
# `despine` should be kept roughly in line with `seaborn.utils.despine`
def despine(
self,
*,
fig: Figure | None = None,
ax: Axes | None = None,
top: bool = True,
right: bool = True,
Expand All @@ -111,7 +262,13 @@ class FacetGrid(Grid):
self, template: str | None = None, row_template: str | None = None, col_template: str | None = None, **kwargs: Any
) -> Self: ...
def refline(
self, *, x: float | None = None, y: float | None = None, color: ColorType = ".5", linestyle: str = "--", **line_kws: Any
self,
*,
x: float | None = None,
y: float | None = None,
color: ColorType = ".5",
linestyle: LineStyleType = "--",
**line_kws: Any,
) -> Self: ...
@property
def axes(self) -> NDArray[Incomplete]: ... # array of `Axes`
Expand All @@ -127,15 +284,15 @@ class PairGrid(Grid):
axes: NDArray[Incomplete] # two-dimensional array of `Axes`
data: DataFrame
diag_sharey: bool
diag_vars: NDArray[Incomplete] | None # array of `str`
diag_axes: NDArray[Incomplete] | None # array of `Axes`
diag_vars: list[str] | None
diag_axes: list[Axes] | None
Comment on lines -130 to +288
Copy link
Contributor Author

@hamdanal hamdanal Dec 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are lists, not sure why I typed them as arrays earlier

hue_names: list[str]
hue_vals: Series[Incomplete]
hue_vals: Series[Any]
hue_kws: dict[str, Any]
palette: _RGBColorPalette
def __init__(
self,
data: DataFrame,
data: DataFrame | _SupportsDataFrame,
*,
hue: str | None = None,
vars: Iterable[str] | None = None,
Expand All @@ -162,25 +319,25 @@ class JointGrid(_BaseGrid):
ax_joint: Axes
ax_marg_x: Axes
ax_marg_y: Axes
x: Series[Incomplete]
y: Series[Incomplete]
hue: Series[Incomplete]
x: Series[Any]
y: Series[Any]
hue: Series[Any]
def __init__(
self,
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
height: float = 6,
ratio: float = 5,
space: float = 0.2,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[str] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
dropna: bool = False,
xlim: Incomplete | None = None,
ylim: Incomplete | None = None,
xlim: float | tuple[float, float] | None = None,
ylim: float | tuple[float, float] | None = None,
marginal_ticks: bool = False,
) -> None: ...
def plot(self, joint_func: Callable[..., object], marginal_func: Callable[..., object], **kwargs: Any) -> Self: ...
Expand All @@ -194,7 +351,7 @@ class JointGrid(_BaseGrid):
joint: bool = True,
marginal: bool = True,
color: ColorType = ".5",
linestyle: str = "--",
linestyle: LineStyleType = "--",
**line_kws: Any,
) -> Self: ...
def set_axis_labels(self, xlabel: str = "", ylabel: str = "", **kwargs: Any) -> Self: ...
Expand All @@ -210,7 +367,7 @@ def pairplot(
y_vars: Iterable[str] | str | None = None,
kind: Literal["scatter", "kde", "hist", "reg"] = "scatter",
diag_kind: Literal["auto", "hist", "kde"] | None = "auto",
markers: Incomplete | None = None,
markers: MarkerType | list[MarkerType] | None = None,
height: float = 2.5,
aspect: float = 1,
corner: bool = False,
Expand All @@ -221,22 +378,22 @@ def pairplot(
size: float | None = None, # deprecated
) -> PairGrid: ...
def jointplot(
data: Incomplete | None = None,
data: DataSource | _DataSourceWideForm | None = None,
*,
x: Incomplete | None = None,
y: Incomplete | None = None,
hue: Incomplete | None = None,
kind: str = "scatter", # ideally Literal["scatter", "kde", "hist", "hex", "reg", "resid"] but it is checked with startswith
x: ColumnName | _Vector | None = None,
y: ColumnName | _Vector | None = None,
hue: ColumnName | _Vector | None = None,
kind: Literal["scatter", "kde", "hist", "hex", "reg", "resid"] = "scatter",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was changed to str earlier thanks to the feedback on the PR that introduced seaborn stubs. Upon further investigation, the function seems to only accepts this set of strings (also verified in the repl).

height: float = 6,
ratio: float = 5,
space: float = 0.2,
dropna: bool = False,
xlim: Incomplete | None = None,
ylim: Incomplete | None = None,
xlim: float | tuple[float, float] | None = None,
ylim: float | tuple[float, float] | None = None,
color: ColorType | None = None,
palette: _Palette | Colormap | None = None,
hue_order: Iterable[str] | None = None,
hue_norm: Incomplete | None = None,
hue_order: Iterable[ColumnName] | None = None,
hue_norm: tuple[float, float] | Normalize | None = None,
marginal_ticks: bool = False,
joint_kws: dict[str, Any] | None = None,
marginal_kws: dict[str, Any] | None = None,
Expand Down
Loading