|
2 | 2 | from unittest import TestCase
|
3 | 3 | from unittest.mock import patch
|
4 | 4 |
|
5 |
| -from transformers import BertConfig, PreTrainedModel |
6 | 5 | from transformers.models.bert.configuration_bert import BertOnnxConfig
|
7 | 6 | from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
|
8 | 7 |
|
@@ -52,37 +51,22 @@ def test_use_external_data_format(self):
|
52 | 51 | """
|
53 | 52 | External data format is required only if the serialized size of the parameters if bigger than 2Gb
|
54 | 53 | """
|
55 |
| - LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT |
| 54 | + TWO_GB_LIMIT = EXTERNAL_DATA_FORMAT_SIZE_LIMIT |
56 | 55 |
|
57 | 56 | # No parameters
|
58 |
| - with patch.object(PreTrainedModel, "num_parameters", return_value=0): |
59 |
| - model = PreTrainedModel(BertConfig()) |
60 |
| - onnx = OnnxConfig(model) |
61 |
| - self.assertFalse(onnx.use_external_data_format(model.num_parameters())) |
| 57 | + self.assertFalse(OnnxConfig.use_external_data_format(0)) |
62 | 58 |
|
63 | 59 | # Some parameters
|
64 |
| - with patch.object(PreTrainedModel, "num_parameters", return_value=2): |
65 |
| - model = PreTrainedModel(BertConfig()) |
66 |
| - onnx = OnnxConfig(model) |
67 |
| - self.assertFalse(onnx.use_external_data_format(model.num_parameters())) |
| 60 | + self.assertFalse(OnnxConfig.use_external_data_format(1)) |
68 | 61 |
|
69 | 62 | # Almost 2Gb parameters
|
70 |
| - with patch.object(PreTrainedModel, "num_parameters", return_value=(LIMIT - 1) // ParameterFormat.Float.size): |
71 |
| - model = PreTrainedModel(BertConfig()) |
72 |
| - onnx = OnnxConfig(model) |
73 |
| - self.assertFalse(onnx.use_external_data_format(model.num_parameters())) |
| 63 | + self.assertFalse(OnnxConfig.use_external_data_format((TWO_GB_LIMIT - 1) // ParameterFormat.Float.size)) |
74 | 64 |
|
75 | 65 | # Exactly 2Gb parameters
|
76 |
| - with patch.object(PreTrainedModel, "num_parameters", return_value=LIMIT): |
77 |
| - model = PreTrainedModel(BertConfig()) |
78 |
| - onnx = OnnxConfig(model) |
79 |
| - self.assertTrue(onnx.use_external_data_format(model.num_parameters())) |
| 66 | + self.assertTrue(OnnxConfig.use_external_data_format(TWO_GB_LIMIT)) |
80 | 67 |
|
81 | 68 | # More than 2Gb parameters
|
82 |
| - with patch.object(PreTrainedModel, "num_parameters", return_value=(LIMIT + 1) // ParameterFormat.Float.size): |
83 |
| - model = PreTrainedModel(BertConfig()) |
84 |
| - onnx = OnnxConfig(model) |
85 |
| - self.assertTrue(onnx.use_external_data_format(model.num_parameters())) |
| 69 | + self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size)) |
86 | 70 |
|
87 | 71 |
|
88 | 72 | class OnnxExportTestCaseV2(TestCase):
|
|
0 commit comments