Skip to content

Commit db3b72d

Browse files
committed
Forward batch_size and seq_length to the validate_model_outputs in order to test with potential different input shapes.
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
1 parent fe85d3a commit db3b72d

File tree

2 files changed

+43
-12
lines changed

2 files changed

+43
-12
lines changed

src/transformers/onnx/convert.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,27 @@ def validate_model_outputs(
138138
reference_model: Union[PreTrainedModel, TFPreTrainedModel],
139139
onnx_model_path_or_bytes: Union[PathLike, bytes],
140140
onnx_named_outputs: List[str],
141-
atol: float,
141+
batch_size: int = -1,
142+
seq_length: int = -1,
143+
atol: float = 1e-5,
142144
):
143145
from onnxruntime import InferenceSession, SessionOptions
144146

145147
logger.info("Validating ONNX model...")
146148

147-
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
148-
# dynamic input shapes.
149-
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
150-
151149
# Create ONNX Runtime session
152150
options = SessionOptions()
153-
session = InferenceSession(onnx_model.as_posix(), options)
151+
options.add_session_config_entry('session.load_model_format', 'ONNX')
152+
session = InferenceSession(onnx_model_path_or_bytes, options)
153+
154+
# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
155+
# dynamic input shapes.
156+
reference_model_inputs = config.generate_dummy_inputs(
157+
tokenizer=tokenizer,
158+
batch_size=batch_size if batch_size > 0 else 3,
159+
seq_length=seq_length if seq_length > 0 else 31,
160+
framework=TensorType.PYTORCH
161+
)
154162

155163
# Compute outputs from the reference model
156164
ref_outputs = reference_model(**reference_model_inputs)

tests/test_onnx_v2.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,13 +237,25 @@ def test_pytorch_export_default(self):
237237
model = model_class(config_class.from_pretrained(model))
238238
onnx_config = onnx_config_class.from_model_config(model.config)
239239

240-
with NamedTemporaryFile("w") as output:
240+
with NamedTemporaryFile("wb+") as output:
241241
onnx_inputs, onnx_outputs = export(
242-
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
242+
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output
243243
)
244244

245245
try:
246-
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
246+
# Reset to the head of the file and read everything
247+
output.seek(0)
248+
model_bytes = output.read()
249+
validate_model_outputs(
250+
onnx_config,
251+
tokenizer,
252+
model,
253+
model_bytes,
254+
onnx_outputs,
255+
batch_size=-1,
256+
seq_length=-1,
257+
atol=1e-5
258+
)
247259
except ValueError as ve:
248260
self.fail(f"{name} -> {ve}")
249261

@@ -265,11 +277,22 @@ def test_pytorch_export_with_past(self):
265277
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
266278
)
267279

268-
with NamedTemporaryFile("w") as output:
269-
output = Path(output.name)
280+
with NamedTemporaryFile("wb+") as output:
270281
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)
271282

272283
try:
273-
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
284+
# Reset to the head of the file and read everything
285+
output.seek(0)
286+
model_bytes = output.read()
287+
validate_model_outputs(
288+
onnx_config,
289+
tokenizer,
290+
model,
291+
model_bytes,
292+
onnx_outputs,
293+
batch_size=-1,
294+
seq_length=-1,
295+
atol=1e-5
296+
)
274297
except ValueError as ve:
275298
self.fail(f"{name} -> {ve}")

0 commit comments

Comments
 (0)