Skip to content

Commit 2398ecb

Browse files
add test cases for strong typing (#3739)
1 parent 1313d1b commit 2398ecb

File tree

7 files changed

+98
-71
lines changed

7 files changed

+98
-71
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,18 +277,18 @@ def _populate_trt_builder_config(
277277
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
278278
self.compilation_settings.dla_global_dram_size,
279279
)
280+
if not self.compilation_settings.use_explicit_typing:
281+
if dtype.float16 in self.compilation_settings.enabled_precisions:
282+
builder_config.set_flag(trt.BuilderFlag.FP16)
280283

281-
if dtype.float16 in self.compilation_settings.enabled_precisions:
282-
builder_config.set_flag(trt.BuilderFlag.FP16)
284+
if dtype.int8 in self.compilation_settings.enabled_precisions:
285+
builder_config.set_flag(trt.BuilderFlag.INT8)
283286

284-
if dtype.int8 in self.compilation_settings.enabled_precisions:
285-
builder_config.set_flag(trt.BuilderFlag.INT8)
287+
if dtype.fp8 in self.compilation_settings.enabled_precisions:
288+
builder_config.set_flag(trt.BuilderFlag.FP8)
286289

287-
if dtype.fp8 in self.compilation_settings.enabled_precisions:
288-
builder_config.set_flag(trt.BuilderFlag.FP8)
289-
290-
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
291-
builder_config.set_flag(trt.BuilderFlag.BF16)
290+
if dtype.bfloat16 in self.compilation_settings.enabled_precisions:
291+
builder_config.set_flag(trt.BuilderFlag.BF16)
292292

293293
if self.compilation_settings.sparse_weights:
294294
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,28 +60,28 @@ def batch_norm(
6060
):
6161
# We name the weight here according to the state_dict name
6262
weight = (
63-
get_trt_tensor(ctx, 1.0, f"{name}_weight")
63+
get_trt_tensor(ctx, 1.0, f"{name}_weight", dtype=input.dtype)
6464
if weight is None
6565
else get_trt_tensor(ctx, weight, f"{name}_weight")
6666
)
6767
bias = (
68-
get_trt_tensor(ctx, 0.0, f"{name}_bias")
68+
get_trt_tensor(ctx, 0.0, f"{name}_bias", dtype=input.dtype)
6969
if bias is None
7070
else get_trt_tensor(ctx, bias, f"{name}_bias")
7171
)
7272
running_mean = (
73-
get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
73+
get_trt_tensor(ctx, 0.0, f"{name}_running_mean", dtype=input.dtype)
7474
if running_mean is None
7575
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
7676
)
7777
running_var = (
78-
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
78+
get_trt_tensor(ctx, 1.0, f"{name}_running_var", dtype=input.dtype)
7979
if running_var is None
8080
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
8181
)
8282

8383
# eps_tensor for numerical stability
84-
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
84+
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps", dtype=input.dtype)
8585

8686
# adjusted_var = running_var + eps
8787
adjusted_var = impl.elementwise.add(

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,9 @@ def scaled_dot_product_attention_decomposition(
483483
attn_weight = query @ key.transpose(-2, -1)
484484

485485
if scale is None:
486-
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
486+
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)).to(
487+
query.dtype
488+
)
487489
attn_weight = attn_weight / scale
488490
else:
489491
attn_weight = attn_weight * scale

tests/py/dynamo/models/test_dtype_support.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def forward(self, x):
4242
use_python_runtime=False,
4343
cache_built_engines=False,
4444
reuse_cached_engines=False,
45+
use_explicit_typing=True,
4546
)
4647

4748
torch_model_results = mod(in_tensor)
@@ -82,12 +83,13 @@ def forward(self, x):
8283
use_python_runtime=True,
8384
cache_built_engines=False,
8485
reuse_cached_engines=False,
86+
use_explicit_typing=True,
8587
)
8688

8789
torch_model_results = mod(in_tensor)
8890
with torch_tensorrt.logging.debug():
8991
optimized_model_results = trt_mod(in_tensor)
90-
92+
assert torch_model_results.dtype == optimized_model_results.dtype
9193
max_diff = float(
9294
torch.max(torch.abs(optimized_model_results - torch_model_results))
9395
)
@@ -128,11 +130,12 @@ def forward(self, x):
128130
use_python_runtime=False,
129131
cache_built_engines=False,
130132
reuse_cached_engines=False,
133+
use_explicit_typing=True,
131134
)
132135

133136
torch_model_results = mod(in_tensor)
134137
optimized_model_results = trt_mod(in_tensor)
135-
138+
assert torch_model_results.dtype == optimized_model_results.dtype
136139
max_diff = float(
137140
torch.max(torch.abs(optimized_model_results - torch_model_results))
138141
)
@@ -169,11 +172,12 @@ def forward(self, x):
169172
use_python_runtime=True,
170173
cache_built_engines=False,
171174
reuse_cached_engines=False,
175+
use_explicit_typing=True,
172176
)
173177

