Skip to content

Commit a0207c8

Browse files
janselpytorchmergebot
authored andcommitted
[dynamo] Fix support for classmethod(property(...)) (#134968)
Fixes #134451 Pull Request resolved: #134968 Approved by: https://github.com/yanboliang
1 parent 9aa22ea commit a0207c8

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

test/dynamo/test_repros.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC
2121
from collections import namedtuple
2222
from copy import deepcopy
23-
from enum import Enum
23+
from enum import Enum, IntEnum
2424
from functools import wraps
2525
from typing import Any, Dict, Iterator, List, Tuple
2626
from unittest import mock
@@ -457,7 +457,7 @@ def forward(
457457
if past_key_value is not None:
458458
assert (
459459
len(past_key_value) == 2
460-
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
460+
), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
461461
real_seq_length += (
462462
past_key_value[0].shape[2] if query_length is None else query_length
463463
)
@@ -4546,6 +4546,84 @@ def f(*args):
45464546
f(*args)
45474547
self.assertEqual(num_compiles, 1)
45484548

4549+
def test_issue134451(self):
4550+
class BoundingBox2DIndex(IntEnum):
4551+
_X = 0
4552+
_Y = 1
4553+
_HEADING = 2
4554+
_LENGTH = 3
4555+
_WIDTH = 4
4556+
4557+
@classmethod
4558+
def size(cls):
4559+
return 5
4560+
4561+
@classmethod
4562+
@property
4563+
def X(cls):
4564+
return cls._X
4565+
4566+
@classmethod
4567+
@property
4568+
def Y(cls):
4569+
return cls._Y
4570+
4571+
@classmethod
4572+
@property
4573+
def HEADING(cls):
4574+
return cls._HEADING
4575+
4576+
@classmethod
4577+
@property
4578+
def LENGTH(cls):
4579+
return cls._LENGTH
4580+
4581+
@classmethod
4582+
@property
4583+
def WIDTH(cls):
4584+
return cls._WIDTH
4585+
4586+
@classmethod
4587+
@property
4588+
def POINT(cls):
4589+
# assumes X, Y have subsequent indices
4590+
return slice(cls._X, cls._Y + 1)
4591+
4592+
@classmethod
4593+
@property
4594+
def STATE_SE2(cls):
4595+
# assumes X, Y, HEADING have subsequent indices
4596+
return slice(cls._X, cls._HEADING + 1)
4597+
4598+
class SimpleModel(nn.Module):
4599+
def __init__(self):
4600+
super().__init__()
4601+
self._mlp_states = nn.Sequential(
4602+
nn.Linear(10, 20),
4603+
nn.ReLU(),
4604+
nn.Linear(20, BoundingBox2DIndex.size()),
4605+
)
4606+
4607+
def forward(self, x):
4608+
agent_states = self._mlp_states(x)
4609+
agent_states[..., BoundingBox2DIndex.POINT] = (
4610+
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
4611+
)
4612+
agent_states[..., BoundingBox2DIndex.HEADING] = (
4613+
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
4614+
)
4615+
return agent_states
4616+
4617+
model = SimpleModel().eval()
4618+
input_tensor = torch.randn(1, 10, dtype=torch.float32)
4619+
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
4620+
actual = opt(input_tensor)
4621+
try:
4622+
expected = model(input_tensor)
4623+
except Exception as e:
4624+
raise unittest.SkipTest("eager failed, requires Python>=3.12") from e
4625+
self.assertEqual(actual, expected)
4626+
45494627
def test_invalid_seq_unpack(self):
45504628
def myfn(arg):
45514629
(a, b) = arg

torch/_dynamo/variables/constant.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,8 @@ def create(cls, cls_type, value_vt, options):
222222
unimplemented("Enum variable is constructed with non constant values")
223223

224224
def as_proxy(self):
225+
if isinstance(self.value, int):
226+
return int(self.value) # convert IntEnum to a normal int
225227
return self.value
226228

227229
def __str__(self) -> str:

torch/_dynamo/variables/user_defined.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke
193193
else:
194194
return SourcelessBuilder.create(tx, func)
195195
elif isinstance(obj, classmethod):
196+
if isinstance(obj.__func__, property):
197+
return variables.UserFunctionVariable(obj.__func__.fget).call_function(
198+
tx, [self], {}
199+
)
196200
return variables.UserMethodVariable(obj.__func__, self, source=source)
197201
elif isinstance(obj, types.ClassMethodDescriptorType):
198202
# e.g.: inspect.getattr_static(dict, "fromkeys")

0 commit comments

Comments
 (0)