File tree Expand file tree Collapse file tree 3 files changed +15
-1
lines changed Expand file tree Collapse file tree 3 files changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -375,6 +375,10 @@ def prologue_fusion_enabled() -> bool:
375
375
# enable operator reordering for peak memory optimization
376
376
reorder_for_peak_memory = True
377
377
378
+ # reorder_for_peak_memory has performance regression for models with collectives
379
+ # so we by default disable it for models with collectives
380
+ disable_peak_mem_reorder_with_collectives = True
381
+
378
382
# runtime estimation function for ops
379
383
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
380
384
estimate_op_runtime = "default"
Original file line number Diff line number Diff line change 9
9
from torch ._utils_internal import signpost_event
10
10
from torch .utils ._ordered_set import OrderedSet
11
11
12
+ from . import config
12
13
from .ir import MultiOutputLayout , NoneLayout
13
- from .utils import get_dtype_size , is_wait
14
+ from .utils import contains_collective_or_wait , get_dtype_size , is_wait
14
15
from .virtualized import V
15
16
16
17
@@ -648,6 +649,11 @@ def reorder_for_peak_memory(
648
649
649
650
torch_log .info ("Reordering for peak memory -- %d nodes" , len (nodes ))
650
651
652
+ if config .disable_peak_mem_reorder_with_collectives and contains_collective_or_wait (
653
+ nodes
654
+ ):
655
+ return nodes
656
+
651
657
estimated_peak_memory , name_to_freeable_input_buf = prepare_planning_info (
652
658
nodes ,
653
659
name_to_buf ,
Original file line number Diff line number Diff line change @@ -2341,6 +2341,10 @@ def contains_wait(snode: BaseSchedulerNode) -> bool:
2341
2341
return is_wait (snode .node )
2342
2342
2343
2343
2344
+ def contains_collective_or_wait (snodes : list [BaseSchedulerNode ]) -> bool :
2345
+ return any (contains_collective (snode ) or contains_wait (snode ) for snode in snodes )
2346
+
2347
+
2344
2348
def is_fallback_op (
2345
2349
node : Optional [Operation ],
2346
2350
op : Union [torch ._ops .OpOverload , Collection [torch ._ops .OpOverload ]],
You can’t perform that action at this time.
0 commit comments