Skip to content

Commit 92f9d9f

Browse files
committed
Added unittests and docstrings.
1 parent 3feaa5b commit 92f9d9f

File tree

1 file changed

+126
-15
lines changed

1 file changed

+126
-15
lines changed

tests/test_onnx_v2.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
2424
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
2525

26-
from transformers.onnx.config import DEFAULT_ONNX_OPSET
26+
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
27+
from transformers.onnx.convert import validate_model_outputs
2728
from transformers.onnx.utils import (
2829
compute_effective_axis_dimension,
2930
compute_serialized_parameters_size,
@@ -34,7 +35,19 @@
3435

3536
@require_onnx
3637
class OnnxUtilsTestCaseV2(TestCase):
38+
"""
39+
Cover all the utilities involved to export ONNX models
40+
"""
41+
3742
def test_compute_effective_axis_dimension(self):
43+
"""
44+
When exporting ONNX model with dynamic axis (batch or sequence) we set batch_size and/or sequence_length = -1.
45+
We cannot generate an effective tensor with axis dim == -1, so we trick by using some "fixed" values
46+
(> 1 to avoid ONNX squeezing the axis).
47+
48+
This test ensure we are correctly replacing generated batch / sequence tensor with axis > 1
49+
"""
50+
3851
# Dynamic axis (batch, no token added by the tokenizer)
3952
self.assertEqual(compute_effective_axis_dimension(-1, fixed_dimension=2, num_token_to_add=0), 2)
4053

@@ -50,9 +63,19 @@ def test_compute_effective_axis_dimension(self):
5063
self.assertEqual(compute_effective_axis_dimension(0, fixed_dimension=8, num_token_to_add=3), 5)
5164

5265
def test_compute_parameters_serialized_size(self):
66+
"""
67+
This test ensures we compute a "correct" approximation of the underlying storage requirement (size) for all the
68+
parameters for the specified parameter's dtype.
69+
"""
5370
self.assertEqual(compute_serialized_parameters_size(2, ParameterFormat.Float), 2 * ParameterFormat.Float.size)
5471

5572
def test_flatten_output_collection_property(self):
73+
"""
74+
This test ensures we correctly flatten nested collection such as the one we use when returning past_keys.
75+
past_keys = Tuple[Tuple]
76+
77+
ONNX exporter will export nested collections as ${collection_name}.${level_idx_0}.${level_idx_1}...${idx_n}
78+
"""
5679
self.assertEqual(
5780
flatten_output_collection_property("past_key", [[0], [1], [2]]),
5881
{
@@ -64,6 +87,12 @@ def test_flatten_output_collection_property(self):
6487

6588

6689
class OnnxConfigTestCaseV2(TestCase):
90+
"""
91+
Cover the test for models default.
92+
93+
Default means no specific features is being enabled on the model.
94+
"""
95+
6796
@patch.multiple(OnnxConfig, __abstractmethods__=set())
6897
def test_use_external_data_format(self):
6998
"""
@@ -87,37 +116,96 @@ def test_use_external_data_format(self):
87116
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
88117

89118

119+
class OnnxConfigWithPastTestCaseV2(TestCase):
120+
"""
121+
Cover the tests for model which have use_cache feature (i.e. "with_past" for ONNX)
122+
"""
123+
124+
SUPPORTED_WITH_PAST_CONFIGS = {
125+
("BART", BartConfig),
126+
("GPT2", GPT2Config),
127+
("T5", T5Config)
128+
}
129+
130+
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
131+
def test_use_past(self):
132+
"""
133+
Ensure the use_past variable is correctly being set
134+
"""
135+
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
136+
with self.subTest(name):
137+
self.assertFalse(
138+
OnnxConfigWithPast.default(config()).use_past,
139+
"OnnxConfigWithPast.default() should not use_past"
140+
)
141+
142+
self.assertTrue(
143+
OnnxConfigWithPast.with_past(config()).use_past,
144+
"OnnxConfigWithPast.default() should use_past"
145+
)
146+
147+
@patch.multiple(OnnxConfigWithPast, __abstractmethods__=set())
148+
def test_values_override(self):
149+
"""
150+
Ensure the use_past variable correctly set the `use_cache` value in model's configuration
151+
"""
152+
for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
153+
with self.subTest(name):
154+
155+
# without past
156+
onnx_config_default = OnnxConfigWithPast.default(config())
157+
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
158+
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
159+
self.assertFalse(
160+
onnx_config_default.values_override["use_cache"],
161+
"use_cache should be False if not using past"
162+
)
163+
164+
# with past
165+
onnx_config_default = OnnxConfigWithPast.with_past(config())
166+
self.assertIsNotNone(onnx_config_default.values_override, "values_override should not be None")
167+
self.assertIn("use_cache", onnx_config_default.values_override, "use_cache should be present")
168+
self.assertTrue(
169+
onnx_config_default.values_override["use_cache"],
170+
"use_cache should be False if not using past"
171+
)
172+
173+
90174
if is_torch_available():
91175
from transformers import AlbertModel, BartModel, BertModel, DistilBertModel, GPT2Model, RobertaModel, T5Model, XLMRobertaModel
92176

93177
PYTORCH_EXPORT_DEFAULT_MODELS = {
94-
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
95-
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
96-
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
97-
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
98-
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
99-
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
100-
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
101-
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
102-
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
178+
# ("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
179+
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
180+
# ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
181+
# ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
182+
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
183+
# # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
184+
# ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
185+
# ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
186+
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
103187
}
104188

105189
PYTORCH_EXPORT_WITH_PAST_MODELS = {
106-
("BART", ),
107-
("GPT2", ),
108-
("T5", )
190+
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
191+
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
192+
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
109193
}
110194

111195

112196
class OnnxExportTestCaseV2(TestCase):
197+
"""
198+
Integration tests ensuring supported models are correctly exported
199+
"""
113200
@slow
114201
@require_torch
115202
def test_pytorch_export_default(self):
116203
from transformers.onnx.convert import convert_pytorch
117204

118205
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
119-
120206
with self.subTest(name):
207+
self.assertTrue(hasattr(onnx_config_class, "default"))
208+
121209
tokenizer = AutoTokenizer.from_pretrained(model)
122210
model = model_class(config_class())
123211
onnx_config = onnx_config_class.default(model.config)
@@ -128,4 +216,27 @@ def test_pytorch_export_default(self):
128216
@slow
129217
@require_torch
130218
def test_pytorch_export_with_past(self):
131-
pass
219+
from transformers.onnx.convert import convert_pytorch
220+
221+
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
222+
with self.subTest(name):
223+
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")
224+
225+
tokenizer = AutoTokenizer.from_pretrained(model)
226+
model = model_class(config_class())
227+
onnx_config = onnx_config_class.with_past(model.config)
228+
229+
self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
230+
self.assertTrue(
231+
onnx_config.use_past,
232+
"OnnxConfigWithPast.use_past should be if called with with_past()"
233+
)
234+
235+
with NamedTemporaryFile("w") as output:
236+
onnx_inputs, onnx_outputs = \
237+
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))
238+
239+
try:
240+
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
241+
except ValueError as ve:
242+
self.fail(f"{name} -> {ve}")

0 commit comments

Comments
 (0)