Skip to content

Commit fd553b9

Browse files
eellisonpytorchmergebot
authored andcommitted
Add remaining method and tests for dtype propagation (#140057)
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. Pull Request resolved: #140057 Approved by: https://github.com/arui-meta, https://github.com/blaine-rister, https://github.com/ezyang ghstack dependencies: #139945
1 parent 566ceb3 commit fd553b9

File tree

7 files changed

+322
-44
lines changed

7 files changed

+322
-44
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,27 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
66

77

88

9-
add_loop_inductor,compile_time_instruction_count,26750000000,0.015
9+
add_loop_inductor,compile_time_instruction_count,29490000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,42430000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,43310000000,0.025
1414

1515

1616

17-
add_loop_inductor_gpu,compile_time_instruction_count,24790000000,0.015
17+
add_loop_inductor_gpu,compile_time_instruction_count,25660000000,0.015
1818

1919

2020

2121
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1033000000,0.015
2222

2323

2424

25-
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19970000000,0.015
25+
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,20790000000,0.015
2626

2727

2828

29-
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,16450000000,0.015
29+
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17020000000,0.015
3030

3131

3232

@@ -62,4 +62,4 @@ aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3863000000,
6262

6363

6464

65-
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10360000000,0.015
65+
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10410000000,0.015

test/inductor/test_op_dtype_prop.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Owner(s): ["module: inductor"]
2+
import importlib
3+
import os
4+
import sys
5+
6+
import torch
7+
from torch._dynamo.utils import disable_cache_limit
8+
from torch._inductor import config
9+
from torch._inductor.test_case import TestCase as InductorTestCase
10+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
11+
from torch.testing._internal.common_methods_invocations import op_db
12+
13+
14+
# Make the helper files in test/ importable
15+
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
16+
sys.path.append(pytorch_test_dir)
17+
18+
19+
importlib.import_module("functorch")
20+
importlib.import_module("filelock")
21+
22+
23+
from torch._inductor.lowering import lowerings
24+
from torch.testing._internal.common_device_type import ops
25+
from torch.testing._internal.inductor_utils import HAS_GPU
26+
27+
28+
unique_pointwise_op_names = set()
29+
30+
for op in lowerings:
31+
if not isinstance(op, torch._ops.OpOverload):
32+
continue
33+
34+
if torch.Tag.pointwise not in op.tags:
35+
continue
36+
37+
if op._schema.is_mutable:
38+
continue
39+
40+
op_name = (op.name().split("::")[-1]).split(".")[0]
41+
unique_pointwise_op_names.add(op_name)
42+
43+
pointwise_ops = [
44+
op
45+
for op in op_db
46+
if op.name in unique_pointwise_op_names and "reduction" not in op.variant_test_name
47+
]
48+
49+
50+
class TestCase(InductorTestCase):
51+
@ops(
52+
pointwise_ops,
53+
allowed_dtypes=(
54+
torch.float32,
55+
torch.float64,
56+
torch.int32,
57+
# torch.int64, # fixed in follow up
58+
torch.bool,
59+
),
60+
)
61+
# @config.patch("triton.codegen_upcast_to_fp32", False) # TODO enable
62+
@config.patch("test_configs.runtime_triton_dtype_assert", True)
63+
@disable_cache_limit()
64+
def test_op_dtype_propagation(self, op, dtype):
65+
def run(op, args, kwargs):
66+
return op(*args, **kwargs)
67+
68+
if op.name == "add":
69+
self.skipTest("Fixed in follow ups")
70+
71+
sample_inputs_itr = op.sample_inputs("cuda", dtype, requires_grad=False)
72+
for sample_input in sample_inputs_itr:
73+
args = (sample_input.input,) + sample_input.args
74+
kwargs = sample_input.kwargs
75+
out = run(op.get_op(), args, kwargs)
76+
out_c = torch.compile(run)(op.get_op(), args, kwargs)
77+
self.assertEqual(out, out_c)
78+
79+
80+
instantiate_device_type_tests(TestCase, globals(), only_for=("cuda",))
81+
82+
if __name__ == "__main__":
83+
from torch._inductor.test_case import run_tests
84+
85+
if HAS_GPU:
86+
run_tests(needs="filelock")

torch/_inductor/codegen/common.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,21 +1791,39 @@ def inner(*args, **kwargs):
17911791
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
17921792
dtype_handler = DtypePropagationOpsHandler()
17931793

1794+
output_idx = 0
1795+
17941796
def do_cse(v):
1795-
# TODO - throw on default
17961797
output_dtype = getattr(
17971798
dtype_handler,
17981799
name,
1799-
dtype_handler.default_handler,
1800-
)(*args)
1800+
)(*args, **kwargs)
18011801

18021802
csevar = V.kernel.cse.generate(
18031803
V.kernel.compute,
18041804
v,
18051805
bounds=bounds,
18061806
dtype=output_dtype,
18071807
)
1808+
1809+
nonlocal output_idx
1810+
if config.test_configs.runtime_triton_dtype_assert and not (
1811+
V.graph.get_current_device_or_throw().type == "cpu"
1812+
and config.cpu_backend != "triton"
1813+
):
1814+
from torch._inductor.codegen.triton import triton_type
1815+
1816+
# we tree_map over the output, so we need to fetch corresponding dtype
1817+
if isinstance(output_dtype, (list, tuple)):
1818+
output_dtype = output_dtype[output_idx]
1819+
1820+
V.kernel.compute.writeline(
1821+
f"tl.static_assert({csevar}.dtype == {triton_type(output_dtype)})"
1822+
)
1823+
output_idx += 1
1824+
18081825
csevar.update_on_args(name, args, kwargs)
1826+
18091827
return csevar
18101828

18111829
return pytree.tree_map(do_cse, value)

torch/_inductor/codegen/triton.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,9 @@ def sigmoid(x):
10921092
@staticmethod
10931093
def signbit(x):
10941094
# XX: This is wrong for the value -0.0 in floating point
1095-
return f"libdevice.signbit({x}) if ({x}).dtype is tl.float32 else {x} < 0"
1095+
return (
1096+
f"(libdevice.signbit({x}) != 0) if ({x}).dtype is tl.float32 else {x} < 0"
1097+
)
10961098

10971099
@staticmethod
10981100
def fmod(a, b):

torch/_inductor/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,6 +1353,8 @@ class trace:
13531353
class test_configs:
13541354
force_extern_kernel_in_multi_template = False
13551355

1356+
runtime_triton_dtype_assert = False
1357+
13561358

13571359
if TYPE_CHECKING:
13581360
from torch.utils._config_typing import * # noqa: F401, F403

0 commit comments

Comments
 (0)