Skip to content

Commit 23b6ae9

Browse files
Ankit-Jaiswal-AMDnaveenthangudu
authored andcommitted
[ZENDNN] Added op tests for zendnn_linear
Change-Id: Iedd2988244b35d44f596dd427f2a5378154143ee
1 parent a9b2cbe commit 23b6ae9

File tree

1 file changed

+230
-0
lines changed

1 file changed

+230
-0
lines changed

test/test_zendnn_linear.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# Owner(s): ["module: unknown"]
2+
import unittest
3+
import torch
4+
import torch.nn.functional as F
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
from hypothesis import given, strategies as st, settings
7+
8+
9+
@unittest.skipIf(
10+
not torch._C.has_zendnn, "ZenDNN is not available in this PyTorch build"
11+
)
12+
class TestZenDNNLinear(TestCase):
13+
def setUp(self):
14+
self.device = torch.device("cpu")
15+
# Check if bfloat16 is supported on the current device
16+
self.bf16_supported = torch._C._cpu._is_avx512_bf16_supported()
17+
18+
def _test_zendnn_linear(self, input, weight, bias, atol, rtol):
19+
# Run reference implementation using torch.nn.functional.linear
20+
expected = F.linear(input, weight, bias)
21+
22+
# Run ZenDNN implementation
23+
if bias is not None:
24+
result = torch.ops.aten.zendnn_linear(input=input, weight=weight, bias=bias)
25+
else:
26+
result = torch.ops.aten.zendnn_linear(input=input, weight=weight)
27+
28+
# Compare results
29+
torch.testing.assert_close(result, expected, rtol=rtol, atol=atol)
30+
31+
@given(
32+
batch_size=st.integers(1, 32),
33+
in_features=st.integers(2, 256),
34+
out_features=st.integers(2, 256),
35+
has_bias=st.booleans(),
36+
use_bf16=st.booleans(),
37+
)
38+
@settings(deadline=None)
39+
def test_zendnn_linear_2d_input(
40+
self, batch_size, in_features, out_features, has_bias, use_bf16
41+
):
42+
if use_bf16 and not self.bf16_supported:
43+
# Skip test if bf16 is requested but not supported
44+
self.skipTest("BFloat16 not supported on this device")
45+
dtype = torch.bfloat16 if use_bf16 else torch.float32
46+
47+
# Create input tensor
48+
input = torch.randn(batch_size, in_features, device=self.device, dtype=dtype)
49+
50+
# Create weight tensor
51+
weight = torch.randn(out_features, in_features, device=self.device, dtype=dtype)
52+
53+
# Create bias tensor (optional)
54+
bias = (
55+
torch.randn(out_features, device=self.device, dtype=dtype)
56+
if has_bias
57+
else None
58+
)
59+
rtol = 1e-2 if use_bf16 else 1e-4 # Relax tolerances for BF16
60+
atol = 1e-2 if use_bf16 else 1e-4
61+
self._test_zendnn_linear(input, weight, bias, atol, rtol)
62+
63+
@given(
64+
batch_size=st.integers(1, 16),
65+
seq_len=st.integers(1, 32),
66+
in_features=st.integers(2, 128),
67+
out_features=st.integers(2, 128),
68+
has_bias=st.booleans(),
69+
use_bf16=st.booleans(),
70+
)
71+
@settings(deadline=None)
72+
def test_zendnn_linear_3d_input(
73+
self, batch_size, seq_len, in_features, out_features, has_bias, use_bf16
74+
):
75+
if use_bf16 and not self.bf16_supported:
76+
# Skip test if bf16 is requested but not supported
77+
self.skipTest("BFloat16 not supported on this device")
78+
79+
dtype = torch.bfloat16 if use_bf16 else torch.float32
80+
81+
# Create input tensor
82+
input = torch.randn(
83+
batch_size, seq_len, in_features, device=self.device, dtype=dtype
84+
)
85+
86+
# Create weight tensor
87+
weight = torch.randn(out_features, in_features, device=self.device, dtype=dtype)
88+
89+
# Create bias tensor (optional)
90+
bias = (
91+
torch.randn(out_features, device=self.device, dtype=dtype)
92+
if has_bias
93+
else None
94+
)
95+
96+
rtol = 1e-2 if use_bf16 else 1e-4 # Relax tolerances for BF16
97+
atol = 1e-2 if use_bf16 else 1e-4
98+
self._test_zendnn_linear(input, weight, bias, atol, rtol)
99+
100+
@given(
101+
dims=st.integers(4, 5),
102+
batch_dim=st.integers(1, 8),
103+
in_features=st.integers(2, 64),
104+
out_features=st.integers(2, 64),
105+
has_bias=st.booleans(),
106+
use_bf16=st.booleans(),
107+
)
108+
@settings(deadline=None)
109+
def test_zendnn_linear_nd_input(
110+
self, dims, batch_dim, in_features, out_features, has_bias, use_bf16
111+
):
112+
if use_bf16 and not self.bf16_supported:
113+
# Skip test if bf16 is requested but not supported
114+
self.skipTest("BFloat16 not supported on this device")
115+
116+
dtype = torch.bfloat16 if use_bf16 else torch.float32
117+
118+
# Create shape with multiple batch dimensions
119+
shape = [batch_dim] * (dims - 1) + [in_features]
120+
121+
# Create input tensor
122+
input = torch.randn(*shape, device=self.device, dtype=dtype)
123+
124+
# Create weight tensor
125+
weight = torch.randn(out_features, in_features, device=self.device, dtype=dtype)
126+
127+
# Create bias tensor (optional)
128+
bias = (
129+
torch.randn(out_features, device=self.device, dtype=dtype)
130+
if has_bias
131+
else None
132+
)
133+
134+
rtol = 1e-2 if use_bf16 else 1e-4 # Relax tolerances for BF16
135+
atol = 1e-2 if use_bf16 else 1e-4
136+
self._test_zendnn_linear(input, weight, bias, atol, rtol)
137+
138+
@given(
139+
batch_size=st.integers(1, 32),
140+
in_features=st.integers(2, 128),
141+
out_features=st.integers(2, 128),
142+
use_bf16=st.booleans(),
143+
)
144+
@settings(deadline=None)
145+
def test_zendnn_linear_keyword_args(
146+
self, batch_size, in_features, out_features, use_bf16
147+
):
148+
if use_bf16 and not self.bf16_supported:
149+
# Skip test if bf16 is requested but not supported
150+
self.skipTest("BFloat16 not supported on this device")
151+
152+
dtype = torch.bfloat16 if use_bf16 else torch.float32
153+
154+
# Create tensors
155+
input = torch.randn(batch_size, in_features, device=self.device, dtype=dtype)
156+
weight = torch.randn(out_features, in_features, device=self.device, dtype=dtype)
157+
bias = torch.randn(out_features, device=self.device, dtype=dtype)
158+
159+
# Run with positional arguments
160+
result1 = torch.ops.aten.zendnn_linear(input, weight, bias)
161+
162+
# Run with keyword arguments
163+
result2 = torch.ops.aten.zendnn_linear(input=input, weight=weight, bias=bias)
164+
165+
# Compare results
166+
rtol = 1e-2 if use_bf16 else 1e-4 # Relax tolerances for BF16
167+
atol = 1e-2 if use_bf16 else 1e-4
168+
torch.testing.assert_close(result1, result2, rtol=rtol, atol=atol)
169+
170+
def test_zendnn_linear_exception_weight_dim(self):
171+
# Test invalid weight dimension
172+
input = torch.randn(10, 20)
173+
weight = torch.randn(30, 20, 5) # Should be 2D
174+
175+
with self.assertRaises(RuntimeError):
176+
torch.ops.aten.zendnn_linear(input, weight)
177+
178+
def test_zendnn_linear_exception_bias_dim(self):
179+
# Test invalid bias dimension
180+
input = torch.randn(10, 20)
181+
weight = torch.randn(30, 20)
182+
bias = torch.randn(30, 5) # Should be 1D
183+
184+
with self.assertRaises(RuntimeError):
185+
torch.ops.aten.zendnn_linear(input, weight, bias)
186+
187+
def test_zendnn_linear_exception_feature_mismatch(self):
188+
# Test mismatch in feature dimensions
189+
input = torch.randn(10, 20)
190+
weight = torch.randn(30, 25) # Should be (30, 20)
191+
192+
with self.assertRaises(RuntimeError):
193+
torch.ops.aten.zendnn_linear(input, weight)
194+
195+
def test_zendnn_linear_exception_bias_size(self):
196+
# Test mismatch in bias size
197+
input = torch.randn(10, 20)
198+
weight = torch.randn(30, 20)
199+
bias = torch.randn(35) # Should be size 30
200+
201+
with self.assertRaises(RuntimeError):
202+
torch.ops.aten.zendnn_linear(input, weight, bias)
203+
204+
def test_zendnn_linear_dtype_mismatch(self):
205+
# Test dtype mismatch between input tensors
206+
input = torch.randn(10, 20, dtype=torch.float32)
207+
weight = torch.randn(30, 20, dtype=torch.float64) # Different dtype
208+
209+
with self.assertRaises(RuntimeError):
210+
torch.ops.aten.zendnn_linear(input, weight)
211+
212+
def test_zendnn_linear_bf16(self):
213+
# Skip if BF16 is not supported
214+
if not self.bf16_supported:
215+
self.skipTest("BFloat16 not supported on this device")
216+
217+
# Create BF16 tensors
218+
input = torch.randn(10, 20, dtype=torch.bfloat16)
219+
weight = torch.randn(30, 20, dtype=torch.bfloat16)
220+
bias = torch.randn(30, dtype=torch.bfloat16)
221+
222+
# Verify both implementations produce similar results
223+
expected = F.linear(input, weight, bias)
224+
result = torch.ops.aten.zendnn_linear(input, weight, bias)
225+
226+
torch.testing.assert_close(result, expected, rtol=1e-2, atol=1e-2)
227+
228+
229+
if __name__ == "__main__":
230+
run_tests()

0 commit comments

Comments
 (0)