Skip to content

add try catch around provenance tracking #159266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,6 @@ def _compile_fx_inner(
if config.trace.provenance_tracking:
# Dump provenance artifacts for debugging trace
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
# provenance_info might be None if trace.provenance_tracking is not set
if provenance_info:
trace_structured(
"artifact",
Expand Down
116 changes: 62 additions & 54 deletions torch/_inductor/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,71 +851,79 @@ def convert_sets_to_lists(d: dict[str, Any]) -> None:
def dump_inductor_provenance_info(
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
) -> dict[str, Any]:
global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info
if config.trace.enabled:
with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {}
if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
try:
global _pre_grad_graph_id
global _inductor_post_to_pre_grad_nodes
global _inductor_triton_kernel_to_post_grad_node_info
if config.trace.enabled:
with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
json.dump(node_mapping, fd)
return node_mapping
with V.debug.fopen(filename, "w") as fd:
log.info("Writing provenance tracing debugging info to %s", fd.name)
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
node_mapping = {}
if _pre_grad_graph_id:
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
_inductor_triton_kernel_to_post_grad_node_info
)
node_mapping = {
**_inductor_post_to_pre_grad_nodes,
**node_mapping_kernel,
}
if config.trace.enabled:
with V.debug.fopen(
"inductor_provenance_tracking_node_mappings.json", "w"
) as fd:
json.dump(node_mapping, fd)
return node_mapping
except Exception as e:
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
return {}


def set_kernel_post_grad_provenance_tracing(
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernelOut],
kernel_name: str,
is_extern: bool = False,
) -> None:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction
try:
from .codegen.simd_kernel_features import DisableReduction, EnableReduction

global _inductor_triton_kernel_to_post_grad_node_info
if is_extern:
assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
# "origin_node" is more precise and says that the contents of this node corresponds
# EXACTLY to the output of a particular FX node, but it's not always available
if node_schedule.origin_node:
origin_node_name = node_schedule.origin_node.name
if origin_node_name not in curr_node_info:
curr_node_info.append(origin_node_name)
else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
global _inductor_triton_kernel_to_post_grad_node_info
if is_extern:
assert isinstance(node_schedule, ExternKernelOut)
curr_node_info = _inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
else:
assert isinstance(node_schedule, list)
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
curr_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
# 'origins' on IR nodes gives what FX IR nodes contributed to any given fused kernel.
# "origin_node" is more precise and says that the contents of this node corresponds
# EXACTLY to the output of a particular FX node, but it's not always available
if node_schedule.origin_node:
origin_node_name = node_schedule.origin_node.name
if origin_node_name not in curr_node_info:
curr_node_info.append(origin_node_name)
else:
curr_node_info.extend(
origin.name
for origin in node_schedule.origins
if origin.name not in curr_node_info
)
else:
assert isinstance(node_schedule, list)
for snode in node_schedule:
if snode not in (EnableReduction, DisableReduction):
if snode.node is not None:
curr_node_info = (
_inductor_triton_kernel_to_post_grad_node_info.setdefault(
kernel_name, []
)
)
)
curr_node_info.extend(
origin.name
for origin in snode.node.origins
if origin.name not in curr_node_info
)
curr_node_info.extend(
origin.name
for origin in snode.node.origins
if origin.name not in curr_node_info
)
except Exception as e:
log.error("Unexpected error in set_kernel_post_grad_provenance_tracing: %s", e)
log.error(traceback.format_exc())


def save_args_for_compile_fx_inner(*args: Any, **kwargs: Any) -> None:
Expand Down
25 changes: 16 additions & 9 deletions torch/fx/traceback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import copy
import logging
import traceback
from contextlib import contextmanager
from enum import Enum
Expand All @@ -10,6 +11,8 @@
from .node import Node


log = logging.getLogger(__name__)

__all__ = [
"preserve_node_meta",
"has_preserved_node_meta",
Expand Down Expand Up @@ -311,12 +314,16 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
"""
Given an fx.Graph, return a json that contains the provenance information of each node.
"""
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
try:
provenance_tracking_json = {}
for node in graph.nodes:
if node.op == "call_function":
provenance_tracking_json[node.name] = (
[source.to_dict() for source in node.meta["from_node"]]
if "from_node" in node.meta
else []
)
return provenance_tracking_json
except Exception as e:
log.error("Unexpected error in dump_inductor_provenance_info: %s", e)
return {}
Loading