Skip to content

Commit 063bbf9

Browse files
committed
[simplefsdp] add multi parallelism autobucketing
ghstack-source-id: f759afb Pull-Request: #160282
1 parent 579aa07 commit 063bbf9

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

torch/_inductor/simple_fsdp/auto_bucket_plan.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
bucket_reduce_scatters,
1919
get_fx_node,
2020
)
21+
from .reorder import _check_ir_node_fsdp
2122

2223

2324
def get_dynamic_memory_threshold(
@@ -349,9 +350,10 @@ def get_bucketing_plan(
349350
if has_reduce_scatter:
350351
peak_memory = peak_memory + config.simplefsdp.peak_memory_offset
351352
for idx, snode in enumerate(snodes):
353+
# we only bucket on FSDP comm
352354
if is_collective(
353355
snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default
354-
):
356+
) and _check_ir_node_fsdp(snode.node):
355357
current_ag_bucket.append(snode)
356358
estimated_comm, comm_size_inp, comm_size_out = estimate_bucketed_node(
357359
current_ag_bucket,
@@ -473,7 +475,7 @@ def get_bucketing_plan(
473475
) = 0, 0
474476
elif is_collective(
475477
snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default
476-
):
478+
) and _check_ir_node_fsdp(snode.node):
477479
current_rs_bucket.append(snode)
478480
heuristic_info["this_step_rs_comm"], _, rs_comm_size_out = (
479481
estimate_bucketed_node(
@@ -499,14 +501,27 @@ def get_bucketing_plan(
499501
reduce_scatter_plan.append(current_rs_bucket)
500502
current_rs_bucket = []
501503
else:
502-
comp = estimate_comp_time(
503-
sched, snode, verbose=False, comp_cache=comp_cache
504-
)
504+
# [TODO]ruisizhang: for now, we only consider TP and CP, whose comm are AG & RS
505+
# For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time
506+
# the memory is the data fetched by the communication.
507+
if is_collective(snode.node):
508+
comp = comm_cache.get_comm_time(
509+
bucked_node[0].layout.size,
510+
bucked_node[1].layout.size,
511+
comm_func,
512+
calibrated=True,
513+
)
514+
memory = bucked_node[0].layout.size
515+
else:
516+
comp = estimate_comp_time(
517+
sched, snode, verbose=False, comp_cache=comp_cache
518+
)
519+
memory = max(
520+
abs(memories_at_nodes[idx + 1] - memories_at_nodes[release_steps[-1]]),
521+
heuristic_info["next_step_memory"],
522+
)
505523
heuristic_info["next_step_comp"] += comp
506-
heuristic_info["next_step_memory"] = max(
507-
abs(memories_at_nodes[idx + 1] - memories_at_nodes[release_steps[-1]]),
508-
heuristic_info["next_step_memory"],
509-
)
524+
heuristic_info["next_step_memory"] = memory
510525
total_comp_time -= comp
511526

512527
if len(current_ag_bucket) > 0 or len(all_gather_plan) == 0:

torch/_inductor/simple_fsdp/reorder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _check_ir_node_fsdp(ir_node: "ir.Operation") -> bool:
7272

7373
for n in ir_node_origins:
7474
meta_data = n.meta.get("stack_trace", {})
75-
# TODO(ruisizhang123): hack to get FSDP node (the FSDP AG/RS are created from torch_spmd)
75+
# TODO(ruisizhang123): hack to get FSDP node (the SimpleFSDP AG/RS are created with parametrization)
7676
if "parametrization" in meta_data:
7777
is_fsdp = True
7878
return is_fsdp

0 commit comments

Comments
 (0)