23
23
from transformers .models .xlm_roberta import XLMRobertaOnnxConfig
24
24
from transformers .onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT , OnnxConfig , ParameterFormat
25
25
26
- from transformers .onnx .config import DEFAULT_ONNX_OPSET
26
+ from transformers .onnx .config import DEFAULT_ONNX_OPSET , OnnxConfigWithPast
27
+ from transformers .onnx .convert import validate_model_outputs
27
28
from transformers .onnx .utils import (
28
29
compute_effective_axis_dimension ,
29
30
compute_serialized_parameters_size ,
34
35
35
36
@require_onnx
36
37
class OnnxUtilsTestCaseV2 (TestCase ):
38
+ """
39
+ Cover all the utilities involved to export ONNX models
40
+ """
41
+
37
42
def test_compute_effective_axis_dimension (self ):
43
+ """
44
+ When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
45
+ We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
46
+ (> 1 to avoid ONNX squeezing the axis).
47
+
48
+ This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
49
+ """
50
+
38
51
# Dynamic axis (batch, no token added by the tokenizer)
39
52
self .assertEqual (compute_effective_axis_dimension (- 1 , fixed_dimension = 2 , num_token_to_add = 0 ), 2 )
40
53
@@ -50,9 +63,19 @@ def test_compute_effective_axis_dimension(self):
50
63
self .assertEqual (compute_effective_axis_dimension (0 , fixed_dimension = 8 , num_token_to_add = 3 ), 5 )
51
64
52
65
def test_compute_parameters_serialized_size (self ):
66
+ """
67
+ This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
68
+ parameters for the specified parameter's dtype.
69
+ """
53
70
self .assertEqual (compute_serialized_parameters_size (2 , ParameterFormat .Float ), 2 * ParameterFormat .Float .size )
54
71
55
72
def test_flatten_output_collection_property (self ):
73
+ """
74
+ This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
75
+ past_keys = Tuple[Tuple]
76
+
77
+ ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
78
+ """
56
79
self .assertEqual (
57
80
flatten_output_collection_property ("past_key" , [[0 ], [1 ], [2 ]]),
58
81
{
@@ -64,6 +87,12 @@ def test_flatten_output_collection_property(self):
64
87
65
88
66
89
class OnnxConfigTestCaseV2 (TestCase ):
90
+ """
91
+ Cover the test for models default.
92
+
93
+ Default means no specific features is being enabled on the model.
94
+ """
95
+
67
96
@patch .multiple (OnnxConfig , __abstractmethods__ = set ())
68
97
def test_use_external_data_format (self ):
69
98
"""
@@ -87,37 +116,96 @@ def test_use_external_data_format(self):
87
116
self .assertTrue (OnnxConfig .use_external_data_format ((TWO_GB_LIMIT + 1 ) // ParameterFormat .Float .size ))
88
117
89
118
119
+ class OnnxConfigWithPastTestCaseV2 (TestCase ):
120
+ """
121
+ Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
122
+ """
123
+
124
+ SUPPORTED_WITH_PAST_CONFIGS = {
125
+ ("BART" , BartConfig ),
126
+ ("GPT2" , GPT2Config ),
127
+ ("T5" , T5Config )
128
+ }
129
+
130
+ @patch .multiple (OnnxConfigWithPast , __abstractmethods__ = set ())
131
+ def test_use_past (self ):
132
+ """
133
+ Ensure the use_past variable is correctly being set
134
+ """
135
+ for name , config in OnnxConfigWithPastTestCaseV2 .SUPPORTED_WITH_PAST_CONFIGS :
136
+ with self .subTest (name ):
137
+ self .assertFalse (
138
+ OnnxConfigWithPast .default (config ()).use_past ,
139
+ "OnnxConfigWithPast.default() should not use_past"
140
+ )
141
+
142
+ self .assertTrue (
143
+ OnnxConfigWithPast .with_past (config ()).use_past ,
144
+ "OnnxConfigWithPast.default() should use_past"
145
+ )
146
+
147
+ @patch .multiple (OnnxConfigWithPast , __abstractmethods__ = set ())
148
+ def test_values_override (self ):
149
+ """
150
+ Ensure the use_past variable correctly set the `use_cache` value in model's configuration
151
+ """
152
+ for name , config in OnnxConfigWithPastTestCaseV2 .SUPPORTED_WITH_PAST_CONFIGS :
153
+ with self .subTest (name ):
154
+
155
+ # without past
156
+ onnx_config_default = OnnxConfigWithPast .default (config ())
157
+ self .assertIsNotNone (onnx_config_default .values_override , "values_override should not be None" )
158
+ self .assertIn ("use_cache" , onnx_config_default .values_override , "use_cache should be present" )
159
+ self .assertFalse (
160
+ onnx_config_default .values_override ["use_cache" ],
161
+ "use_cache should be False if not using past"
162
+ )
163
+
164
+ # with past
165
+ onnx_config_default = OnnxConfigWithPast .with_past (config ())
166
+ self .assertIsNotNone (onnx_config_default .values_override , "values_override should not be None" )
167
+ self .assertIn ("use_cache" , onnx_config_default .values_override , "use_cache should be present" )
168
+ self .assertTrue (
169
+ onnx_config_default .values_override ["use_cache" ],
170
+ "use_cache should be False if not using past"
171
+ )
172
+
173
+
90
174
if is_torch_available ():
91
175
from transformers import AlbertModel , BartModel , BertModel , DistilBertModel , GPT2Model , RobertaModel , T5Model , XLMRobertaModel
92
176
93
177
PYTORCH_EXPORT_DEFAULT_MODELS = {
94
- ("ALBERT" , "albert-base-v2" , AlbertModel , AlbertConfig , AlbertOnnxConfig ),
95
- ("BART" , "facebook/bart-base" , BartModel , BartConfig , BartOnnxConfig ),
96
- ("BERT" , "bert-base-cased" , BertModel , BertConfig , BertOnnxConfig ),
97
- ("DistilBERT" , "distilbert-base-cased" , DistilBertModel , DistilBertConfig , DistilBertOnnxConfig ),
98
- ("GPT2" , "gpt2" , GPT2Model , GPT2Config , GPT2OnnxConfig ),
99
- # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
100
- ("Roberta" , "roberta-base" , RobertaModel , RobertaConfig , RobertaOnnxConfig ),
101
- ("XLM-Roberta" , "roberta-base" , XLMRobertaModel , XLMRobertaConfig , XLMRobertaOnnxConfig ),
102
- ("T5" , "t5-small" , T5Model , T5Config , T5OnnxConfig )
178
+ # ("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
179
+ # ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
180
+ # ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
181
+ # ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
182
+ # ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
183
+ # # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
184
+ # ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
185
+ # ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
186
+ # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
103
187
}
104
188
105
189
PYTORCH_EXPORT_WITH_PAST_MODELS = {
106
- ("BART" , ),
107
- ("GPT2" , ),
108
- ("T5" , )
190
+ ("BART" , "facebook/bart-base" , BartModel , BartConfig , BartOnnxConfig ),
191
+ ("GPT2" , "gpt2" , GPT2Model , GPT2Config , GPT2OnnxConfig ),
192
+ # ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig )
109
193
}
110
194
111
195
112
196
class OnnxExportTestCaseV2 (TestCase ):
197
+ """
198
+ Integration tests ensuring supported models are correctly exported
199
+ """
113
200
@slow
114
201
@require_torch
115
202
def test_pytorch_export_default (self ):
116
203
from transformers .onnx .convert import convert_pytorch
117
204
118
205
for name , model , model_class , config_class , onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS :
119
-
120
206
with self .subTest (name ):
207
+ self .assertTrue (hasattr (onnx_config_class , "default" ))
208
+
121
209
tokenizer = AutoTokenizer .from_pretrained (model )
122
210
model = model_class (config_class ())
123
211
onnx_config = onnx_config_class .default (model .config )
@@ -128,4 +216,27 @@ def test_pytorch_export_default(self):
128
216
@slow
129
217
@require_torch
130
218
def test_pytorch_export_with_past (self ):
131
- pass
219
+ from transformers .onnx .convert import convert_pytorch
220
+
221
+ for name , model , model_class , config_class , onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS :
222
+ with self .subTest (name ):
223
+ self .assertTrue (hasattr (onnx_config_class , "with_past" ), "OnnxConfigWithPast should have with_past()" )
224
+
225
+ tokenizer = AutoTokenizer .from_pretrained (model )
226
+ model = model_class (config_class ())
227
+ onnx_config = onnx_config_class .with_past (model .config )
228
+
229
+ self .assertTrue (hasattr (onnx_config , "use_past" ), "OnnxConfigWithPast should have use_past attribute." )
230
+ self .assertTrue (
231
+ onnx_config .use_past ,
232
+ "OnnxConfigWithPast.use_past should be if called with with_past()"
233
+ )
234
+
235
+ with NamedTemporaryFile ("w" ) as output :
236
+ onnx_inputs , onnx_outputs = \
237
+ convert_pytorch (tokenizer , model , onnx_config , DEFAULT_ONNX_OPSET , Path (output .name ))
238
+
239
+ try :
240
+ validate_model_outputs (onnx_config , tokenizer , model , Path (output .name ), onnx_outputs , 1e-5 )
241
+ except ValueError as ve :
242
+ self .fail (f"{ name } -> { ve } " )
0 commit comments