Skip to content

Commit ecfb07e

Browse files
committed
Fix unittest to remove usage of PyTorch model for onnx.utils.
1 parent ac2e005 commit ecfb07e

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

tests/test_onnx_v2.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from unittest import TestCase
33
from unittest.mock import patch
44

5-
from transformers import BertConfig, PreTrainedModel
65
from transformers.models.bert.configuration_bert import BertOnnxConfig
76
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
87

@@ -52,37 +51,22 @@ def test_use_external_data_format(self):
5251
"""
5352
External data format is required only if the serialized size of the parameters if bigger than 2Gb
5453
"""
55-
LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT
54+
TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT
5655

5756
# No parameters
58-
with patch.object(PreTrainedModel, "num_parameters", return_value=0):
59-
model = PreTrainedModel(BertConfig())
60-
onnx = OnnxConfig(model)
61-
self.assertFalse(onnx.use_external_data_format(model.num_parameters()))
57+
self.assertFalse(OnnxConfig.use_external_data_format(0))
6258

6359
# Some parameters
64-
with patch.object(PreTrainedModel, "num_parameters", return_value=2):
65-
model = PreTrainedModel(BertConfig())
66-
onnx = OnnxConfig(model)
67-
self.assertFalse(onnx.use_external_data_format(model.num_parameters()))
60+
self.assertFalse(OnnxConfig.use_external_data_format(1))
6861

6962
# Almost 2Gb parameters
70-
with patch.object(PreTrainedModel, "num_parameters", return_value=(LIMIT - 1) // ParameterFormat.Float.size):
71-
model = PreTrainedModel(BertConfig())
72-
onnx = OnnxConfig(model)
73-
self.assertFalse(onnx.use_external_data_format(model.num_parameters()))
63+
self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size))
7464

7565
# Exactly 2Gb parameters
76-
with patch.object(PreTrainedModel, "num_parameters", return_value=LIMIT):
77-
model = PreTrainedModel(BertConfig())
78-
onnx = OnnxConfig(model)
79-
self.assertTrue(onnx.use_external_data_format(model.num_parameters()))
66+
self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT))
8067

8168
# More than 2Gb parameters
82-
with patch.object(PreTrainedModel, "num_parameters", return_value=(LIMIT + 1) // ParameterFormat.Float.size):
83-
model = PreTrainedModel(BertConfig())
84-
onnx = OnnxConfig(model)
85-
self.assertTrue(onnx.use_external_data_format(model.num_parameters()))
69+
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
8670

8771

8872
class OnnxExportTestCaseV2(TestCase):

0 commit comments

Comments
 (0)