@@ -1888,15 +1888,10 @@ def _warn_complex_not_supported():
1888
1888
1889
1889
# There are some types (CPU) which we accept as input but not as
1890
1890
# output.
1891
- def unsupported_input_tensor (t : torch .Tensor , parent = None , node = None ):
1891
+ def unsupported_input_tensor (t : torch .Tensor , node = None ):
1892
1892
"Do not support reading or writing to this tensor"
1893
1893
if t .is_complex ():
1894
1894
# Complex views are supported with IR ComplexView
1895
- if parent and parent .target in (
1896
- torch .ops .aten .view .dtype ,
1897
- torch .ops .prims .convert_element_type .default ,
1898
- ):
1899
- return False
1900
1895
_warn_complex_not_supported ()
1901
1896
return True
1902
1897
@@ -1910,11 +1905,12 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
1910
1905
# allow bitcast, views, memory movement, but not arithmetic
1911
1906
# TODO: delete once triton adds native support
1912
1907
return not (
1913
- isinstance (parent .target , torch ._ops .OpOverload )
1914
- and parent .target
1908
+ isinstance (node .target , torch ._ops .OpOverload )
1909
+ and node .target
1915
1910
in (
1916
1911
aten .view .dtype ,
1917
1912
aten .cat .default ,
1913
+ aten .clone .default ,
1918
1914
aten ._scaled_mm .default ,
1919
1915
)
1920
1916
or (isinstance (node .target , torch ._ops .OpOverload ) and is_view (node .target ))
@@ -1923,9 +1919,15 @@ def unsupported_input_tensor(t: torch.Tensor, parent=None, node=None):
1923
1919
return False
1924
1920
1925
1921
1926
- def unsupported_output_tensor (t : torch .Tensor , parent = None , node = None ):
1922
+ def unsupported_output_tensor (t : torch .Tensor , node = None ):
1927
1923
"Do not support writing tensor but can read from it"
1928
- if unsupported_input_tensor (t , parent ):
1924
+ supported_complex_views = (
1925
+ aten .view .dtype ,
1926
+ torch .ops .prims .convert_element_type .default ,
1927
+ )
1928
+ if node is not None and node .target in supported_complex_views and t .is_complex ():
1929
+ return False
1930
+ if unsupported_input_tensor (t , node ):
1929
1931
return True
1930
1932
return t .is_cpu and config .disable_cpp_codegen
1931
1933
@@ -1935,36 +1937,39 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
1935
1937
if node .target is aten .view_as_complex .default :
1936
1938
return False
1937
1939
1940
+ if node .op == "placeholder" :
1941
+ return False
1942
+
1938
1943
# We should be able to remove this special case once `disable_cpp_codegen` is killed.
1939
1944
if node .target is aten .lift_fresh_copy .default :
1940
1945
return False
1941
1946
1942
- def check_skip_condition (node , parent , is_output ):
1943
- if not isinstance (node , torch .fx .Node ):
1947
+ def check_skip_condition (inp_out_node , is_output ):
1948
+ if not isinstance (inp_out_node , torch .fx .Node ):
1944
1949
return False
1945
1950
1946
- if "val" not in node .meta :
1951
+ if "val" not in inp_out_node .meta :
1947
1952
return False
1948
1953
1949
- for meta in pytree .tree_leaves (node .meta ["val" ]):
1954
+ for meta in pytree .tree_leaves (inp_out_node .meta ["val" ]):
1950
1955
if not isinstance (meta , torch ._subclasses .FakeTensor ):
1951
1956
continue
1952
1957
1953
1958
if is_output :
1954
- if unsupported_output_tensor (meta , parent , node ):
1959
+ if unsupported_output_tensor (meta , node ):
1955
1960
return True
1956
1961
else :
1957
- if unsupported_input_tensor (meta , parent , node ):
1962
+ if unsupported_input_tensor (meta , node ):
1958
1963
return True
1959
1964
1960
1965
return False
1961
1966
1962
1967
# only skip codegen if there is a cpu output, not input
1963
1968
for arg in pytree .arg_tree_leaves (* node .args , ** node .kwargs ):
1964
- if check_skip_condition (arg , node , is_output = False ):
1969
+ if check_skip_condition (arg , is_output = False ):
1965
1970
return True
1966
1971
1967
- return check_skip_condition (node , node , is_output = True )
1972
+ return check_skip_condition (node , is_output = True )
1968
1973
1969
1974
1970
1975
def make_fallback (op , layout_constraint = None , warn = True , override_decomp = False ):
0 commit comments