Skip to content

Commit 4311e8a

Browse files
committed
Enable ONNX export test for supported model.
1 parent 6b1899f commit 4311e8a

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

tests/test_onnx_v2.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
from pathlib import Path
2+
from tempfile import NamedTemporaryFile
23
from unittest import TestCase
34
from unittest.mock import patch
45

5-
from transformers.models.bert.configuration_bert import BertOnnxConfig
6+
from transformers import AutoTokenizer, is_torch_available, AlbertConfig, DistilBertConfig, LongformerConfig, \
7+
RobertaConfig, XLMRobertaConfig
8+
from transformers.models.albert import AlbertOnnxConfig
9+
10+
from transformers.models.bert.configuration_bert import BertOnnxConfig, BertConfig
11+
from transformers.models.distilbert import DistilBertOnnxConfig
12+
from transformers.models.longformer import LongformerOnnxConfig
13+
from transformers.models.roberta import RobertaOnnxConfig
14+
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
615
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
716

817
# from transformers.onnx.convert import convert_pytorch
18+
from transformers.onnx.config import DEFAULT_ONNX_OPSET
919
from transformers.onnx.utils import (
1020
compute_effective_axis_dimension,
1121
compute_serialized_parameters_size,
1222
flatten_output_collection_property,
1323
)
14-
from transformers.testing_utils import require_onnx
24+
from transformers.testing_utils import require_onnx, slow, require_torch
1525

1626

1727
@require_onnx
@@ -69,13 +79,35 @@ def test_use_external_data_format(self):
6979
self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
7080

7181

72-
class OnnxExportTestCaseV2(TestCase):
73-
EXPORT_DEFAULT_MODELS = {
74-
("BERT", "bert-base-cased", BertOnnxConfig),
82+
if is_torch_available():
83+
from transformers import AlbertModel, BertModel, DistilBertModel, LongformerModel, RobertaModel, XLMRobertaModel
84+
PYTORCH_EXPORT_DEFAULT_MODELS = {
85+
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
86+
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
87+
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
88+
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
89+
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
90+
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
7591
}
7692

77-
def export_default(self):
78-
pass
7993

80-
def export_with_past(self):
94+
class OnnxExportTestCaseV2(TestCase):
95+
@slow
96+
@require_torch
97+
def test_pytorch_export_default(self):
98+
from transformers.onnx.convert import convert_pytorch
99+
100+
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
101+
102+
with self.subTest(name):
103+
tokenizer = AutoTokenizer.from_pretrained(model)
104+
model = model_class(config_class())
105+
onnx_config = onnx_config_class.default(model.config)
106+
107+
with NamedTemporaryFile("w") as output:
108+
convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name))
109+
110+
@slow
111+
@require_torch
112+
def test_pytorch_export_with_past(self):
81113
pass

0 commit comments

Comments
 (0)