-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Open
Open
Copy link
Labels
module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
The expected behavior is to raise StopIteration
after the iterator is exhausted, but inside Dynamo, the iterator is not being properly exhausted when .(force_)unpack_var_sequence(...)
is called
Reproducer:
import torch
@torch.compile(backend="eager", fullgraph=True)
def foo_iter(t):
it = iter([1, 2, 3])
_ = list(it) # consume all elements
try:
next(it)
except StopIteration:
return t.sin()
else:
assert False, "Expected StopIteration"
@torch.compile(backend="eager", fullgraph=True)
def foo_reversed(t):
rev = reversed([1, 2, 3])
_ = list(rev) # consume all elements
try:
next(rev)
except StopIteration:
return t.sin()
else:
assert False, "Expected StopIteration"
t = torch.tensor([1.0])
assert foo_iter(t) == t.sin()
assert foo_reversed(t) == t.sin()
Versions
PyTorch main
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
Metadata
Metadata
Assignees
Labels
module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module