Skip to content

Commit cee7d1d

Browse files
haowu14facebook-github-bot
authored andcommitted
Graph split event tracker (#159795)
Summary: A tool to track events in graph split, specifically on how nodes being end up in acc or cpu subgraphs. Usage: use env var to specify a mode and necessary arguments. FX_NET_ACC_SPLITTER_TRACKER_MODE: Tracker mode. ``` Different modes of the event tracker: "0": Tracker enabled but no local dumps. Information available by setting breakpoints and visually inspect in pdb. "1": Tracker enabled and dumps all events to DUMP_PREFIX_all.txt "2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES recursively and dump to DUMP_PREFIX_nodex.txt "3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt Regardless of the modes, tracker is always enabled and dumps by trace_structured. ``` FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH: overriding dump path. Leave empty for `~`. FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES: Nodes to track for mode "2". Test Plan: New unit test ``` buck test caffe2/test:fx -- test_fx_split_node_finder ``` ---- ``` TORCH_TRACE=~/my_trace_log_dir FX_NET_ACC_SPLITTER_TRACKER_MODE=3 ../buck-out/v2/gen/fbcode/6f6fe98d41631b2e/inference_enablement/model_processing/infra/components/lowering/re/__re_cinder__/re_cinder.par -r '{"aot_inductor":{"serialized_inference_model_input_path":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/4/gpu_lowering/input.merge.61759375","serialized_inference_model_output_path":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/4/gpu_lowering/inductor_output.merge.61759375","submodule_names_to_lower":["merge"],"inductor_lowering_context":{"aot_inductor_lowering_settings":{"max_batch_size":2048,"min_acc_module_size":10,"workdir":"/tmp/local","name":"merge","dll_name":"inductor_engine.so","use_scripting":true,"preset_lowerer":"gr;disable_new_lowering_weights;disable_dper_passes:passes=fuse_parallel_linear_no_weight_change|fuse_parallel_linear","precision":4,"output_precision":4,"remote_cache_file_path_folder":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/","save_remote_cache":true,"aot_inductor_config":"{\"max_autotune\":True,\"comprehensive_padding\":False}","disable_dynamic_shapes":false,"remove_unexpected_type_cast":false,"disable_constraint_solver":false,"sample_input_tile_factor":32,"disable_acc_tracer":true,"generate_sample_inputs":true,"tile_sample_input_by_dynamic_shape":false,"node_replacement_dict":"","dynamic_shapes_strategy":73728,"auto_dynamic_shapes":false,"auto_dynamic_shapes_min_size":1,"auto_dynamic_shapes_max_size":1048576,"max_acc_splits":-1,"dynamic_size":-1,"pre_dispatch_export":true,"merge_split_optimization":false,"use_dynamic_dim_hints":false,"allow_refine_dynamic_shapes_on_constants":false,"use_sigmoid":false}},"model_entity_id":895540436,"model_snapshot_id":4,"add_sample_inputs":false,"platform_arch":0,"lowering_lib_pkg":"ien.lower:prod","dense_in_place_format":2}}' ``` Events dump: P1896093119 Nodes track dump: P1896110514 The above files are generated locally ``` ? _fx_net_tracker_all.txt ? _fx_net_tracker_nodes.txt ``` Also in torch_trace you can find all events: https://www.internalfb.com/intern/paste/P1897874179/ Rollback Plan: Reviewed By: georgiaphillips Differential Revision: D79203595
1 parent 9ccd0f5 commit cee7d1d

File tree

2 files changed

+268
-8
lines changed

2 files changed

+268
-8
lines changed

test/fx/test_fx_split_node_finder.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Owner(s): ["module: fx"]
2+
3+
# pyre-strict
4+
import torch
5+
from torch.fx.passes.operator_support import OperatorSupportBase
6+
from torch.fx.passes.splitter_base import (
7+
FxNetAccNodesFinder,
8+
NodeEventTracker,
9+
ShapeProp,
10+
)
11+
from torch.testing._internal.common_utils import TestCase
12+
13+
14+
# Wrappepr function to make it supported
15+
@torch.fx.wrap
16+
def sup_f(x):
17+
return x
18+
19+
20+
class TestFxSplitNodeFinder(TestCase):
21+
def testFinder(self):
22+
class IsNodeSupported(OperatorSupportBase):
23+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
24+
return "sup_" in node.name
25+
26+
class TestModule(torch.nn.Module):
27+
def forward(self, x, y):
28+
x = sup_f(x)
29+
y = sup_f(y)
30+
b = x + y # non-supported to break graph
31+
return sup_f(b)
32+
33+
gm = torch.fx.symbolic_trace(TestModule())
34+
ShapeProp(gm).propagate(*(torch.rand((2, 2)), 3))
35+
finder = FxNetAccNodesFinder(gm, IsNodeSupported(), False)
36+
37+
# override tracker without having to run with env var
38+
tracker = NodeEventTracker(
39+
1, # mode: just enable the tracker without dumping
40+
"", # dump_path. We don't need it.
41+
)
42+
finder.tracker = tracker
43+
44+
acc_nodes = finder()
45+
46+
def getEvents(tracker, node):
47+
return [tracker.events[idx] for idx in tracker.node_events[node.name]]
48+
49+
# check that acc nodes events are as expected
50+
for node in gm.graph.nodes:
51+
if node.name == "sup_f_1":
52+
# this node should be removed from acc nodes.
53+
self.assertFalse(node in acc_nodes)
54+
events = getEvents(tracker, node)
55+
# 2 events.
56+
self.assertEqual(len(events), 2)
57+
# 1st event is init_acc as supported operator
58+
self.assertTrue(
59+
events[0].desc.startswith(
60+
"init_acc|callable_and_operator_supported"
61+
)
62+
)
63+
# 2nd event is del_acc as non-tensor output with cpu user
64+
self.assertTrue(
65+
events[1].desc.startswith("acc_del|non_tensor_output_with_cpu_user")
66+
)
67+
elif node.name.startswith("sup_f"):
68+
# other supported nodes should remain in acc nodes.
69+
self.assertTrue(node in acc_nodes)
70+
events = getEvents(tracker, node)
71+
self.assertEqual(len(events), 1)
72+
self.assertTrue(
73+
events[0].desc.startswith(
74+
"init_acc|callable_and_operator_supported"
75+
)
76+
)
77+
else:
78+
# other nodes are on cpu.
79+
self.assertFalse(node in acc_nodes)

torch/fx/passes/splitter_base.py

Lines changed: 189 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# mypy: allow-untyped-defs
22
import argparse
33
import copy
4+
import json
45
import logging
6+
import os
57
from collections import defaultdict
68
from collections.abc import Iterable, Sequence
79
from dataclasses import dataclass
8-
from typing import Any, NamedTuple, Optional
10+
from typing import Any, Literal, NamedTuple, Optional
911

1012
import torch
1113
from torch.fx._compatibility import compatibility
1214
from torch.fx.node import map_arg
1315
from torch.fx.passes.graph_manipulation import get_size_of_node
16+
from torch._logging import trace_structured
1417

1518
from .graph_drawer import FxGraphDrawer
1619
from .operator_support import get_node_target, OperatorSupportBase
@@ -32,13 +35,44 @@
3235
"Subgraph",
3336
"SplitResult",
3437
"generate_inputs_for_submodules",
38+
"NodeEvent",
39+
"NodeEventTracker",
3540
]
3641
_LOGGER = logging.getLogger(__name__)
3742

