Skip to content

Commit 90b67fd

Browse files
[generator] Raise StopIteration(value) with value from the return stmt
ghstack-source-id: cc02ea1 Pull Request resolved: #157152
1 parent 7570678 commit 90b67fd

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

test/dynamo/test_generator.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,50 @@ def fn(t):
14801480

14811481
self._compile_check(fn)
14821482

1483+
def test_return_value_in_except_and_finally(self):
1484+
def whoo():
1485+
try:
1486+
yield 1
1487+
except ValueError:
1488+
return 2 # noqa: B901
1489+
finally:
1490+
return 3 # noqa: B012, SIM107
1491+
1492+
def fn(t):
1493+
gen = whoo()
1494+
next(gen)
1495+
try:
1496+
gen.throw(ValueError)
1497+
except StopIteration as e:
1498+
assert e.args[0] == 3
1499+
except Exception as e:
1500+
raise AssertionError from e
1501+
return t.sin()
1502+
1503+
self._compile_check(fn)
1504+
1505+
def test_return_None_in_except_and_finally(self):
1506+
def whoo():
1507+
try:
1508+
yield 1
1509+
except ValueError:
1510+
return 2 # noqa: B901
1511+
finally:
1512+
return # noqa: B012, SIM107
1513+
1514+
def fn(t):
1515+
gen = whoo()
1516+
next(gen)
1517+
try:
1518+
gen.throw(ValueError)
1519+
except StopIteration as e:
1520+
assert len(e.args) == 0
1521+
except Exception as e:
1522+
raise AssertionError from e
1523+
return t.sin()
1524+
1525+
self._compile_check(fn)
1526+
14831527

14841528
instantiate_parametrized_tests(GeneratorTests)
14851529
instantiate_parametrized_tests(TestGeneratorSend)

torch/_dynamo/symbolic_convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3972,7 +3972,13 @@ def inline_call_(self):
39723972
):
39733973
assert isinstance(self, InliningGeneratorInstructionTranslator)
39743974
# When the generator returns None, we raise StopIteration
3975-
exc.raise_observed_exception(StopIteration, self)
3975+
args = []
3976+
if (
3977+
isinstance(self.symbolic_result, ConstantVariable)
3978+
and self.symbolic_result.value is not None
3979+
):
3980+
args = [self.symbolic_result]
3981+
exc.raise_observed_exception(StopIteration, self, args=args)
39763982
else:
39773983
return self.symbolic_result
39783984
else:

0 commit comments

Comments
 (0)