Skip to content

iter() and reversed() do not raise StopIteration when exhausted in torch.compile #152262

@guilhermeleobas

Description

@guilhermeleobas

🐛 Describe the bug

The expected behavior is to raise StopIteration after the iterator is exhausted, but inside Dynamo, the iterator is not being properly exhausted when (force_)unpack_var_sequence(...) is called.

Reproducer:

import torch

@torch.compile(backend="eager", fullgraph=True)
def foo_iter(t):
    it = iter([1, 2, 3])
    _ = list(it)        # consume all elements
    try:
        next(it)
    except StopIteration:
        return t.sin()
    else:
        assert False, "Expected StopIteration"


@torch.compile(backend="eager", fullgraph=True)
def foo_reversed(t):
    rev = reversed([1, 2, 3])
    _ = list(rev)       # consume all elements
    try:
        next(rev)
    except StopIteration:
        return t.sin()
    else:
        assert False, "Expected StopIteration"


t = torch.tensor([1.0])
assert foo_iter(t) == t.sin()
assert foo_reversed(t) == t.sin()

Versions

PyTorch main

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions