Skip to content

Commit

Permalink
Refactor errors (#189)
Browse files Browse the repository at this point in the history
Refactor errors according to
pytorch/pytorch#135180
  • Loading branch information
justinchuby authored Sep 6, 2024
1 parent 1edda23 commit 47bf897
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 27 deletions.
6 changes: 3 additions & 3 deletions src/torch_onnx/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _add_nodes(
# No lowering
_handle_call_function_node(model.graph, node, node_name_to_values)
except Exception as e:
raise errors.OnnxConversionError(
raise errors.ConversionError(
f"Error when translating node {node.format_node()}. See the stack trace for more information."
) from e
return node_name_to_values
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def export(
else:
report_path = None

raise errors.OnnxConversionError(
raise errors.ConversionError(
_STEP_TWO_ERROR_MESSAGE
+ (f"\nError report has been saved to '{report_path}'." if report else "")
+ _summarize_exception_stack(e)
Expand Down Expand Up @@ -1158,7 +1158,7 @@ def export(
else:
report_path = None

raise errors.OnnxConversionError(
raise errors.ConversionError(
_STEP_TWO_ERROR_MESSAGE
+ (f"\nError report has been saved to '{report_path}'." if report else "")
+ _summarize_exception_stack(e)
Expand Down
28 changes: 5 additions & 23 deletions src/torch_onnx/errors.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,28 @@
class ExporterError(RuntimeError):
class OnnxExporterError(RuntimeError):
"""Error during export."""

pass


class TorchExportError(ExporterError):
class TorchExportError(OnnxExporterError):
"""Error during torch.export.export."""

pass


class OnnxConversionError(ExporterError):
class ConversionError(OnnxExporterError):
"""Error during ONNX conversion."""

pass


class DispatchError(OnnxConversionError):
class DispatchError(ConversionError):
"""Error during ONNX Funtion dispatching."""

pass


class GraphConstructionError(OnnxConversionError):
class GraphConstructionError(ConversionError):
"""Error during graph construction."""

pass


class OnnxCheckerError(ExporterError):
"""Error during ONNX model checking."""

pass


class OnnxRuntimeError(ExporterError):
"""Error during ONNX Runtime execution."""

pass


class OnnxValidationError(ExporterError):
"""Output value mismatch."""

pass
2 changes: 1 addition & 1 deletion tests/torch_tests/fx_consistency_test_.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,7 +1878,7 @@ def _run_test_output_match(
model,
*inputs,
)
except torch.onnx.OnnxExporterError as e:
except torch_onnx.errors.OnnxExporterError as e:
# NOTE: If the model has unsupported nodes, we will skip the test
# with non-strict xfail. Otherwise, we will raise the error.
if hasattr(
Expand Down

0 comments on commit 47bf897

Please sign in to comment.