Skip to content

[TYP] Add overload of pyplot.subplots #27001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 30, 2024

Conversation

n-takumasa
Copy link
Contributor

@n-takumasa n-takumasa commented Oct 5, 2023

PR summary

Change the type hint of the return value of pyplot.subplots from Any to Axes, ndarray, or Axes | ndarray.

PR checklist

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening your first PR into Matplotlib!

If you have not heard from us in a week or so, please leave a new comment below and that should bring it to our attention. Most of our reviewers are volunteers and sometimes things fall through the cracks.

You can also join us on gitter for real-time discussion.

For details on testing, writing docs, and our review process, please see the developer guide

We strive to be a welcoming and open project. Please follow our Code of Conduct.

Copy link
Member

@ksunden ksunden left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately ndarray doesn't qualify as a Collections.abc.Sequence (see discussion in #26858) and so I'm reluctant to type it as something that it is not (It is already a thorn in our side already for inputs, I'd rather not add it on outputs).

See also numpy/numpy#24738 re adding the type of the contents of the ndarray.

Additionally, some of the errors could maybe be resolved by removing the = ... for cases where the default value (1) would preclude that being left as default, though that may incur additional overloads for some cases of positional vs keyword provided args (e.g. passing just ncols=2 as keyword vs passing 1, 2 as positional) it gets a bit complicated...

I also think you are missing the case where nrows>1, ncols>1 and squeeze is passed as True. while squeeze has no effect then, it is valid to pass it in (and in fact is the default, which the type hints from this PR contradict)

In all, there is a limitation on static type hinting that it does not allow you to differentiate based on value of of the arguments, which is what is happening here. Literal can help, but it still gets messy pretty quick, and can lead to false resolutions by type checkers. (e.g. if a variable that can be 1 but is only known to be an int is passed into this chain, the type checker would resolve to the unsqueezed overload, which is incorrect)

That is why we went with Any to start.

I think it may be reasonable to have a scaled back version of this that that takes the following into account:

  • squeeze=false always returns ndarray
  • nrows=ncols=Literal[1] (and squeeze=true) returns Axes
  • anything else returns union of Axes | ndarray (or Any to avoid users having to type check when they know better than the type checker)

This is closer to what I did in Figure.subplots, which for some reason I did not carry forward to pyplot... (here I only djid squeeze, not nrows=ncols=1):

def subplots(
self,
nrows: int = ...,
ncols: int = ...,
*,
sharex: bool | Literal["none", "all", "row", "col"] = ...,
sharey: bool | Literal["none", "all", "row", "col"] = ...,
squeeze: Literal[False],
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...
) -> np.ndarray: ...
@overload
def subplots(
self,
nrows: int = ...,
ncols: int = ...,
*,
sharex: bool | Literal["none", "all", "row", "col"] = ...,
sharey: bool | Literal["none", "all", "row", "col"] = ...,
squeeze: bool = ...,
width_ratios: ArrayLike | None = ...,
height_ratios: ArrayLike | None = ...,
subplot_kw: dict[str, Any] | None = ...,
gridspec_kw: dict[str, Any] | None = ...
) -> np.ndarray | SubplotBase | Axes: ...

This is somewhat unsatisfactory, but gives ways to gain type information as a caller in relatively explicit ways (e.g. pass squeeze=False to always get ndarray)

Literals are good in some regards, but even if you handle both Literal[True] and Literal[False], the type checker may still complain when a bool is passed in, because while you have covered all options, a bool variable is not actually either literal. So I have found myself adding a bool overload with a union return type that I dislike, but covers the options to the satisfaction of the type checker.

@n-takumasa n-takumasa force-pushed the fix-subplots-typing branch from 07ec432 to 8e48282 Compare October 5, 2023 23:33
@n-takumasa n-takumasa force-pushed the fix-subplots-typing branch from a62d311 to 8e48282 Compare October 5, 2023 23:41
Co-authored-by: Kyle Sunden <git@ksunden.space>
@n-takumasa
Copy link
Contributor Author

Thank you for the quick review!
Adding variables as arguments is not something I typically do, so it was a complete oversight on my part.
I would have preferred the Axes to be inferred when indexing, but I also understand the limitation and previous discussions, so I also think a scaled back version is reasonable.

@QuLogic QuLogic changed the title [TYP] Add overload of pyplot.subplot [TYP] Add overload of pyplot.subplots Oct 6, 2023
@QuLogic
Copy link
Member

QuLogic commented Oct 6, 2023

Note that this might also be better served by #25937.

@n-takumasa
Copy link
Contributor Author

The motivation for this PR is that I noticed a loss in development experience when using mpl>=3.8 compared to mpl<3.8, which benefited from microsoft/python-type-stubs.
I think it would be great to define a class better than ndarray, but it would be painful to develop with object-oriented-interface fig, ax = plt.subplots() until it it implemented.

@jklymak
Copy link
Member

jklymak commented Oct 7, 2023

@ksunden
Copy link
Member

ksunden commented Oct 11, 2023

So the microsoft ones contain 9 different (but mostly redundant by my analysis) overloaded signatures:

A) squeeze=false, nrows=ncols=int -> ndarray
   - One of the cases in this PR
    
