Skip to content

Commit f59fbd0

Browse files
committed
minor tests
1 parent 8918c7e commit f59fbd0

File tree

4 files changed

+18
-14
lines changed

4 files changed

+18
-14
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,8 @@ def forward(self, input):
891891
FileCheck().check(
892892
"buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal"
893893
).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check(
894-
"extern_kernels.mm(buf0,"
894+
"extern_kernels.mm(buf0," if not config.triton.enable_native_matmul
895+
else "triton_per_fused_add_addmm_0.run(buf6"
895896
).run(code)
896897

897898
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")

test/inductor/test_aot_inductor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,6 +2718,9 @@ def forward(self, x, y):
27182718
result_package = model_package(*inputs_on_device)
27192719
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
27202720

2721+
@unittest.skipIf(
2722+
config.triton.enable_native_matmul, "sin and mm are fused in native matmul"
2723+
)
27212724
def test_reuse_kernel(self):
27222725
class Model(torch.nn.Module):
27232726
def __init__(self) -> None:
@@ -2736,14 +2739,9 @@ def forward(self, x, y):
27362739
)
27372740
model = Model()
27382741

2739-
if config.triton.enable_native_matmul:
2740-
atol, rtol = 2e-4, 2e-4
2741-
else :
2742-
# 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
2743-
atol, rtol = 1e-4, 1e-4
2744-
2742+
# 1e-4 is the tol value used in pytorch/torch/_dynamo/utils.py
27452743
self.check_model(
2746-
model, example_inputs, atol=atol, rtol=rtol
2744+
model, example_inputs, atol=1e-4, rtol=1e-4
27472745
)
27482746

27492747
if self.device == "mps":
@@ -4921,7 +4919,10 @@ def forward(self, image: torch.Tensor, target_size: torch.Tensor):
49214919
"target_size": None,
49224920
}
49234921
self.check_model(model, example_inputs, dynamic_shapes=dynamic_shapes)
4924-
4922+
4923+
@unittest.skipIf(
4924+
config.triton.enable_native_matmul, "matmul is generated"
4925+
)
49254926
def test_aoti_debug_printer_codegen(self):
49264927
# basic addmm model to test codegen for aoti intermediate debug printer
49274928
class Model(torch.nn.Module):

test/inductor/test_flex_attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,6 +3650,9 @@ def mask_mod(b, h, q, kv):
36503650

36513651
@supported_platform
36523652
@skip_on_cpu
3653+
@unittest.skipIf(
3654+
config.triton.enable_native_matmul, "different dynamo counters"
3655+
)
36533656
def test_free_symbol_dynamic(self, device):
36543657
def batch_flip_causal(b, h, q_idx, kv_idx):
36553658
return (q_idx >= kv_idx) & (b % 2 == 0)

test/inductor/test_torchinductor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6304,6 +6304,9 @@ def fn(x1, x2, x3, x4):
63046304
@skip_if_gpu_halide
63056305
# Constant folding was explicitly turned off due to issue #108388
63066306
# Turn it back on for test
6307+
@unittest.skipIf(
6308+
config.triton.enable_native_matmul, "native matmul has better precision"
6309+
)
63076310
@torch._inductor.config.patch(joint_graph_constant_folding=True)
63086311
def test_remove_no_ops(self):
63096312
def matmul_with_op(x, y, fn):
@@ -6329,11 +6332,7 @@ def matmul_with_op(x, y, fn):
63296332
if self.device == "cpu":
63306333
FileCheck().check_not("cpp_fused").run(source_codes[0])
63316334
else :
6332-
if config.triton.enable_native_matmul:
6333-
FileCheck().check("triton.jit").run(source_codes[0])
6334-
#atol, rtol = 1e-2, 1e-2
6335-
else :
6336-
FileCheck().check_not("triton.jit").run(source_codes[0])
6335+
FileCheck().check_not("triton.jit").run(source_codes[0])
63376336

63386337
# test dtype conversion
63396338
for lowp_dtype in [torch.float16, torch.bfloat16]:

0 commit comments

Comments
 (0)