Skip to content

Commit

Permalink
Rename _torch_onnx_export
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 6, 2024
1 parent 0f5afaf commit a8d7d3b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/torch_onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from ._analysis import analyze
from ._core import export, exported_program_to_ir
from ._onnx_program import ONNXProgram
from ._patch import _export_compat as export_compat
from ._patch import _torch_onnx_export as export_compat
from ._patch import patch_torch, unpatch_torch
from ._registration import ONNXRegistry
6 changes: 3 additions & 3 deletions src/torch_onnx/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _get_torch_export_args(
return args, kwargs


def _export_compat(
def _torch_onnx_export(
model: torch.nn.Module | torch.export.ExportedProgram,
args: tuple[Any, ...],
f: str | os.PathLike | None = None,
Expand Down Expand Up @@ -261,7 +261,7 @@ def _to_dynamic_shape(x):
else:
dynamic_shapes = None

return _export_compat(
return _torch_onnx_export(
model,
model_args,
kwargs=model_kwargs,
Expand Down Expand Up @@ -311,7 +311,7 @@ def patch_torch(
_ARTIFACTS_DIR = artifacts_dir
global _FALLBACK_TO_LEGACY_EXPORT # noqa: PLW0603
_FALLBACK_TO_LEGACY_EXPORT = fallback
torch.onnx.export = _export_compat # type: ignore[assignment]
torch.onnx.export = _torch_onnx_export # type: ignore[assignment]
torch.onnx.dynamo_export = _torch_onnx_dynamo_export # type: ignore[assignment]


Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def forward(self, x, y):
"x": (torch.export.Dim("dx"),),
"y": (torch.export.Dim("dy"),),
}
onnx_program = torch_onnx._patch._export_compat(
onnx_program = torch_onnx.export_compat(
TestRefineDynamicShapeModel(), inps, dynamic_shapes=dynamic_shapes
)
assert onnx_program is not None
Expand Down

0 comments on commit a8d7d3b

Please sign in to comment.