Skip to content

Commit

Permalink
Refine the dynamic shapes based on the suggested fixes (#180)
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby authored Sep 3, 2024
1 parent d8fea7c commit 313d327
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
35 changes: 29 additions & 6 deletions src/torch_onnx/_capture_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TYPE_CHECKING, Any, Callable

import torch
import torch._dynamo.exc
from torch.utils import _pytree

from torch_onnx import _torchscript_converter
Expand Down Expand Up @@ -121,9 +122,20 @@ class TorchExportStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
try:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes
)
except torch._dynamo.exc.UserError as exc:
# Refine the dynamic shapes based on the suggested fixes.
new_shapes = (
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
exc.msg, dynamic_shapes
)
)
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=new_shapes
)

def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
Expand All @@ -149,9 +161,20 @@ class TorchExportNonStrictStrategy(CaptureStrategy):
def _capture(
self, model, args, kwargs, dynamic_shapes
) -> torch.export.ExportedProgram:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
)
try:
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes, strict=False
)
except torch._dynamo.exc.UserError as exc:
# Refine the dynamic shapes based on the suggested fixes.
new_shapes = (
torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes(
exc.msg, dynamic_shapes
)
)
return torch.export.export(
model, args, kwargs=kwargs, dynamic_shapes=new_shapes, strict=False
)

def _enter(self, model) -> None:
model_repr = _take_first_line(repr(model))
Expand Down
4 changes: 2 additions & 2 deletions src/torch_onnx/_onnx_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import onnxscript._framework_apis.torch_2_5 as onnxscript_apis
import torch
from onnxscript import ir
from torch.utils import _pytree as pytree
from torch.utils import _pytree

if TYPE_CHECKING:
import onnxruntime as ort # type: ignore[import-untyped]
Expand Down Expand Up @@ -215,7 +215,7 @@ def _process_args(args, kwargs) -> tuple[torch.Tensor, ...]:


def _flatten_inputs(model_args, model_kwargs):
flattened_args, _ = pytree.tree_flatten((model_args, model_kwargs))
flattened_args, _ = _pytree.tree_flatten((model_args, model_kwargs))
return flattened_args


Expand Down

0 comments on commit 313d327

Please sign in to comment.