@@ -237,13 +237,25 @@ def test_pytorch_export_default(self):
237
237
model = model_class (config_class .from_pretrained (model ))
238
238
onnx_config = onnx_config_class .from_model_config (model .config )
239
239
240
- with NamedTemporaryFile ("w " ) as output :
240
+ with NamedTemporaryFile ("wb+ " ) as output :
241
241
onnx_inputs , onnx_outputs = export (
242
- tokenizer , model , onnx_config , DEFAULT_ONNX_OPSET , Path ( output . name )
242
+ tokenizer , model , onnx_config , DEFAULT_ONNX_OPSET , output
243
243
)
244
244
245
245
try :
246
- validate_model_outputs (onnx_config , tokenizer , model , Path (output .name ), onnx_outputs , 1e-5 )
246
+ # Reset to the head of the file and read everything
247
+ output .seek (0 )
248
+ model_bytes = output .read ()
249
+ validate_model_outputs (
250
+ onnx_config ,
251
+ tokenizer ,
252
+ model ,
253
+ model_bytes ,
254
+ onnx_outputs ,
255
+ batch_size = - 1 ,
256
+ seq_length = - 1 ,
257
+ atol = 1e-5
258
+ )
247
259
except ValueError as ve :
248
260
self .fail (f"{ name } -> { ve } " )
249
261
@@ -265,11 +277,22 @@ def test_pytorch_export_with_past(self):
265
277
onnx_config .use_past , "OnnxConfigWithPast.use_past should be if called with with_past()"
266
278
)
267
279
268
- with NamedTemporaryFile ("w" ) as output :
269
- output = Path (output .name )
280
+ with NamedTemporaryFile ("wb+" ) as output :
270
281
onnx_inputs , onnx_outputs = export (tokenizer , model , onnx_config , DEFAULT_ONNX_OPSET , output )
271
282
272
283
try :
273
- validate_model_outputs (onnx_config , tokenizer , model , output , onnx_outputs , 1e-5 )
284
+ # Reset to the head of the file and read everything
285
+ output .seek (0 )
286
+ model_bytes = output .read ()
287
+ validate_model_outputs (
288
+ onnx_config ,
289
+ tokenizer ,
290
+ model ,
291
+ model_bytes ,
292
+ onnx_outputs ,
293
+ batch_size = - 1 ,
294
+ seq_length = - 1 ,
295
+ atol = 1e-5
296
+ )
274
297
except ValueError as ve :
275
298
self .fail (f"{ name } -> { ve } " )
0 commit comments