Skip to content

Commit 9572fa2

Browse files
committed
Enable all supported default models.
1 parent fa08375 commit 9572fa2

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

tests/test_onnx_v2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,16 @@
1010
# LongformerConfig,
1111
RobertaConfig,
1212
XLMRobertaConfig,
13-
is_torch_available,
13+
is_torch_available, BartConfig, GPT2Config, T5Config,
1414
)
1515
from transformers.models.albert import AlbertOnnxConfig
16+
from transformers.models.bart import BartOnnxConfig
1617
from transformers.models.bert.configuration_bert import BertConfig, BertOnnxConfig
1718
from transformers.models.distilbert import DistilBertOnnxConfig
1819
# from transformers.models.longformer import LongformerOnnxConfig
20+
from transformers.models.gpt2 import GPT2OnnxConfig
1921
from transformers.models.roberta import RobertaOnnxConfig
22+
from transformers.models.t5 import T5OnnxConfig
2023
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
2124
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
2225

@@ -85,15 +88,24 @@ def test_use_external_data_format(self):
8588

8689

8790
if is_torch_available():
88-
from transformers import AlbertModel, BertModel, DistilBertModel, LongformerModel, RobertaModel, XLMRobertaModel
91+
from transformers import AlbertModel, BartModel, BertModel, DistilBertModel, GPT2Model, RobertaModel, T5Model, XLMRobertaModel
8992

9093
PYTORCH_EXPORT_DEFAULT_MODELS = {
9194
("ALBERT", "albert-base-v2", AlbertModel, AlbertConfig, AlbertOnnxConfig),
95+
("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
9296
("BERT", "bert-base-cased", BertModel, BertConfig, BertOnnxConfig),
9397
("DistilBERT", "distilbert-base-cased", DistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
98+
("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
9499
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
95100
("Roberta", "roberta-base", RobertaModel, RobertaConfig, RobertaOnnxConfig),
96101
("XLM-Roberta", "roberta-base", XLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
102+
("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
103+
}
104+
105+
PYTORCH_EXPORT_WITH_PAST_MODELS = {
106+
("BART", ),
107+
("GPT2", ),
108+
("T5", )
97109
}
98110

99111

0 commit comments

Comments
 (0)