Skip to content

Commit 4eebfda

Browse files
committed
Rename convert_pytorch to export
1 parent 2fc079f commit 4eebfda

File tree

4 files changed

+9
-9
lines changed

4 files changed

+9
-9
lines changed

src/transformers/onnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@
1414
# limitations under the License.
1515

1616
from .config import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, OnnxConfigWithPast
17+
from .convert import export, validate_model_outputs
1718
from .utils import ParameterFormat, compute_serialized_parameters_size

src/transformers/onnx/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from .. import is_torch_available
3131
from ..utils import logging
32-
from .convert import convert_pytorch, validate_model_outputs
32+
from .convert import export, validate_model_outputs
3333

3434

3535
if is_torch_available():
@@ -142,7 +142,7 @@ def main():
142142
f"At least {onnx_config.default_onnx_opset} is required."
143143
)
144144

145-
onnx_inputs, onnx_outputs = convert_pytorch(tokenizer, model, onnx_config, args.opset, args.output)
145+
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, args.opset, args.output)
146146

147147
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
148148
logger.info(f"All good, model saved at: {args.output.as_posix()}")

src/transformers/onnx/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def check_onnxruntime_requirements(minimum_version: Version):
6262
)
6363

6464

65-
def convert_pytorch(
65+
def export(
6666
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
6767
) -> Tuple[List[str], List[str]]:
6868
"""

tests/test_onnx_v2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
from transformers.models.roberta import RobertaOnnxConfig
2525
from transformers.models.t5 import T5OnnxConfig
2626
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
27-
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat
27+
from transformers.onnx import EXTERNAL_DATA_FORMAT_SIZE_LIMIT, OnnxConfig, ParameterFormat, validate_model_outputs
2828
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
29-
from transformers.onnx.convert import validate_model_outputs
3029
from transformers.onnx.utils import (
3130
compute_effective_axis_dimension,
3231
compute_serialized_parameters_size,
@@ -204,7 +203,7 @@ class OnnxExportTestCaseV2(TestCase):
204203
@slow
205204
@require_torch
206205
def test_pytorch_export_default(self):
207-
from transformers.onnx.convert import convert_pytorch
206+
from transformers.onnx import export
208207

209208
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_DEFAULT_MODELS:
210209
with self.subTest(name):
@@ -215,7 +214,7 @@ def test_pytorch_export_default(self):
215214
onnx_config = onnx_config_class.default(model.config)
216215

217216
with NamedTemporaryFile("w") as output:
218-
onnx_inputs, onnx_outputs = convert_pytorch(
217+
onnx_inputs, onnx_outputs = export(
219218
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
220219
)
221220

@@ -227,7 +226,7 @@ def test_pytorch_export_default(self):
227226
@slow
228227
@require_torch
229228
def test_pytorch_export_with_past(self):
230-
from transformers.onnx.convert import convert_pytorch
229+
from transformers.onnx import export
231230

232231
for name, model, model_class, config_class, onnx_config_class in PYTORCH_EXPORT_WITH_PAST_MODELS:
233232
with self.subTest(name):
@@ -244,7 +243,7 @@ def test_pytorch_export_with_past(self):
244243

245244
with NamedTemporaryFile("w") as output:
246245
output = Path(output.name)
247-
onnx_inputs, onnx_outputs = convert_pytorch(
246+
onnx_inputs, onnx_outputs = export(
248247
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output
249248
)
250249

0 commit comments

Comments
 (0)