|
10 | 10 | # LongformerConfig,
|
11 | 11 | RobertaConfig,
|
12 | 12 | XLMRobertaConfig,
|
13 |
| - is_torch_available, |
| 13 | + is_torch_available, BartConfig, GPT2Config, T5Config, |
14 | 14 | )
|
15 | 15 | from transformers.models.albert import AlbertOnnxConfig
|
| 16 | +from transformers.models.bart import BartOnnxConfig |
16 | 17 | from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
|
17 | 18 | from transformers.models.distilbert import DistilBertOnnxConfig
|
18 | 19 | # from transformers.models.longformer import LongformerOnnxConfig
|
| 20 | +from transformers.models.gpt2 import GPT2OnnxConfig |
19 | 21 | from transformers.models.roberta import RobertaOnnxConfig
|
| 22 | +from transformers.models.t5 import T5OnnxConfig |
20 | 23 | from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
|
21 | 24 | from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
|
22 | 25 |
|
@@ -85,15 +88,24 @@ def test_use_external_data_format(self):
|
85 | 88 |
|
86 | 89 |
|
87 | 90 | if is_torch_available():
|
88 |
| - from transformers import AlbertModel, BertModel, DistilBertModel, LongformerModel, RobertaModel, XLMRobertaModel |
| 91 | + from transformers import AlbertModel, BartModel, BertModel, DistilBertModel, GPT2Model, RobertaModel, T5Model, XLMRobertaModel |
89 | 92 |
|
90 | 93 | PYTORCH_EXPORT_DEFAULT_MODELS = {
|
91 | 94 | ("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
|
| 95 | + ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig), |
92 | 96 | ("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
|
93 | 97 | ("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
|
| 98 | + ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig), |
94 | 99 | # ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
|
95 | 100 | ("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
|
96 | 101 | ("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
|
| 102 | + ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig) |
| 103 | + } |
| 104 | + |
| 105 | + PYTORCH_EXPORT_WITH_PAST_MODELS = { |
| 106 | + ("BART", ), |
| 107 | + ("GPT2", ), |
| 108 | + ("T5", ) |
97 | 109 | }
|
98 | 110 |
|
99 | 111 |
|
|
0 commit comments