18
18
bucket_reduce_scatters ,
19
19
get_fx_node ,
20
20
)
21
+ from .reorder import _check_ir_node_fsdp
21
22
22
23
23
24
def get_dynamic_memory_threshold (
@@ -349,9 +350,10 @@ def get_bucketing_plan(
349
350
if has_reduce_scatter :
350
351
peak_memory = peak_memory + config .simplefsdp .peak_memory_offset
351
352
for idx , snode in enumerate (snodes ):
353
+ # we only bucket on FSDP comm
352
354
if is_collective (
353
355
snode .node , op = torch .ops ._c10d_functional .all_gather_into_tensor .default
354
- ):
356
+ ) and _check_ir_node_fsdp ( snode . node ) :
355
357
current_ag_bucket .append (snode )
356
358
estimated_comm , comm_size_inp , comm_size_out = estimate_bucketed_node (
357
359
current_ag_bucket ,
@@ -473,7 +475,7 @@ def get_bucketing_plan(
473
475
) = 0 , 0
474
476
elif is_collective (
475
477
snode .node , op = torch .ops ._c10d_functional .reduce_scatter_tensor .default
476
- ):
478
+ ) and _check_ir_node_fsdp ( snode . node ) :
477
479
current_rs_bucket .append (snode )
478
480
heuristic_info ["this_step_rs_comm" ], _ , rs_comm_size_out = (
479
481
estimate_bucketed_node (
@@ -499,14 +501,27 @@ def get_bucketing_plan(
499
501
reduce_scatter_plan .append (current_rs_bucket )
500
502
current_rs_bucket = []
501
503
else :
502
- comp = estimate_comp_time (
503
- sched , snode , verbose = False , comp_cache = comp_cache
504
- )
504
+ # [TODO]ruisizhang: for now, we only consider TP and CP, whose comm are AG & RS
505
+ # For TP and CP, we consider the node as a "COMP" node with exposed communication as Comp time
506
+ # the memory is the data fetched by the communication.
507
+ if is_collective (snode .node ):
508
+ comp = comm_cache .get_comm_time (
509
+ bucked_node [0 ].layout .size ,
510
+ bucked_node [1 ].layout .size ,
511
+ comm_func ,
512
+ calibrated = True ,
513
+ )
514
+ memory = bucked_node [0 ].layout .size
515
+ else :
516
+ comp = estimate_comp_time (
517
+ sched , snode , verbose = False , comp_cache = comp_cache
518
+ )
519
+ memory = max (
520
+ abs (memories_at_nodes [idx + 1 ] - memories_at_nodes [release_steps [- 1 ]]),
521
+ heuristic_info ["next_step_memory" ],
522
+ )
505
523
heuristic_info ["next_step_comp" ] += comp
506
- heuristic_info ["next_step_memory" ] = max (
507
- abs (memories_at_nodes [idx + 1 ] - memories_at_nodes [release_steps [- 1 ]]),
508
- heuristic_info ["next_step_memory" ],
509
- )
524
+ heuristic_info ["next_step_memory" ] = memory
510
525
total_comp_time -= comp
511
526
512
527
if len (current_ag_bucket ) > 0 or len (all_gather_plan ) == 0 :
0 commit comments