3843
DEFAULT_MIN_ACC_MODULE_SIZE = 1
3944
DEFAULT_SKIP_FUSION = False
4045
DEFAULT_ALLOW_NON_TENSOR = False
4146

47+
# ENV var and constants for node tracker
48+
49+
TRACKER_DUMP_PATH = "_fx_net_tracker"
50+
NODES_SUFFIX = "_nodes.txt"
51+
ALL_SUFFIX = "_all.txt"
52+
53+
ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
54+
ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
55+
ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
56+
"FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
57+
)
58+
59+
DUMP_PREFIX = os.environ.get(
60+
ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
61+
)
62+
63+
"""
64+
Different modes of the event tracker:
65+
"0": Tracker enabled but no local dumps. Information available by setting breakpoints and visually inspect in pdb.
66+
"1": Tracker enabled and dumps all events to DUMP_PREFIX_all.txt
67+
"2": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
68+
recursively and dump to DUMP_PREFIX_nodex.txt
69+
"3": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
70+
Regardless of the modes, tracker is always enabled and dumps by trace_structured.
71+
"""
72+
TRACKER_MODE: Literal["0", "1", "2", "3"] = os.environ.get(
73+
ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
74+
) # type: ignore[assignment]
75+
4276

4377
class _SplitterSettingBase:
4478
def __init__(
@@ -99,6 +133,140 @@ def __init__(
99133
self.max_acc_splits: int = max_acc_splits
100134

101135

136+
@compatibility(is_backward_compatible=False)
137+
class NodeEvent:
138+
"""
139+
An event in graph split that happened on a node.
140+
source: Subject of the event
141+
desc: readable description
142+
dep: Optional dependency, usually the node that caused the event.
143+
"""
144+
145+
def __init__(self, source: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
146+
self.source = source
147+
self.desc = desc
148+
self.dep = dep
149+
150+
def to_str(self):
151+
return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
152+
153+
154+
@compatibility(is_backward_compatible=False)
155+
class NodeEventTracker:
156+
"""
157+
Tracks node events during the splitter execution.
158+
"""
159+
160+
def __init__(self, tracker_mode, dump_prefix):
161+
self.tracker_mode = tracker_mode
162+
self.dump_prefix = dump_prefix
163+
self.enabled = self.tracker_mode > 0
164+
# list of events
165+
self.events = []
166+
# dict from node name to event index
167+
self.node_events = {}
168+
self.writer = print
169+
170+
def add(self, node: torch.fx.Node, desc: str, dep: Optional[torch.fx.Node] = None):
171+
"""
172+
Add a new event to the tracker.
173+
"""
174+
if not self.enabled:
175+
return
176+
event = NodeEvent(node, desc, dep)
177+
self.events.append(event)
178+
if node.name not in self.node_events:
179+
self.node_events[node.name] = []
180+
self.node_events[node.name].append(len(self.events) - 1)
181+
182+
def print_node(self, node_name, recursive = False, tab = "", writer = None):
183+
"""
184+
Print a node and its events.
185+
@param recursive: if True, print nodes that caused the events on this current node.
186+
@param tab: Indentation for dependencies.
187+
@param writer: function to write to file. If None, use print.
188+
"""
189+
if not writer:
190+
writer = self.writer
191+
for idx in self.node_events.get(node_name, []):
192+
event = self.events[idx]
193+
writer(tab + event.to_str())
194+
if recursive and event.dep is not None:
195+
self.print_node(
196+
event.dep.name, recursive=True, tab="| " + tab, writer=writer
197+
)
198+
199+
def to_dict(self):
200+
"""
201+
Create dict dump on all events.
202+
"""
203+
ret = {}
204+
for name in self.node_events.keys():
205+
ret[name] = []
206+
for idx in self.node_events.get(name, []):
207+
event = self.events[idx]
208+
ret[name].append(event.to_str())
209+
return ret
210+
211+
212+
def print_all(self, writer=None):
213+
"""
214+
Print all nodes in a list.
215+
@param writer: function to write to file. If None, use print.
216+
"""
217+
if not writer:
218+
writer = self.writer
219+
for name in self.node_events.keys():
220+
writer(f"Node: {name}: ")
221+
self.print_node(name, recursive = False, tab = " ", writer = writer)
222+
223+
def close(self):
224+
"""
225+
Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
226+
"""
227+
# if enabled, always dump via trace_structured
228+
trace_structured(
229+
"artifact",
230+
metadata_fn = lambda: {"name": "fx_net_acc_splitter_finder_events", "encoding": "json"},
231+
payload_fn = lambda: json.dumps(self.to_dict()),
232+
)
233+
234+
def writeln(f):
235+
def fn(x):
236+
return f.write(x + "\n")
237+
238+
return fn
239+
# Mode 0: no local dump
240+
241+
# Mode >=1: Dump all events to file
242+
if self.tracker_mode >= 1:
243+
with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
244+
self.print_all(writeln(f))
245+
246+
def dump_selected_nodes(nodes):
247+
with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
248+
for node_name in nodes:
249+
writeln(f"===== Tracking node {node_name} =====")
250+
self.print_node(
251+
node_name, recursive=True, tab="|-", writer=writeln(f)
252+
)
253+
writeln(f"===== End of tracking node {node_name} =====")
254+
255+
# Mode 2: Dump specific nodes in recursive manner.
256+
# Mode 3: Dump all nodes with more than 1 event in recursive manner.
257+
if self.tracker_mode == 2 or self.tracker_mode == 3:
258+
nodes = (
259+
os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
260+
","
261+
)
262+
if self.tracker_mode == 2
263+
else [
264+
name for name, events in self.node_events.items() if len(events) > 1
265+
]
266+
)
267+
dump_selected_nodes(nodes)
268+
269+
102270
@compatibility(is_backward_compatible=False)
103271
class FxNetAccNodesFinder:
104272
"""
@@ -125,6 +293,8 @@ def __init__(
125293
self.allow_non_tensor = allow_non_tensor
126294
self.acc_nodes: NodeSet = set()
127295

296+
self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
297+
128298
def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
129299
"""
130300
Transitively excludes nodes from ACC supported set.
@@ -139,7 +309,9 @@ def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
139309
for user in node.users:
140310
if user in self.acc_nodes:
141311
self.acc_nodes.remove(user)
312+
self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
142313
if not is_node_output_tensor(user):
314+
self.tracker.add(user, "new_cpu_node|non_tensor_output")
143315
cpu_worklist.append(user)
144316

145317
def reduce_acc_nodes_non_tensor_input(self):
@@ -156,6 +328,7 @@ def reduce_acc_nodes_non_tensor_input(self):
156328
continue
157329
if is_node_output_tensor(node):
158330
continue
331+
self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
159332
non_tensor_cpu_nodes.append(node)
160333

161334
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
@@ -174,6 +347,9 @@ def reduce_acc_nodes_non_tensor_output(self):
174347
for user in acc_node.users:
175348
if user not in self.acc_nodes:
176349
new_cpu_nodes.append(acc_node)
350+
self.tracker.add(
351+
acc_node, "acc_del|non_tensor_output_with_cpu_user", user
352+
)
177353
break
178354

179355
if not new_cpu_nodes:
@@ -186,17 +362,22 @@ def reduce_acc_nodes_non_tensor_output(self):
186362

187363
def __call__(self) -> NodeSet:
188364
submodules = dict(self.module.named_modules())
189-
self.acc_nodes = {
190-
n
191-
for n in self.module.graph.nodes
192-
if n.op in CALLABLE_NODE_OPS
193-
and self.operator_support.is_node_supported(submodules, n)
194-
}
365+
self.acc_nodes = set()
366+
for n in self.module.graph.nodes:
367+
if n.op not in CALLABLE_NODE_OPS:
368+
self.tracker.add(n, "init_cpu|not_callable")
369+
continue
370+
if not self.operator_support.is_node_supported(submodules, n):
371+
self.tracker.add(n, "init_cpu|operator_support")
372+
continue
373+
374+
self.tracker.add(n, "init_acc|callable_and_operator_supported")
375+
self.acc_nodes.add(n)
195376

196377
if not self.allow_non_tensor:
197378
self.reduce_acc_nodes_non_tensor_input()
198379
self.reduce_acc_nodes_non_tensor_output()
199-
380+
self.tracker.close()
200381
return self.acc_nodes
201382

202383

0 commit comments

Comments
 (0)