Skip to content

Commit 19a3a33

Browse files
committed
Enable outputs validation for default export.
1 parent 6f56f65 commit 19a3a33

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

tests/test_onnx_v2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,13 @@ def test_pytorch_export_default(self):
211211
onnx_config = onnx_config_class.default(model.config)
212212

213213
with NamedTemporaryFile("w") as output:
214-
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))
214+
onnx_inputs, onnx_outputs = \
215+
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))
216+
217+
try:
218+
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
219+
except ValueError as ve:
220+
self.fail(f"{name} -> {ve}")
215221

216222
@slow
217223
@require_torch

0 commit comments

Comments
 (0)