Note
Go to the end to download the full example code.
Introduction to ONNX || Exporting a PyTorch model to ONNX || Extending the ONNX exporter operator support || `Export a model with control flow to ONNX
Export a model with control flow to ONNX#
Author: Xavier Dupré
Overview#
This tutorial demonstrates how to handle control flow logic while exporting a PyTorch model to ONNX. It highlights the challenges of exporting conditional statements directly and provides solutions to circumvent them.
Conditional logic cannot be exported into ONNX unless they refactored
to use torch.cond()
. Let’s start with a simple model
implementing a test.
What you will learn:
How to refactor the model to use
torch.cond()
for exporting.How to export a model with control flow logic to ONNX.
How to optimize the exported model using the ONNX optimizer.
Prerequisites#
torch >= 2.6
import torch
Define the Models#
Two models are defined:
ForwardWithControlFlowTest
: A model with a forward method containing an
if-else conditional.
ModelWithControlFlowTest
: A model that incorporates ForwardWithControlFlowTest
as part of a simple MLP. The models are tested with
a random input tensor to confirm they execute as expected.
class ForwardWithControlFlowTest(torch.nn.Module):
def forward(self, x):
if x.sum():
return x * 2
return -x
class ModelWithControlFlowTest(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(3, 2),
torch.nn.Linear(2, 1),
ForwardWithControlFlowTest(),
)
def forward(self, x):
out = self.mlp(x)
return out
model = ModelWithControlFlowTest()
Exporting the Model: First Attempt#
Exporting this model using torch.export.export fails because the control
flow logic in the forward pass creates a graph break that the exporter cannot
handle. This behavior is expected, as conditional logic not written using
torch.cond()
is unsupported.
A try-except block is used to capture the expected failure during the export
process. If the export unexpectedly succeeds, an AssertionError
is raised.
x = torch.randn(3)
model(x)
try:
torch.export.export(model, (x,), strict=False)
raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none)
Caused by: (_export/non_strict_utils.py:1051 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing
For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
The following call raised this error:
File "/var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
if x.sum():
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
Using torch.onnx.export()
with JIT Tracing#
When exporting the model using torch.onnx.export()
with the dynamo=True
argument, the exporter defaults to using JIT tracing. This fallback allows
the model to export, but the resulting ONNX graph may not faithfully represent
the original model logic due to the limitations of tracing.
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None
ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None
item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`...
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3][1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None
l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None
# File: /var/lib/workspace/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=True)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`...
[torch.onnx] Draft Export report:
###################################################################################################
WARNING: 1 issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
1. Data dependent error.
When exporting, we were unable to evaluate the value of `Eq(u0, 1)`.
This was encountered 1 times.
This occurred at the following user stacktrace:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py, lineno 1773, in _wrapped_call_impl
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py, lineno 1784, in _call_impl
if x.sum():
Locals:
x: ['Tensor(shape: torch.Size([1]), stride: (1,), storage_offset: 0)']
And the following framework stacktrace:
File /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py, lineno 1360, in __torch_function__
File /usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py, lineno 1407, in __torch_function__
return func(*args, **kwargs)
As a result, it was specialized to a constant (e.g. `1` in the 1st occurrence), and asserts were inserted into the graph.
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export draft_export`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"mul"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_mul
%"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default"{2.0})
return %"mul"<FLOAT,[1]>
}
Suggested Patch: Refactoring with torch.cond()
#
To make the control flow exportable, the tutorial demonstrates replacing the
forward method in ForwardWithControlFlowTest
with a refactored version that
uses torch.cond`()
.
Details of the Refactoring:
Two helper functions (identity2 and neg) represent the branches of the conditional logic:
* torch.cond`()
is used to specify the condition and the two branches along with the input arguments.
* The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest
instance within the model. A list of submodules is printed to confirm the replacement.
def new_forward(x):
def identity2(x):
return x * 2
def neg(x):
return -x
return torch.cond(x.sum() > 0, identity2, neg, (x,))
print("the list of submodules")
for name, mod in model.named_modules():
print(name, type(mod))
if isinstance(mod, ForwardWithControlFlowTest):
mod.forward = new_forward
the list of submodules
<class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>
Let’s see what the FX graph looks like.
print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias); x = p_mlp_0_weight = p_mlp_0_bias = None
linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias); linear = p_mlp_1_weight = p_mlp_1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:244 in forward, code: input = module(input)
sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
# File: <eval_with_key>.25:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, (l_args_3_0_,)); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (linear_1,)); gt = true_graph_0 = false_graph_0 = linear_1 = None
getitem: "f32[1]" = cond[0]; cond = None
return (getitem,)
class true_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.22:6 in forward, code: mul = l_args_3_0__1.mul(2); l_args_3_0__1 = None
mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2); linear_1 = None
return (mul,)
class false_graph_0(torch.nn.Module):
def forward(self, linear_1: "f32[1]"):
# File: <eval_with_key>.23:6 in forward, code: neg = l_args_3_0__1.neg(); l_args_3_0__1 = None
neg: "f32[1]" = torch.ops.aten.neg.default(linear_1); linear_1 = None
return (neg,)
Graph signature:
# inputs
p_mlp_0_weight: PARAMETER target='mlp.0.weight'
p_mlp_0_bias: PARAMETER target='mlp.0.bias'
p_mlp_1_weight: PARAMETER target='mlp.1.weight'
p_mlp_1_bias: PARAMETER target='mlp.1.bias'
x: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {}
Let’s export again.
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
%"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_sum_1
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
5 | # node_gt
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
6 | # node_cond__0
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_mul
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_neg
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
We can optimize the model and get rid of the model local functions created to capture the control flow branches.
onnx_program.optimize()
print(onnx_program.model)
<
ir_version=10,
opset_imports={'': 18},
producer_name='pytorch',
producer_version='2.8.0+cu128',
domain=None,
model_version=None,
>
graph(
name=main_graph,
inputs=(
%"x"<FLOAT,[3]>
),
outputs=(
%"getitem"<FLOAT,[1]>
),
initializers=(
%"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([ 0.1817, -0.4669], requires_grad=True), name='mlp.0.bias')},
%"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3096], requires_grad=True), name='mlp.1.bias')},
%"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.02091814, 0.06923048], [ 0.20628652, -0.1461392 ], [-0.05838434, -0.20525895]], dtype=float32), name='val_0')},
%"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.1529778 ], [-0.41561294]], dtype=float32), name='val_2')},
%"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')},
%"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
),
) {
0 | # node_MatMul_1
%"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.020918138325214386, 0.06923048198223114], [0.20628651976585388, -0.1461392045021057], [-0.05838434025645256, -0.20525895059108734]]})
1 | # node_linear
%"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.1816701740026474, -0.4669439494609833]})
2 | # node_MatMul_3
%"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.15297779440879822], [-0.41561293601989746]]})
3 | # node_linear_1
%"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.30958300828933716]})
4 | # node_sum_1
%"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
5 | # node_gt
%"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
6 | # node_cond__0
%"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
graph(
name=true_graph_0,
inputs=(
),
outputs=(
%"mul_true_graph_0"<FLOAT,[1]>
),
) {
0 | # node_mul
%"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
return %"mul_true_graph_0"<FLOAT,[1]>
}, else_branch=
graph(
name=false_graph_0,
inputs=(
),
outputs=(
%"neg_false_graph_0"<FLOAT,[1]>
),
) {
0 | # node_neg
%"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
return %"neg_false_graph_0"<FLOAT,[1]>
}}
return %"getitem"<FLOAT,[1]>
}
Conclusion#
This tutorial demonstrates the challenges of exporting models with conditional
logic to ONNX and presents a practical solution using torch.cond()
.
While the default exporters may fail or produce imperfect graphs, refactoring the
model’s logic ensures compatibility and generates a faithful ONNX representation.
By understanding these techniques, we can overcome common pitfalls when working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
Further reading#
The list below refers to tutorials that ranges from basic examples to advanced scenarios, not necessarily in the order they are listed. Feel free to jump directly to specific topics of your interest or sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
Total running time of the script: (0 minutes 2.895 seconds)