174178
torch_model_results = mod(in_tensor)
175179
optimized_model_results = trt_mod(in_tensor)
176-
180+
assert torch_model_results.dtype == optimized_model_results.dtype
177181
max_diff = float(
178182
torch.max(torch.abs(optimized_model_results - torch_model_results))
179183
)
@@ -218,16 +222,16 @@ def forward(self, x):
218222
exp_mod,
219223
inputs=[in_tensor],
220224
pass_through_build_failures=True,
221-
enabled_precisions={torch.float, torch.bfloat16, torch.half},
222225
min_block_size=1,
223226
use_python_runtime=False,
224227
cache_built_engines=False,
225228
reuse_cached_engines=False,
229+
use_explicit_typing=True,
226230
)
227231

228232
torch_model_results = mod(in_tensor)
229233
optimized_model_results = trt_mod(in_tensor)
230-
234+
assert torch_model_results.dtype == optimized_model_results.dtype
231235
max_diff = float(
232236
torch.max(torch.abs(optimized_model_results - torch_model_results))
233237
)
@@ -258,16 +262,16 @@ def forward(self, x):
258262
exp_mod,
259263
inputs=[in_tensor],
260264
pass_through_build_failures=True,
261-
enabled_precisions={torch.float, torch.bfloat16, torch.half},
262265
min_block_size=1,
263266
use_python_runtime=True,
264267
cache_built_engines=False,
265268
reuse_cached_engines=False,
269+
use_explicit_typing=True,
266270
)
267271

268272
torch_model_results = mod(in_tensor)
269273
optimized_model_results = trt_mod(in_tensor)
270-
274+
assert torch_model_results.dtype == optimized_model_results.dtype
271275
max_diff = float(
272276
torch.max(torch.abs(optimized_model_results - torch_model_results))
273277
)
@@ -296,16 +300,16 @@ def forward(self, x):
296300
mod,
297301
ir="torch_compile",
298302
inputs=inputs,
299-
enabled_precisions={torch.bfloat16},
300303
min_block_size=1,
301304
device=device,
302305
cache_built_engines=False,
303306
reuse_cached_engines=False,
307+
use_explicit_typing=True,
304308
)
305309

306310
torch_model_results = mod(*inputs)
307311
optimized_model_results = trt_mod(*inputs)
308-
312+
assert torch_model_results.dtype == optimized_model_results.dtype
309313
max_diff = float(
310314
torch.max(torch.abs(optimized_model_results - torch_model_results))
311315
)

tests/py/dynamo/models/test_dyn_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,26 +178,27 @@ def forward(self, x):
178178
not importlib.util.find_spec("torchvision"), "torchvision not installed"
179179
)
180180
@pytest.mark.unit
181-
def test_resnet_dynamic(ir):
181+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
182+
def test_resnet_dynamic(ir, dtype):
182183
"""
183184
Tests the Resnet18 model (which is fully convertible) with dynamic shapes
184185
"""
185186
import torchvision.models as models
186187

187-
model = models.resnet18(pretrained=True).eval().to("cuda")
188+
model = models.resnet18(pretrained=True).eval().to("cuda").to(dtype)
188189

189190
compile_spec = {
190191
"device": torchtrt.Device("cuda:0"),
191-
"enabled_precisions": {torch.float},
192192
"ir": ir,
193193
"pass_through_build_failures": True,
194194
"min_block_size": 1,
195195
"cache_built_engines": False,
196196
"reuse_cached_engines": False,
197+
"use_explicit_typing": True,
197198
}
198199

199200
if ir == "torch_compile":
200-
input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda")
201+
input_bs2 = torch.randn((2, 3, 224, 224)).to("cuda").to(dtype)
201202
torch._dynamo.mark_dynamic(input_bs2, 0, min=1, max=8)
202203
# Compile the model
203204
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
@@ -208,14 +209,18 @@ def test_resnet_dynamic(ir):
208209
min_shape=(1, 3, 224, 224),
209210
opt_shape=(4, 3, 224, 224),
210211
max_shape=(8, 3, 224, 224),
211-
dtype=torch.float32,
212+
dtype=dtype,
212213
name="x",
213214
)
214215
]
215216
trt_model = torchtrt.compile(model, **compile_spec)
216217

217-
input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda")
218-
cos_sim = cosine_similarity(model(input_bs6), trt_model(input_bs6))
218+
input_bs6 = torch.randn((6, 3, 224, 224)).to("cuda").to(dtype)
219+
pyt_output = model(input_bs6)
220+
trt_output = trt_model(input_bs6)
221+
assert pyt_output.dtype == trt_output.dtype
222+
assert trt_output.dtype == dtype
223+
cos_sim = cosine_similarity(pyt_output, trt_output)
219224
assertions.assertTrue(
220225
cos_sim > COSINE_THRESHOLD,
221226
msg=f"test_resnet_dynamic model TRT outputs don't match with the pytorch model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",

tests/py/dynamo/models/test_models.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -136,28 +136,31 @@ def test_resnet18_torch_exec_ops(ir):
136136
not importlib.util.find_spec("torchvision"),
137137
"torchvision is not installed",
138138
)
139-
def test_mobilenet_v2(ir):
140-
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")
141-
input = torch.randn((1, 3, 224, 224)).to("cuda")
139+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
140+
def test_mobilenet_v2(ir, dtype):
141+
model = models.mobilenet_v2(pretrained=True).eval().to("cuda").to(dtype)
142+
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)
142143