B) squeeze=True, nrows=ncols=1 -> Axes
   - One of the cases in this PR
            
C) Squeeze=True, nrows=1 ncols=int -> ndarray
   - Technically wrong, to my understanding, as ncols can still be 1 but only known to be int at type check time
         
D) Squeeze=True, nrows=int ncols=1 -> ndarray
   - Technically wrong, to my understanding, for the same reason
    
E) Squeeze=True, nrows=ncols=int; keyword only -> ndarray
   - Technically wrong, to my understanding
   - redundant to nrows/ncols not being kwonly  (A/I)

F) squeeze=True, nrows omitted, ncols=int; kwonly -> ndarray
   - redundant to C
  
G) squeeze=True, nrows=int, ncols omitted -> ndarray
   - redundant to D

H) squeeze=True, nrows and ncols omitted -> Axes
   - redundant to B
    
I) squeeze=bool, nrows=ncols-int -> ndarray
   - Technically wrong, as squeeze=True, nrows=1,ncols=1 will match but not return an ndarray
   - Closest to the last one from this PR (though there is a union to ensure technically correct)

I think what is in this PR is better than what Microsoft had because it reduces redundant cases, and does not include incorrect overlapping cases (which I'm a little surprised doesn't flag for them, to be honest, I've definitely seen similar errors flagged by mypy) but still at least carves out where we can tell based on input types and the most common calls.

Additionally, the microsoft ones are missing the width/height_ratios arguments and do not specify that the dictionaries have string keys.

I would like to see the three cases from this PR also included in figure.pyi and gridspec.pyi versions of subplots. (These already have 2 overloads, but could be made better by including the Literal[1] case)

I will note that they have a SubplotBase as a possible return, which should also be removed. It doesn't seem to fail mypy without it, and seems to come back to Figure.add_subplot's docstring which stated (but has now removed) it was a possible return, but then it was omitted from the type hints with a comment:

# TODO: docstring indicates SubplotSpec a valid arg, but none of the listed signatures appear to be that

@dlefcoe
Copy link

dlefcoe commented May 12, 2024

there are a lot of links and comments.
please advise if the original request is complete...

@ksunden
Copy link
Member

ksunden commented May 13, 2024

I would like to see the overloads unified to the three cases from this PR in all three places where subplots method exists:

@n-takumasa n-takumasa force-pushed the fix-subplots-typing branch from dc387b9 to c13996b Compare May 14, 2024 03:19
@n-takumasa
Copy link
Contributor Author

GridSpecBase.subplots has no arguments for narrowing :(

@ksunden ksunden added this to the v3.9.1 milestone May 30, 2024
@ksunden ksunden merged commit a833d99 into matplotlib:main May 30, 2024
37 of 42 checks passed
meeseeksmachine pushed a commit to meeseeksmachine/matplotlib that referenced this pull request May 30, 2024
QuLogic added a commit that referenced this pull request May 30, 2024
…001-on-v3.9.x

Backport PR #27001 on branch v3.9.x ([TYP] Add overload of `pyplot.subplots`)
@adamjstewart
Copy link
Contributor

This PR does not work as intended. Using the following code from here:

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)
fig, axs = plt.subplots(2)  # or nrows=2
fig.suptitle('Vertically stacked subplots')
axs[0].plot(x, y)
axs[1].plot(x, -y)

mypy complains with:

> mypy test.py
test.py:8: error: Value of type "Axes | ndarray[Any, Any]" is not indexable  [index]
test.py:9: error: Value of type "Axes | ndarray[Any, Any]" is not indexable  [index]
Found 2 errors in 1 file (checked 1 source file)

There seems to be an issue with the overload. If we can figure that out, great. If not, we will have to revert back to Any and treat this function as being dynamic (static type analysis is not possible).

How did I notice this? We use multiple subplots quite heavily in TorchGeo. The PR that updated matplotlib from 3.9.0 to 3.9.1 results in 400+ mypy errors as a result of this PR. We could add type hints to all 400+ locations to clarify that it is indeed a numpy array, but given that we're talking about 400+ locations, I would much rather fix this in 1 location in matplotlib instead.

P.S. Let me know if you want me to open a formal issue for this. I thought it would be easier to get the attention of everyone involved in this PR by commenting directly on the PR, but I realize it's easier to track bug reports by having a formal issue opened.

@adamjstewart
Copy link
Contributor

adamjstewart commented Jul 6, 2024

Possible solution: explicitly passing squeeze=False makes type checking much easier. I would also be okay with using a return type of np.ndarray for squeeze=False and Any for squeeze=True if we can't get nrows/ncols working properly with static type checking.

@adamjstewart
Copy link
Contributor

Please review my fix in #28518, it passes all of my tests but more eyes are always better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

6 participants