|
20 | 20 | from abc import ABC
|
21 | 21 | from collections import namedtuple
|
22 | 22 | from copy import deepcopy
|
23 |
| -from enum import Enum |
| 23 | +from enum import Enum, IntEnum |
24 | 24 | from functools import wraps
|
25 | 25 | from typing import Any, Dict, Iterator, List, Tuple
|
26 | 26 | from unittest import mock
|
@@ -457,7 +457,7 @@ def forward(
|
457 | 457 | if past_key_value is not None:
|
458 | 458 | assert (
|
459 | 459 | len(past_key_value) == 2
|
460 |
| - ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| 460 | + ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" |
461 | 461 | real_seq_length += (
|
462 | 462 | past_key_value[0].shape[2] if query_length is None else query_length
|
463 | 463 | )
|
@@ -4546,6 +4546,84 @@ def f(*args):
|
4546 | 4546 | f(*args)
|
4547 | 4547 | self.assertEqual(num_compiles, 1)
|
4548 | 4548 |
|
| 4549 | + def test_issue134451(self): |
| 4550 | + class BoundingBox2DIndex(IntEnum): |
| 4551 | + _X = 0 |
| 4552 | + _Y = 1 |
| 4553 | + _HEADING = 2 |
| 4554 | + _LENGTH = 3 |
| 4555 | + _WIDTH = 4 |
| 4556 | + |
| 4557 | + @classmethod |
| 4558 | + def size(cls): |
| 4559 | + return 5 |
| 4560 | + |
| 4561 | + @classmethod |
| 4562 | + @property |
| 4563 | + def X(cls): |
| 4564 | + return cls._X |
| 4565 | + |
| 4566 | + @classmethod |
| 4567 | + @property |
| 4568 | + def Y(cls): |
| 4569 | + return cls._Y |
| 4570 | + |
| 4571 | + @classmethod |
| 4572 | + @property |
| 4573 | + def HEADING(cls): |
| 4574 | + return cls._HEADING |
| 4575 | + |
| 4576 | + @classmethod |
| 4577 | + @property |
| 4578 | + def LENGTH(cls): |
| 4579 | + return cls._LENGTH |
| 4580 | + |
| 4581 | + @classmethod |
| 4582 | + @property |
| 4583 | + def WIDTH(cls): |
| 4584 | + return cls._WIDTH |
| 4585 | + |
| 4586 | + @classmethod |
| 4587 | + @property |
| 4588 | + def POINT(cls): |
| 4589 | + # assumes X, Y have subsequent indices |
| 4590 | + return slice(cls._X, cls._Y + 1) |
| 4591 | + |
| 4592 | + @classmethod |
| 4593 | + @property |
| 4594 | + def STATE_SE2(cls): |
| 4595 | + # assumes X, Y, HEADING have subsequent indices |
| 4596 | + return slice(cls._X, cls._HEADING + 1) |
| 4597 | + |
| 4598 | + class SimpleModel(nn.Module): |
| 4599 | + def __init__(self): |
| 4600 | + super().__init__() |
| 4601 | + self._mlp_states = nn.Sequential( |
| 4602 | + nn.Linear(10, 20), |
| 4603 | + nn.ReLU(), |
| 4604 | + nn.Linear(20, BoundingBox2DIndex.size()), |
| 4605 | + ) |
| 4606 | + |
| 4607 | + def forward(self, x): |
| 4608 | + agent_states = self._mlp_states(x) |
| 4609 | + agent_states[..., BoundingBox2DIndex.POINT] = ( |
| 4610 | + agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32 |
| 4611 | + ) |
| 4612 | + agent_states[..., BoundingBox2DIndex.HEADING] = ( |
| 4613 | + agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi |
| 4614 | + ) |
| 4615 | + return agent_states |
| 4616 | + |
| 4617 | + model = SimpleModel().eval() |
| 4618 | + input_tensor = torch.randn(1, 10, dtype=torch.float32) |
| 4619 | + opt = torch.compile(model.eval(), backend="eager", fullgraph=True) |
| 4620 | + actual = opt(input_tensor) |
| 4621 | + try: |
| 4622 | + expected = model(input_tensor) |
| 4623 | + except Exception as e: |
| 4624 | + raise unittest.SkipTest("eager failed, requires Python>=3.12") from e |
| 4625 | + self.assertEqual(actual, expected) |
| 4626 | + |
4549 | 4627 | def test_invalid_seq_unpack(self):
|
4550 | 4628 | def myfn(arg):
|
4551 | 4629 | (a, b) = arg
|
|
0 commit comments