Skip to content

Commit bb886ec

Browse files
haowu14facebook-github-bot
authored andcommitted
Graph split event tracker (#159795)
Summary: A tool to track events in graph split, specifically on how nodes being end up in acc or cpu subgraphs. Usage: use env var to specify a mode and necessary arguments. FX_NET_ACC_SPLITTER_TRACKER_MODE: Tracker mode. ``` Different modes of the event tracker: "0": Tracker not enabled (by default) "1": Tracker enabled but no dumps. Information available by setting breakpoints and visually inspect in pdb. "2": Tracker enabled and dumps all events to DUMP_PREFIX_all.txt "3": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES recusrively and dump to DUMP_PREFIX_nodex.txt "4:: In addition to events dump, track all nodes with more than 1 event recusrively and dump to DUMP_PREFIX_nodex.txt ``` FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH: overriding dump path. Leave empty for `~`. FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES: Nodes to track for mode "3". Test Plan: New unit test ``` buck test caffe2/test:fx -- test_fx_split_node_finder ``` ---- ``` FX_NET_ACC_SPLITTER_TRACKER_MODE=4 ../buck-out/v2/gen/fbcode/6f6fe98d41631b2e/inference_enablement/model_processing/infra/components/lowering/re/__re_cinder__/re_cinder.par -r '{"aot_inductor":{"serialized_inference_model_input_path":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/4/gpu_lowering/input.merge.61759375","serialized_inference_model_output_path":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/4/gpu_lowering/inductor_output.merge.61759375","submodule_names_to_lower":["merge"],"inductor_lowering_context":{"aot_inductor_lowering_settings":{"max_batch_size":2048,"min_acc_module_size":10,"workdir":"/tmp/local","name":"merge","dll_name":"inductor_engine.so","use_scripting":true,"preset_lowerer":"gr;disable_new_lowering_weights;disable_dper_passes:passes=fuse_parallel_linear_no_weight_change|fuse_parallel_linear","precision":4,"output_precision":4,"remote_cache_file_path_folder":"ads_storage_fblearner/tree/user/facebook/fblearner/predictor/895540436/","save_remote_cache":true,"aot_inductor_config":"{\"max_autotune\":True,\"comprehensive_padding\":False}","disable_dynamic_shapes":false,"remove_unexpected_type_cast":false,"disable_constraint_solver":false,"sample_input_tile_factor":32,"disable_acc_tracer":true,"generate_sample_inputs":true,"tile_sample_input_by_dynamic_shape":false,"node_replacement_dict":"","dynamic_shapes_strategy":73728,"auto_dynamic_shapes":false,"auto_dynamic_shapes_min_size":1,"auto_dynamic_shapes_max_size":1048576,"max_acc_splits":-1,"dynamic_size":-1,"pre_dispatch_export":true,"merge_split_optimization":false,"use_dynamic_dim_hints":false,"allow_refine_dynamic_shapes_on_constants":false,"use_sigmoid":false}},"model_entity_id":895540436,"model_snapshot_id":4,"add_sample_inputs":false,"platform_arch":0,"lowering_lib_pkg":"ien.lower:prod","dense_in_place_format":2}}' ``` Events dump: P1896093119 Nodes track dump: P1896110514 The above files are generated locally ``` ? _fx_net_tracker_all.txt ? _fx_net_tracker_nodes.txt ``` Rollback Plan: Reviewed By: georgiaphillips Differential Revision: D79203595
1 parent e619c6b commit bb886ec

File tree

2 files changed

+248
-8
lines changed

2 files changed

+248
-8
lines changed

test/fx/test_fx_split_node_finder.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Owner(s): ["module: fx"]
2+
3+
# pyre-strict
4+
import torch
5+
from torch.fx.passes.operator_support import OperatorSupportBase
6+
from torch.fx.passes.splitter_base import (
7+
FxNetAccNodesFinder,
8+
NodeEventTracker,
9+
ShapeProp,
10+
)
11+
from torch.testing._internal.common_utils import TestCase
12+
13+
14+
# Wrappepr function to make it supported
15+
@torch.fx.wrap
16+
def sup_f(x):
17+
return x
18+
19+
20+
class TestFxSplitNodeFinder(TestCase):
21+
def testFinder(self):
22+
class IsNodeSupported(OperatorSupportBase):
23+
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
24+
return "sup_" in node.name
25+
26+
class TestModule(torch.nn.Module):
27+
def forward(self, x, y):
28+
x = sup_f(x)
29+
y = sup_f(y)
30+
b = x + y # non-supported to break graph
31+
return sup_f(b)
32+
33+
gm = torch.fx.symbolic_trace(TestModule())
34+
ShapeProp(gm).propagate(*(torch.rand((2, 2)), 3))
35+
finder = FxNetAccNodesFinder(gm, IsNodeSupported(), False)
36+
37+
# override tracker without having to run with env var
38+
tracker = NodeEventTracker(
39+
1, # mode: just enable the tracker without dumping
40+
"", # dump_path. We don't need it.
41+
)
42+
finder.tracker = tracker
43+
44+
acc_nodes = finder()
45+
46+
def getEvents(tracker, node):
47+
return [tracker.events[idx] for idx in tracker.node_events[node.name]]
48+
49+
# check that acc nodes events are as expected
50+
for node in gm.graph.nodes:
51+
if node.name == "sup_f_1":
52+
# this node should be removed from acc nodes.
53+
self.assertFalse(node in acc_nodes)
54+
events = getEvents(tracker, node)
55+
# 2 events.
56+
self.assertEqual(len(events), 2)
57+
# 1st event is init_acc as supported operator
58+
self.assertTrue(
59+
events[0].desc.startswith(
60+
"init_acc|callable_and_operator_supported"
61+
)
62+
)
63+
# 2nd event is del_acc as non-tensor output with cpu user
64+
self.assertTrue(
65+
events[1].desc.startswith("acc_del|non_tensor_output_with_cpu_user")
66+
)
67+
elif node.name.startswith("sup_f"):
68+
# other supported nodes should remain in acc nodes.
69+
self.assertTrue(node in acc_nodes)
70+
events = getEvents(tracker, node)
71+
self.assertEqual(len(events), 1)
72+
self.assertTrue(
73+
events[0].desc.startswith(
74+
"init_acc|callable_and_operator_supported"
75+
)
76+
)
77+
else:
78+
# other nodes are on cpu.
79+
self.assertFalse(node in acc_nodes)

torch/fx/passes/splitter_base.py

Lines changed: 169 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import argparse
33
import copy
44
import logging
5+
import os
56
from collections import defaultdict
67
from collections.abc import Iterable, Sequence
78
from dataclasses import dataclass
8-
from typing import Any, NamedTuple, Optional
9+
from typing import Any, Literal, NamedTuple, Optional
910

1011
import torch
1112
from torch.fx._compatibility import compatibility
@@ -39,6 +40,35 @@
3940
DEFAULT_SKIP_FUSION = False
4041
DEFAULT_ALLOW_NON_TENSOR = False
4142

43+
# ENV var and constants for node tracker
44+
45+
TRACKER_DUMP_PATH = "_fx_net_tracker"
46+
NODES_SUFFIX = "_nodes.txt"
47+
ALL_SUFFIX = "_all.txt"
48+
49+
ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE = "FX_NET_ACC_SPLITTER_TRACKER_MODE"
50+
ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH = "FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH"
51+
ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES = (
52+
"FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES"
53+
)
54+
55+
DUMP_PREFIX = os.environ.get(
56+
ENV_FX_NET_ACC_SPLITTER_TRACKER_DUMP_PATH, TRACKER_DUMP_PATH
57+
)
58+
59+
"""
60+
Different modes of the event tracker:
61+
"0": Tracker not enabled (by default)
62+
"1": Tracker enabled but no dumps. Information available by setting breakpoints and visually inspect in pdb.
63+
"2": Tracker enabled and dumps all events to DUMP_PREFIX_all.txt
64+
"3": In addition to events dump, track nodes specified by ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES
65+
recursively and dump to DUMP_PREFIX_nodex.txt
66+
"4": In addition to events dump, track all nodes with more than 1 event recursively and dump to DUMP_PREFIX_nodex.txt
67+
"""
68+
TRACKER_MODE: Literal["0", "1", "2", "3", "4"] = os.environ.get(
69+
ENV_FX_NET_ACC_SPLITTER_TRACKER_MODE, "0"
70+
) # type: ignore[assignment]
71+
4272

4373
class _SplitterSettingBase:
4474
def __init__(
@@ -99,6 +129,124 @@ def __init__(
99129
self.max_acc_splits: int = max_acc_splits
100130

101131

132+
@compatibility(is_backward_compatible=False)
133+
class NodeEvent:
134+
"""
135+
An event in graph split that happened on a node.
136+
source: Subject of the event
137+
desc: readable description
138+
dep: Optional dependency, usually the node that caused the event.
139+
"""
140+
141+
def __init__(self, source: torch.fx.Node, desc: str, dep: torch.fx.Node = None):
142+
self.source = source
143+
self.desc = desc
144+
self.dep = dep
145+
146+
def to_str(self):
147+
return f"{self.source.name}: {self.desc} {self.dep.name if self.dep else '#'}"
148+
149+
150+
@compatibility(is_backward_compatible=False)
151+
class NodeEventTracker:
152+
"""
153+
Tracks node events during the splitter execution.
154+
"""
155+
156+
def __init__(self, tracker_mode, dump_prefix):
157+
self.tracker_mode = tracker_mode
158+
self.dump_prefix = dump_prefix
159+
self.enabled = self.tracker_mode > 0
160+
# list of events
161+
self.events = []
162+
# dict from node name to event index
163+
self.node_events = {}
164+
self.writer = print
165+
166+
def add(self, node: torch.fx.Node, desc: str, dep: torch.fx.Node = None):
167+
"""
168+
Add a new event to the tracker.
169+
"""
170+
if not self.enabled:
171+
return
172+
event = NodeEvent(node, desc, dep)
173+
self.events.append(event)
174+
if node.name not in self.node_events:
175+
self.node_events[node.name] = []
176+
self.node_events[node.name].append(len(self.events) - 1)
177+
178+
def print_node(self, node_name, recursive=False, tab="", writer=None):
179+
"""
180+
Print a node and its events.
181+
@param recursive: if True, print nodes that caused the events on this current node.
182+
@param tab: Indentation for dependencies.
183+
@param writer: function to write to file. If None, use print.
184+
"""
185+
if not writer:
186+
writer = self.writer
187+
for idx in self.node_events.get(node_name, []):
188+
event = self.events[idx]
189+
writer(tab + event.to_str())
190+
if recursive and event.dep is not None:
191+
self.print_node(
192+
event.dep.name, recursive=True, tab="| " + tab, writer=writer
193+
)
194+
195+
def print_all(self, writer=None):
196+
"""
197+
Print all nodes in a list.
198+
@param writer: function to write to file. If None, use print.
199+
"""
200+
if not writer:
201+
writer = self.writer
202+
for name in self.node_events.keys():
203+
writer(f"Node: {name}: ")
204+
self.print_node(name, recursive=False, tab=" ", writer=writer)
205+
206+
def close(self):
207+
"""
208+
Function to be invoked at the end of the finder execution to printout tracked events specified by the mode.
209+
"""
210+
211+
def writeln(f):
212+
def fn(x):
213+
return f.write(x + "\n")
214+
215+
return fn
216+
217+
if not self.enabled:
218+
return
219+
# Mode 1: no dump
220+
221+
# Mode >=2: Dump all events to file
222+
if self.tracker_mode >= 2:
223+
with open(self.dump_prefix + ALL_SUFFIX, "w") as f:
224+
self.print_all(writeln(f))
225+
226+
def dump_selected_nodes(nodes):
227+
with open(self.dump_prefix + NODES_SUFFIX, "w") as f:
228+
for node_name in nodes:
229+
writeln(f"===== Tracking node {node_name} =====")
230+
self.print_node(
231+
node_name, recursive=True, tab="|-", writer=writeln(f)
232+
)
233+
writeln(f"===== End of tracking node {node_name} =====")
234+
235+
# Mode 3: Dump specific nodes in recursive manner.
236+
# Mode 4: Dump all nodes with more than 1 event in recursive manner.
237+
if self.tracker_mode == 3 or self.tracker_mode == 4:
238+
nodes = (
239+
os.environ.get(ENV_FX_NET_ACC_SPLITTER_TRACKER_TRACKED_NODES, "").split(
240+
","
241+
)
242+
if self.tracker_mode == 3
243+
else [
244+
name for name, events in self.node_events.items() if len(events) > 1
245+
]
246+
)
247+
dump_selected_nodes(nodes)
248+
249+
102250
@compatibility(is_backward_compatible=False)
103251
class FxNetAccNodesFinder:
104252
"""
@@ -125,6 +273,8 @@ def __init__(
125273
self.allow_non_tensor = allow_non_tensor
126274
self.acc_nodes: NodeSet = set()
127275

276+
self.tracker = NodeEventTracker(int(TRACKER_MODE), DUMP_PREFIX)
277+
128278
def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
129279
"""
130280
Transitively excludes nodes from ACC supported set.
@@ -139,7 +289,9 @@ def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList):
139289
for user in node.users:
140290
if user in self.acc_nodes:
141291
self.acc_nodes.remove(user)
292+
self.tracker.add(user, "acc_del|user_of_new_cpu_node", node)
142293
if not is_node_output_tensor(user):
294+
self.tracker.add(user, "new_cpu_node|non_tensor_output")
143295
cpu_worklist.append(user)
144296

145297
def reduce_acc_nodes_non_tensor_input(self):
@@ -156,6 +308,7 @@ def reduce_acc_nodes_non_tensor_input(self):
156308
continue
157309
if is_node_output_tensor(node):
158310
continue
311+
self.tracker.add(node, "new_cpu_node|callable_non_tensor_input")
159312
non_tensor_cpu_nodes.append(node)
160313

161314
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
@@ -174,6 +327,9 @@ def reduce_acc_nodes_non_tensor_output(self):
174327
for user in acc_node.users:
175328
if user not in self.acc_nodes:
176329
new_cpu_nodes.append(acc_node)
330+
self.tracker.add(
331+
acc_node, "acc_del|non_tensor_output_with_cpu_user", user
332+
)
177333
break
178334

179335
if not new_cpu_nodes:
@@ -186,17 +342,22 @@ def reduce_acc_nodes_non_tensor_output(self):
186342

187343
def __call__(self) -> NodeSet:
188344
submodules = dict(self.module.named_modules())
189-
self.acc_nodes = {
190-
n
191-
for n in self.module.graph.nodes
192-
if n.op in CALLABLE_NODE_OPS
193-
and self.operator_support.is_node_supported(submodules, n)
194-
}
345+
self.acc_nodes = set()
346+
for n in self.module.graph.nodes:
347+
if n.op not in CALLABLE_NODE_OPS:
348+
self.tracker.add(n, "init_cpu|not_callable")
349+
continue
350+
if not self.operator_support.is_node_supported(submodules, n):
351+
self.tracker.add(n, "init_cpu|operator_support")
352+
continue
353+
354+
self.tracker.add(n, "init_acc|callable_and_operator_supported")
355+
self.acc_nodes.add(n)
195356

196357
if not self.allow_non_tensor:
197358
self.reduce_acc_nodes_non_tensor_input()
198359
self.reduce_acc_nodes_non_tensor_output()
199-
360+
self.tracker.close()
200361
return self.acc_nodes
201362

202363

0 commit comments

Comments
 (0)