Skip to content

Commit ef1d45b

Browse files
eellisonpytorchmergebot
authored andcommitted
Cleanup parent fallback logic (#154006)
The `parent` in fallback_node_due_to_unsupported_type is a duplication of `unsupported_output_tensor` logic. remove it. tested that the tests in test_add_complex give same codegen. this fixes an issue in mx that @drisspg was running into. Pull Request resolved: #154006 Approved by: https://github.com/drisspg
1 parent d6e29bf commit ef1d45b

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1949,16 +1949,19 @@ def foo(x0):
19491949

19501950
def foo(x0):
19511951
x1 = x0 + 1
1952-
x2 = x1.view(dtype)
1952+
x2 = x1.view(dtype).view([16 * 16])
19531953
return x2
19541954

19551955
x0 = torch.randint(0, 255, (16, 16), device=device, dtype=torch.uint8)
19561956
foo_c = torch.compile(foo, backend="inductor", fullgraph=True)
19571957

19581958
with torch.no_grad():
1959-
y_c = foo_c(x0)
1959+
result, code = run_and_get_code(foo_c, x0)
19601960

1961-
self.assertEqual(foo(x0), y_c)
1961+
FileCheck().check("call").check_not("torch.ops.aten.reshape.default(").run(
1962+
code[0]
1963+
)
1964+
self.assertEqual(foo(x0), result)
19621965

19631966
@unittest.skipIf(
19641967
not config.is_fbcode(),

torch/_inductor/lowering.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,15 +1888,10 @@ def _warn_complex_not_supported():
18881888

18891889
# There are some types (CPU) which we accept as input but not as
18901890
# output.
1891-
def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
1891+
def unsupported_input_tensor(t: torch.Tensor, node=None):
18921892
"Do not support reading or writing to this tensor"
18931893
if t.is_complex():
18941894
# Complex views are supported with IR ComplexView
1895-
if parent and parent.target in (
1896-
torch.ops.aten.view.dtype,
1897-
torch.ops.prims.convert_element_type.default,
1898-
):
1899-
return False
19001895
_warn_complex_not_supported()
19011896
return True
19021897

@@ -1910,11 +1905,12 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
19101905
# allow bitcast, views, memory movement, but not arithmetic
19111906
# TODO: delete once triton adds native support
19121907
return not (
1913-
isinstance(parent.target, torch._ops.OpOverload)
1914-
and parent.target
1908+
isinstance(node.target, torch._ops.OpOverload)
1909+
and node.target
19151910
in (
19161911
aten.view.dtype,
19171912
aten.cat.default,
1913+
aten.clone.default,
19181914
aten._scaled_mm.default,
19191915
)
19201916
or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target))
@@ -1923,9 +1919,15 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
19231919
return False
19241920

19251921

1926-
def unsupported_output_tensor(t: torch.Tensor, parent=None, node=None):
1922+
def unsupported_output_tensor(t: torch.Tensor, node=None):
19271923
"Do not support writing tensor but can read from it"
1928-
if unsupported_input_tensor(t, parent):
1924+
supported_complex_views = (
1925+
aten.view.dtype,
1926+
torch.ops.prims.convert_element_type.default,
1927+
)
1928+
if node is not None and node.target in supported_complex_views and t.is_complex():
1929+
return False
1930+
if unsupported_input_tensor(t, node):
19291931
return True
19301932
return t.is_cpu and config.disable_cpp_codegen
19311933

@@ -1935,36 +1937,39 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
19351937
if node.target is aten.view_as_complex.default:
19361938
return False
19371939

1940+
if node.op == "placeholder":
1941+
return False
1942+
19381943
# We should be able to remove this special case once `disable_cpp_codegen` is killed.
19391944
if node.target is aten.lift_fresh_copy.default:
19401945
return False
19411946

1942-
def check_skip_condition(node, parent, is_output):
1943-
if not isinstance(node, torch.fx.Node):
1947+
def check_skip_condition(inp_out_node, is_output):
1948+
if not isinstance(inp_out_node, torch.fx.Node):
19441949
return False
19451950

1946-
if "val" not in node.meta:
1951+
if "val" not in inp_out_node.meta:
19471952
return False
19481953

1949-
for meta in pytree.tree_leaves(node.meta["val"]):
1954+
for meta in pytree.tree_leaves(inp_out_node.meta["val"]):
19501955
if not isinstance(meta, torch._subclasses.FakeTensor):
19511956
continue
19521957

19531958
if is_output:
1954-
if unsupported_output_tensor(meta, parent, node):
1959+
if unsupported_output_tensor(meta, node):
19551960
return True
19561961
else:
1957-
if unsupported_input_tensor(meta, parent, node):
1962+
if unsupported_input_tensor(meta, node):
19581963
return True
19591964

19601965
return False
19611966

19621967
# only skip codegen if there is a cpu output, not input
19631968
for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
1964-
if check_skip_condition(arg, node, is_output=False):
1969+
if check_skip_condition(arg, is_output=False):
19651970
return True
19661971

1967-
return check_skip_condition(node, node, is_output=True)
1972+
return check_skip_condition(node, is_output=True)
19681973

19691974

19701975
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):

0 commit comments

Comments
 (0)