-
Notifications
You must be signed in to change notification settings - Fork 24.1k
Fix typing errors in the torch.distributions module #45689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
mypy.ini
Outdated
[mypy-torch.jit.quantized] | ||
ignore_errors = True | ||
|
||
[mypy-torch.nn.functional] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these actually intended to be added? or is this a merge master issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are added because after annotating the distributions.* dir, a lot of typing errors appear under these two modules. Maybe it's because master update. I am looking into this.
torch/distributions/distribution.py
Outdated
@@ -71,7 +72,7 @@ def event_shape(self): | |||
""" | |||
return self._event_shape | |||
|
|||
@property | |||
@property # type: ignore[no-redef] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a mypy git issue link? (is it: python/mypy#6185?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have a mypy git issue link? (is it: python/mypy#6185?)
No it is not - this is simply because the name "arg_constraints" is reused as a member variable and a method name at the same time. I don't find a related mypy git issue. This could be resolved by renaming one of them, but that would introduce too much code change IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to put a comment to explain this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, if arg_constraints is an internal variable, perhaps an underscore prefix should be added to it? ( Which could be done as a separate PR)
torch/distributions/multinomial.py
Outdated
@@ -2,7 +2,7 @@ | |||
from torch._six import inf | |||
from torch.distributions.distribution import Distribution | |||
from torch.distributions import Categorical | |||
from numbers import Number | |||
import numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like only Integral
was used? should we just keep original format?
import numbers | |
from numbers import Integral |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like only
Integral
was used? should we just keep original format?
Nice catch. Will update.
torch/distributions/multinomial.py
Outdated
@@ -88,7 +89,7 @@ def param_shape(self): | |||
|
|||
def sample(self, sample_shape=torch.Size()): | |||
sample_shape = torch.Size(sample_shape) | |||
samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape) | |||
samples = self._categorical.sample(torch.Size((int(self.total_count),).__iter__()) + sample_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if we can remove __iter__()
seems like (int(val),)
should return a Tuple[int]
type. yes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right! I have removed it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's correct! Will update
@@ -94,7 +96,7 @@ class lazy_property(object): | |||
""" | |||
def __init__(self, wrapped): | |||
self.wrapped = wrapped | |||
update_wrapper(self, wrapped) | |||
update_wrapper(self, wrapped) # type: ignore[arg-type] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a link to the reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixing this error requires annotating a built-in python package(functools), therefore I just ignore this error.
More details: update_wrapper(wrapper: [Callable, ...], wrapped) is a functools function. Here the Argument 1 is "self", which is "class lazy_property", therefore the type mismatches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix this problem we need to annotate the update_wrapper(wrapper, wrapped), which is in a third-party python package(functools). The function's first parameter is "[Callable, ...]" whereas the Argument 1 here("self") is type "class lazy_property(object)".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch/distributions/distribution.py
Outdated
@@ -71,7 +72,7 @@ def event_shape(self): | |||
""" | |||
return self._event_shape | |||
|
|||
@property | |||
@property # type: ignore[no-redef] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to put a comment to explain this
torch/distributions/normal.py
Outdated
@@ -1,4 +1,5 @@ | |||
import math | |||
import numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
import numbers | |
from numbers import Real |
torch/distributions/beta.py
Outdated
@@ -1,3 +1,4 @@ | |||
import numbers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
import numbers | |
from numbers import Real |
182a3a2
to
f1e9efc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Codecov Report
@@ Coverage Diff @@
## master #45689 +/- ##
==========================================
+ Coverage 68.25% 68.26% +0.01%
==========================================
Files 410 410
Lines 53246 53266 +20
==========================================
+ Hits 36344 36363 +19
- Misses 16902 16903 +1
Continue to review full report at Codecov.
|
torch/distributions/__init__.py
Outdated
@@ -155,4 +156,4 @@ | |||
'register_kl', | |||
'transform_to', | |||
] | |||
__all__.extend(transforms.__all__) | |||
__all__.extend(transform_all) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a comment, why this is needed_
torch/distributions/distribution.py
Outdated
@@ -71,7 +72,7 @@ def event_shape(self): | |||
""" | |||
return self._event_shape | |||
|
|||
@property | |||
@property # type: ignore[no-redef] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, if arg_constraints is an internal variable, perhaps an underscore prefix should be added to it? ( Which could be done as a separate PR)
torch/distributions/distribution.py
Outdated
@@ -81,7 +83,8 @@ def arg_constraints(self): | |||
""" | |||
raise NotImplementedError | |||
|
|||
@property | |||
# Ignore the mypy type error caused by redefining `support` as a method | |||
@property # type: ignore[no-redef] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above, perhaps the variable should be renamed to something else.
torch/distributions/transforms.py
Outdated
@@ -500,8 +511,8 @@ def __eq__(self, other): | |||
|
|||
@property | |||
def sign(self): | |||
if isinstance(self.scale, numbers.Number): | |||
return 1 if self.scale > 0 else -1 if self.scale < 0 else 0 | |||
if isinstance(self.scale, numbers.Integral): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, changing Number
to Integral
here would lead to a crash if scale is float, right? (because there are no sign
method for the float type, is there?
f1e9efc
to
2e718f1
Compare
torch/distributions/multinomial.py
Outdated
@@ -50,7 +51,7 @@ def variance(self): | |||
return self.total_count * self.probs * (1 - self.probs) | |||
|
|||
def __init__(self, total_count=1, probs=None, logits=None, validate_args=None): | |||
if not isinstance(total_count, Number): | |||
if not isinstance(total_count, Integral): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that total_count
must be an integral value?
torch/distributions/multinomial.py
Outdated
@@ -40,6 +40,7 @@ class Multinomial(Distribution): | |||
""" | |||
arg_constraints = {'probs': constraints.simplex, | |||
'logits': constraints.real} | |||
total_count: Integral |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this one can not be just int
?
total_count: Integral | |
total_count: int |
f32e6de
to
47cb358
Compare
💊 CI failures summary and remediationsAs of commit 3928b46 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 7 times. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
6ff4772
to
3928b46
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in #45689. Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar. Pull Request resolved: #73747 Reviewed By: mruberry Differential Revision: D34649927 Pulled By: neerajprad fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911
Summary: Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in #45689. Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar. Pull Request resolved: #73747 Reviewed By: mruberry Differential Revision: D34649927 Pulled By: neerajprad fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911 (cherry picked from commit cec256c)
Summary: Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in pytorch/pytorch#45689. Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar. Pull Request resolved: pytorch/pytorch#73747 Reviewed By: mruberry Differential Revision: D34649927 Pulled By: neerajprad fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911 (cherry picked from commit cec256c3242d1cf55073a980060af87c1fd59ac9)
Summary: Fixes hiding of the `CatTransform` docstring due to a misplaced type annotation in pytorch/pytorch#45689. Also fixes that annotation and adds to to `StackTransform` to keep `CatTransform` similar. Pull Request resolved: pytorch/pytorch#73747 Reviewed By: mruberry Differential Revision: D34649927 Pulled By: neerajprad fbshipit-source-id: e4594fd76020a37ac4cada88e5fb08191984e911 (cherry picked from commit cec256c3242d1cf55073a980060af87c1fd59ac9)
Fixes #42979.