From 11da12f813e3e0ad10171550d614fbcf0698ab61 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 28 Jul 2025 10:00:16 -0700 Subject: [PATCH] add try catch around provenance tracking Summary: Add try-except around provenance tracking logic to make it more safe Test Plan: CI Rollback Plan: Differential Revision: D79008234 --- torch/_inductor/compile_fx.py | 1 - torch/_inductor/debug.py | 116 ++++++++++++++++++---------------- torch/fx/traceback.py | 25 +++++--- 3 files changed, 78 insertions(+), 64 deletions(-) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 2fef157859d7..a52831701247 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1058,7 +1058,6 @@ def _compile_fx_inner( if config.trace.provenance_tracking: # Dump provenance artifacts for debugging trace provenance_info = torch._inductor.debug.dump_inductor_provenance_info() - # provenance_info might be None if trace.provenance_tracking is not set if provenance_info: trace_structured( "artifact", diff --git a/torch/_inductor/debug.py b/torch/_inductor/debug.py index 23b26765df2b..c64959ecfba1 100644 --- a/torch/_inductor/debug.py +++ b/torch/_inductor/debug.py @@ -851,28 +851,32 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None: def dump_inductor_provenance_info( filename: str = "inductor_generated_kernel_to_post_grad_nodes.json", ) -> dict[str, Any]: - global _pre_grad_graph_id - global _inductor_post_to_pre_grad_nodes - global _inductor_triton_kernel_to_post_grad_node_info - if config.trace.enabled: - with V.debug.fopen(filename, "w") as fd: - log.info("Writing provenance tracing debugging info to %s", fd.name) - json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) - node_mapping = {} - if _pre_grad_graph_id: - node_mapping_kernel = create_node_mapping_kernel_to_post_grad( - _inductor_triton_kernel_to_post_grad_node_info - ) - node_mapping = { - **_inductor_post_to_pre_grad_nodes, - **node_mapping_kernel, - } + try: + global _pre_grad_graph_id + global _inductor_post_to_pre_grad_nodes + global _inductor_triton_kernel_to_post_grad_node_info if config.trace.enabled: - with V.debug.fopen( - "inductor_provenance_tracking_node_mappings.json", "w" - ) as fd: - json.dump(node_mapping, fd) - return node_mapping + with V.debug.fopen(filename, "w") as fd: + log.info("Writing provenance tracing debugging info to %s", fd.name) + json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd) + node_mapping = {} + if _pre_grad_graph_id: + node_mapping_kernel = create_node_mapping_kernel_to_post_grad( + _inductor_triton_kernel_to_post_grad_node_info + ) + node_mapping = { + **_inductor_post_to_pre_grad_nodes, + **node_mapping_kernel, + } + if config.trace.enabled: + with V.debug.fopen( + "inductor_provenance_tracking_node_mappings.json", "w" + ) as fd: + json.dump(node_mapping, fd) + return node_mapping + except Exception as e: + log.error("Unexpected error in dump_inductor_provenance_info: %s", e) + return {} def set_kernel_post_grad_provenance_tracing( @@ -880,42 +884,46 @@ def set_kernel_post_grad_provenance_tracing( kernel_name: str, is_extern: bool = False, ) -> None: - from .codegen.simd_kernel_features import DisableReduction, EnableReduction + try: + from .codegen.simd_kernel_features import DisableReduction, EnableReduction - global _inductor_triton_kernel_to_post_grad_node_info - if is_extern: - assert isinstance(node_schedule, ExternKernelOut) - curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] - ) - # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. - # "origin_node" is more precise and says that the contents of this node corresponds - # EXACTLY to the output of a particular FX node, but it's not always available - if node_schedule.origin_node: - origin_node_name = node_schedule.origin_node.name - if origin_node_name not in curr_node_info: - curr_node_info.append(origin_node_name) - else: - curr_node_info.extend( - origin.name - for origin in node_schedule.origins - if origin.name not in curr_node_info + global _inductor_triton_kernel_to_post_grad_node_info + if is_extern: + assert isinstance(node_schedule, ExternKernelOut) + curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] ) - else: - assert isinstance(node_schedule, list) - for snode in node_schedule: - if snode not in (EnableReduction, DisableReduction): - if snode.node is not None: - curr_node_info = ( - _inductor_triton_kernel_to_post_grad_node_info.setdefault( - kernel_name, [] + # 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel. + # "origin_node" is more precise and says that the contents of this node corresponds + # EXACTLY to the output of a particular FX node, but it's not always available + if node_schedule.origin_node: + origin_node_name = node_schedule.origin_node.name + if origin_node_name not in curr_node_info: + curr_node_info.append(origin_node_name) + else: + curr_node_info.extend( + origin.name + for origin in node_schedule.origins + if origin.name not in curr_node_info + ) + else: + assert isinstance(node_schedule, list) + for snode in node_schedule: + if snode not in (EnableReduction, DisableReduction): + if snode.node is not None: + curr_node_info = ( + _inductor_triton_kernel_to_post_grad_node_info.setdefault( + kernel_name, [] + ) ) - ) - curr_node_info.extend( - origin.name - for origin in snode.node.origins - if origin.name not in curr_node_info - ) + curr_node_info.extend( + origin.name + for origin in snode.node.origins + if origin.name not in curr_node_info + ) + except Exception as e: + log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e) + log.error(traceback.format_exc()) def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None: diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 648a80b87b68..6d507373d2e8 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import copy +import logging import traceback from contextlib import contextmanager from enum import Enum @@ -10,6 +11,8 @@ from .node import Node +log = logging.getLogger(__name__) + __all__ = [ "preserve_node_meta", "has_preserved_node_meta", @@ -311,12 +314,16 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: """ Given an fx.Graph, return a json that contains the provenance information of each node. """ - provenance_tracking_json = {} - for node in graph.nodes: - if node.op == "call_function": - provenance_tracking_json[node.name] = ( - [source.to_dict() for source in node.meta["from_node"]] - if "from_node" in node.meta - else [] - ) - return provenance_tracking_json + try: + provenance_tracking_json = {} + for node in graph.nodes: + if node.op == "call_function": + provenance_tracking_json[node.name] = ( + [source.to_dict() for source in node.meta["from_node"]] + if "from_node" in node.meta + else [] + ) + return provenance_tracking_json + except Exception as e: + log.error("Unexpected error in dump_inductor_provenance_info: %s", e) + return {}