143144
compile_spec = {
144145
"inputs": [
145-
torchtrt.Input(
146-
input.shape, dtype=torch.float, format=torch.contiguous_format
147-
)
146+
torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format)
148147
],
149148
"device": torchtrt.Device("cuda:0"),
150-
"enabled_precisions": {torch.float},
151149
"ir": ir,
152150
"pass_through_build_failures": True,
153151
"optimization_level": 1,
154152
"min_block_size": 10,
155153
"cache_built_engines": False,
156154
"reuse_cached_engines": False,
155+
"use_explicit_typing": True,
157156
}
158157

159158
trt_mod = torchtrt.compile(model, **compile_spec)
160-
cos_sim = cosine_similarity(model(input), trt_mod(input))
159+
pyt_output = model(input)
160+
trt_output = trt_mod(input)
161+
assert pyt_output.dtype == trt_output.dtype
162+
assert pyt_output.dtype == dtype
163+
cos_sim = cosine_similarity(pyt_output, trt_output)
161164
assertions.assertTrue(
162165
cos_sim > COSINE_THRESHOLD,
163166
msg=f"Mobilenet v2 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
@@ -172,28 +175,36 @@ def test_mobilenet_v2(ir):
172175
not importlib.util.find_spec("timm") or not importlib.util.find_spec("torchvision"),
173176
"timm or torchvision not installed",
174177
)
175-
def test_efficientnet_b0(ir):
176-
model = timm.create_model("efficientnet_b0", pretrained=True).eval().to("cuda")
177-
input = torch.randn((1, 3, 224, 224)).to("cuda")
178+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
179+
def test_efficientnet_b0(ir, dtype):
180+
model = (
181+
timm.create_model("efficientnet_b0", pretrained=True)
182+
.eval()
183+
.to("cuda")
184+
.to(dtype)
185+
)
186+
input = torch.randn((1, 3, 224, 224)).to("cuda").to(dtype)
178187

179188
compile_spec = {
180189
"inputs": [
181-
torchtrt.Input(
182-
input.shape, dtype=torch.float, format=torch.contiguous_format
183-
)
190+
torchtrt.Input(input.shape, dtype=dtype, format=torch.contiguous_format)
184191
],
185192
"device": torchtrt.Device("cuda:0"),
186-
"enabled_precisions": {torch.float},
187193
"ir": ir,
188194
"pass_through_build_failures": True,
189195
"optimization_level": 1,
190196
"min_block_size": 10,
191197
"cache_built_engines": False,
192198
"reuse_cached_engines": False,
199+
"use_explicit_typing": True,
193200
}
194201

195202
trt_mod = torchtrt.compile(model, **compile_spec)
196-
cos_sim = cosine_similarity(model(input), trt_mod(input))
203+
pyt_output = model(input)
204+
trt_output = trt_mod(input)
205+
assert pyt_output.dtype == trt_output.dtype
206+
assert pyt_output.dtype == dtype
207+
cos_sim = cosine_similarity(pyt_output, trt_output)
197208
assertions.assertTrue(
198209
cos_sim > COSINE_THRESHOLD,
199210
msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
@@ -208,10 +219,11 @@ def test_efficientnet_b0(ir):
208219
not importlib.util.find_spec("transformers"),
209220
"transformers is required to run this test",
210221
)
211-
def test_bert_base_uncased(ir):
222+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
223+
def test_bert_base_uncased(ir, dtype):
212224
from transformers import BertModel
213225

214-
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval()
226+
model = BertModel.from_pretrained("bert-base-uncased").cuda().eval().to(dtype)
215227
input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
216228
input2 = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda")
217229

@@ -229,21 +241,23 @@ def test_bert_base_uncased(ir):
229241
),
230242
],
231243
"device": torchtrt.Device("cuda:0"),
232-
"enabled_precisions": {torch.float},
233244
"truncate_double": True,
234245
"ir": ir,
235246
"pass_through_build_failures": True,
236247
"optimization_level": 1,
237248
"min_block_size": 15,
238249
"cache_built_engines": False,
239250
"reuse_cached_engines": False,
251+
"use_explicit_typing": True,
240252
}
241253
trt_mod = torchtrt.compile(model, **compile_spec)
242254

243255
model_outputs = model(input, input2)
244256
trt_model_outputs = trt_mod(input, input2)
245257
for key in model_outputs.keys():
246258
out, trt_out = model_outputs[key], trt_model_outputs[key]
259+
assert out.dtype == trt_out.dtype
260+
assert out.dtype == dtype
247261
cos_sim = cosine_similarity(out, trt_out)
248262
assertions.assertTrue(
249263
cos_sim > COSINE_THRESHOLD,

0 commit comments

Comments
 (0)