Skip to content

Commit db78943

Browse files
laithsakkapytorchmergebot
authored andcommitted
Fix get_free_symbol_uses for several nodes. (#160134)
get_free_symbol_uses is used to know what unbacked symbols are used by a given node. not having correct get_free_symbol_uses defined properly leads to : 1. eliminating of some nodes due to not detection of any users. (See the added unit test) 2. Incorrect topological sort. Fix get_free_symbol_uses , NopKernel , ConcarKernel, InputsKerenl, external kernel. for ComputedBuffer with NonOwningLayout its interesting case. when layout is NonOwningLayout we need to access the actual view op base layout and use detect symbols in it. Because when we codegen the ComputedBuffer we uses those symbols. Pull Request resolved: #160134 Approved by: https://github.com/bobrenjc93
1 parent 2971231 commit db78943

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

test/test_dynamic_shapes.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3616,6 +3616,17 @@ def func3(x, y):
36163616
def test_unbacked_select_index_cpp_wrapper(self):
36173617
self.test_unbacked_select_index()
36183618

3619+
@torch._dynamo.config.patch("capture_scalar_outputs", True)
3620+
def test_unbacked_select2(self):
3621+
def f(idx, x):
3622+
x = x.select(0, idx.item())
3623+
return x @ x
3624+
3625+
x = torch.randn(3, 3, 3)
3626+
idx = torch.tensor(1, dtype=torch.int64)
3627+
out = torch.compile(f)(idx, x)
3628+
self.assertEqual(out, f(idx, x))
3629+
36193630

36203631
instantiate_parametrized_tests(TestUnbacked)
36213632

torch/_inductor/ir.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4443,14 +4443,30 @@ def get_free_symbol_uses(
44434443
# unusual reason: we only need accurate dependencies for item() call,
44444444
# but it's impossible to end up with a reduction over i0 from an
44454445
# item() call without a regular non-reduction buffer first.
4446-
return (
4446+
4447+
result = (
44474448
get_free_symbols(self.get_size(), unbacked_only)
44484449
| get_free_symbols(self.get_stride(), unbacked_only)
44494450
| get_free_symbols(self.get_offset(), unbacked_only)
44504451
| self.data.get_free_symbol_uses(unbacked_only)
44514452
| self.get_read_writes().get_free_symbol_uses(unbacked_only)
44524453
)
44534454

4455+
if isinstance(self.layout, NonOwningLayout):
4456+
assert isinstance(self.layout.view, ReinterpretView)
4457+
box = self.layout.view.data
4458+
assert isinstance(box, StorageBox), type(box)
4459+
input_buffer = box.data
4460+
assert isinstance(input_buffer, Buffer), type(box)
4461+
result = (
4462+
result
4463+
| get_free_symbols(input_buffer.get_size(), unbacked_only)
4464+
| get_free_symbols(input_buffer.get_stride(), unbacked_only)
4465+
| get_free_symbols(input_buffer.get_offset(), unbacked_only)
4466+
)
4467+
4468+
return result
4469+
44544470
def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
44554471
if (
44564472
not self.get_reduction_type()
@@ -5126,6 +5142,18 @@ def get_read_writes(self) -> dependencies.ReadWrites:
51265142
def get_reads(self) -> OrderedSet[Dep]:
51275143
return self.get_read_writes().reads
51285144

5145+
def get_free_symbol_uses(
5146+
self, unbacked_only: bool = False
5147+
) -> OrderedSet[sympy.Symbol]:
5148+
r = OrderedSet[sympy.Symbol]()
5149+
for inp in self.inputs:
5150+
if isinstance(inp, IRNode):
5151+
r |= inp.get_free_symbol_uses(unbacked_only)
5152+
else:
5153+
for inner_inp in inp:
5154+
r |= inner_inp.get_free_symbol_uses(unbacked_only)
5155+
return r
5156+
51295157
@classmethod
51305158
def unwrap_storage_for_input(cls, x: IRNode) -> IRNode:
51315159
if isinstance(x, TensorBox):
@@ -5172,6 +5200,11 @@ def is_no_op(self) -> bool:
51725200
def get_reads(self) -> OrderedSet[Dep]:
51735201
return OrderedSet()
51745202

5203+
def get_free_symbol_uses(
5204+
self, unbacked_only: bool = False
5205+
) -> OrderedSet[sympy.Symbol]:
5206+
return InputsKernel.get_free_symbol_uses(self, unbacked_only)
5207+
51755208

51765209
class ConcatKernel(NopKernel):
51775210
"""
@@ -5326,6 +5359,11 @@ def can_realize_into_without_copy(
53265359
and not isinstance(src.data, ExternKernelAlloc)
53275360
)
53285361

5362+
def get_free_symbol_uses(
5363+
self, unbacked_only: bool = False
5364+
) -> OrderedSet[sympy.Symbol]:
5365+
return NopKernel.get_free_symbol_uses(self, unbacked_only)
5366+
53295367
@classmethod
53305368
def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
53315369
# Attempt to turn this into a ReinterpretView rather than assert.
@@ -6221,12 +6259,10 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
62216259
def get_free_symbol_uses(
62226260
self, unbacked_only: bool = False
62236261
) -> OrderedSet[sympy.Symbol]:
6224-
# NB: It's not necessary to check regular inputs as we automatically
6225-
# have dependencies on them
62266262
maybe_get_symbols = (
62276263
maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
62286264
)
6229-
r = OrderedSet[sympy.Symbol]()
6265+
r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
62306266
for arg in self.constant_args:
62316267
r |= maybe_get_symbols(arg)
62326268
for arg in self.kwargs.values():

0 commit comments

Comments
 (0)