-
-
Notifications
You must be signed in to change notification settings - Fork 7.9k
[TYP] Add overload of pyplot.subplots
#27001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening your first PR into Matplotlib!
If you have not heard from us in a week or so, please leave a new comment below and that should bring it to our attention. Most of our reviewers are volunteers and sometimes things fall through the cracks.
You can also join us on gitter for real-time discussion.
For details on testing, writing docs, and our review process, please see the developer guide
We strive to be a welcoming and open project. Please follow our Code of Conduct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately ndarray doesn't qualify as a Collections.abc.Sequence
(see discussion in #26858) and so I'm reluctant to type it as something that it is not (It is already a thorn in our side already for inputs, I'd rather not add it on outputs).
See also numpy/numpy#24738 re adding the type of the contents of the ndarray.
Additionally, some of the errors could maybe be resolved by removing the = ...
for cases where the default value (1) would preclude that being left as default, though that may incur additional overloads for some cases of positional vs keyword provided args (e.g. passing just ncols=2
as keyword vs passing 1, 2
as positional) it gets a bit complicated...
I also think you are missing the case where nrows>1, ncols>1 and squeeze
is passed as True
. while squeeze has no effect then, it is valid to pass it in (and in fact is the default, which the type hints from this PR contradict)
In all, there is a limitation on static type hinting that it does not allow you to differentiate based on value of of the arguments, which is what is happening here. Literal
can help, but it still gets messy pretty quick, and can lead to false resolutions by type checkers. (e.g. if a variable that can be 1 but is only known to be an int is passed into this chain, the type checker would resolve to the unsqueezed overload, which is incorrect)
That is why we went with Any
to start.
I think it may be reasonable to have a scaled back version of this that that takes the following into account:
- squeeze=false always returns
ndarray
- nrows=ncols=Literal[1] (and squeeze=true) returns
Axes
- anything else returns union of
Axes | ndarray
(orAny
to avoid users having to type check when they know better than the type checker)
This is closer to what I did in Figure.subplots
, which for some reason I did not carry forward to pyplot... (here I only djid squeeze, not nrows=ncols=1):
matplotlib/lib/matplotlib/figure.pyi
Lines 95 to 121 in 6ba7d5f
def subplots( | |
self, | |
nrows: int = ..., | |
ncols: int = ..., | |
*, | |
sharex: bool | Literal["none", "all", "row", "col"] = ..., | |
sharey: bool | Literal["none", "all", "row", "col"] = ..., | |
squeeze: Literal[False], | |
width_ratios: ArrayLike | None = ..., | |
height_ratios: ArrayLike | None = ..., | |
subplot_kw: dict[str, Any] | None = ..., | |
gridspec_kw: dict[str, Any] | None = ... | |
) -> np.ndarray: ... | |
@overload | |
def subplots( | |
self, | |
nrows: int = ..., | |
ncols: int = ..., | |
*, | |
sharex: bool | Literal["none", "all", "row", "col"] = ..., | |
sharey: bool | Literal["none", "all", "row", "col"] = ..., | |
squeeze: bool = ..., | |
width_ratios: ArrayLike | None = ..., | |
height_ratios: ArrayLike | None = ..., | |
subplot_kw: dict[str, Any] | None = ..., | |
gridspec_kw: dict[str, Any] | None = ... | |
) -> np.ndarray | SubplotBase | Axes: ... |
This is somewhat unsatisfactory, but gives ways to gain type information as a caller in relatively explicit ways (e.g. pass squeeze=False to always get ndarray)
Literals are good in some regards, but even if you handle both Literal[True]
and Literal[False]
, the type checker may still complain when a bool
is passed in, because while you have covered all options, a bool variable is not actually either literal. So I have found myself adding a bool
overload with a union return type that I dislike, but covers the options to the satisfaction of the type checker.
07ec432
to
8e48282
Compare
a62d311
to
8e48282
Compare
Co-authored-by: Kyle Sunden <git@ksunden.space>
Thank you for the quick review! |
pyplot.subplot
pyplot.subplots
Note that this might also be better served by #25937. |
The motivation for this PR is that I noticed a loss in development experience when using mpl>=3.8 compared to mpl<3.8, which benefited from microsoft/python-type-stubs. |
So the microsoft ones contain 9 different (but mostly redundant by my analysis) overloaded signatures:
I think what is in this PR is better than what Microsoft had because it reduces redundant cases, and does not include incorrect overlapping cases (which I'm a little surprised doesn't flag for them, to be honest, I've definitely seen similar errors flagged by mypy) but still at least carves out where we can tell based on input types and the most common calls. Additionally, the microsoft ones are missing the I would like to see the three cases from this PR also included in I will note that they have a matplotlib/lib/matplotlib/figure.pyi Line 81 in 9cd2812
|
there are a lot of links and comments. |
I would like to see the overloads unified to the three cases from this PR in all three places where
|
dc387b9
to
c13996b
Compare
|
…001-on-v3.9.x Backport PR #27001 on branch v3.9.x ([TYP] Add overload of `pyplot.subplots`)
This PR does not work as intended. Using the following code from here: import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)
fig, axs = plt.subplots(2) # or nrows=2
fig.suptitle('Vertically stacked subplots')
axs[0].plot(x, y)
axs[1].plot(x, -y) mypy complains with: > mypy test.py
test.py:8: error: Value of type "Axes | ndarray[Any, Any]" is not indexable [index]
test.py:9: error: Value of type "Axes | ndarray[Any, Any]" is not indexable [index]
Found 2 errors in 1 file (checked 1 source file) There seems to be an issue with the overload. If we can figure that out, great. If not, we will have to revert back to How did I notice this? We use multiple subplots quite heavily in TorchGeo. The PR that updated matplotlib from 3.9.0 to 3.9.1 results in 400+ mypy errors as a result of this PR. We could add type hints to all 400+ locations to clarify that it is indeed a numpy array, but given that we're talking about 400+ locations, I would much rather fix this in 1 location in matplotlib instead. P.S. Let me know if you want me to open a formal issue for this. I thought it would be easier to get the attention of everyone involved in this PR by commenting directly on the PR, but I realize it's easier to track bug reports by having a formal issue opened. |
Possible solution: explicitly passing |
Please review my fix in #28518, it passes all of my tests but more eyes are always better. |
PR summary
Change the type hint of the return value of
pyplot.subplots
fromAny
toAxes
,ndarray
, orAxes | ndarray
.PR checklist