|
1 | 1 | from pathlib import Path
|
| 2 | +from tempfile import NamedTemporaryFile |
2 | 3 | from unittest import TestCase
|
3 | 4 | from unittest.mock import patch
|
4 | 5 |
|
5 |
| -from transformers.models.bert.configuration_bert import BertOnnxConfig |
| 6 | +from transformers import AutoTokenizer, is_torch_available, AlbertConfig, DistilBertConfig, LongformerConfig, \ |
| 7 | + RobertaConfig, XLMRobertaConfig |
| 8 | +from transformers.models.albert import AlbertOnnxConfig |
| 9 | + |
| 10 | +from transformers.models.bert.configuration_bert import BertOnnxConfig, BertConfig |
| 11 | +from transformers.models.distilbert import DistilBertOnnxConfig |
| 12 | +from transformers.models.longformer import LongformerOnnxConfig |
| 13 | +from transformers.models.roberta import RobertaOnnxConfig |
| 14 | +from transformers.models.xlm_roberta import XLMRobertaOnnxConfig |
6 | 15 | from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
|
7 | 16 |
|
8 | 17 | # from transformers.onnx.convert import convert_pytorch
|
| 18 | +from transformers.onnx.config import DEFAULT_ONNX_OPSET |
9 | 19 | from transformers.onnx.utils import (
|
10 | 20 | compute_effective_axis_dimension,
|
11 | 21 | compute_serialized_parameters_size,
|
12 | 22 | flatten_output_collection_property,
|
13 | 23 | )
|
14 |
| -from transformers.testing_utils import require_onnx |
| 24 | +from transformers.testing_utils import require_onnx, slow, require_torch |
15 | 25 |
|
16 | 26 |
|
17 | 27 | @require_onnx
|
@@ -69,13 +79,35 @@ def test_use_external_data_format(self):
|
69 | 79 | self.assertTrue(OnnxConfig.use_external_data_format((TWO_GB_LIMIT + 1) // ParameterFormat.Float.size))
|
70 | 80 |
|
71 | 81 |
|
72 |
| -class OnnxExportTestCaseV2(TestCase): |
73 |
| - EXPORT_DEFAULT_MODELS = { |
74 |
| - ("BERT", "bert-base-cased", BertOnnxConfig), |
| 82 | +if is_torch_available(): |
| 83 | + from transformers import AlbertModel, BertModel, DistilBertModel, LongformerModel, RobertaModel, XLMRobertaModel |
| 84 | + PYTORCH_EXPORT_DEFAULT_MODELS = { |
| 85 | + ("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig), |
| 86 | + ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig), |
| 87 | + ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig), |
| 88 | + # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig), |
| 89 | + ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig), |
| 90 | + ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig), |
75 | 91 | }
|
76 | 92 |
|
77 |
| - def export_default(self): |
78 |
| - pass |
79 | 93 |
|
80 |
| - def export_with_past(self): |
| 94 | +class OnnxExportTestCaseV2(TestCase): |
| 95 | + @slow |
| 96 | + @require_torch |
| 97 | + def test_pytorch_export_default(self): |
| 98 | + from transformers.onnx.convert import convert_pytorch |
| 99 | + |
| 100 | + for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS: |
| 101 | + |
| 102 | + with self.subTest(name): |
| 103 | + tokenizer = AutoTokenizer.from_pretrained(model) |
| 104 | + model = model_class(config_class()) |
| 105 | + onnx_config = onnx_config_class.default(model.config) |
| 106 | + |
| 107 | + with NamedTemporaryFile("w") as output: |
| 108 | + convert_pytorch(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)) |
| 109 | + |
| 110 | + @slow |
| 111 | + @require_torch |
| 112 | + def test_pytorch_export_with_past(self): |
81 | 113 | pass
|
0 commit comments