Skip to content

Commit 111f6de

Browse files
committed
turn off reorder_for_peak_memory in case of collectives
ghstack-source-id: 23bbb05 Pull Request resolved: #155271
1 parent 6c05f2f commit 111f6de

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,10 @@ def prologue_fusion_enabled() -> bool:
375375
# enable operator reordering for peak memory optimization
376376
reorder_for_peak_memory = True
377377

378+
# reorder_for_peak_memory has performance regression for models with collectives
379+
# so we by default disable it for models with collectives
380+
disable_peak_mem_reorder_with_collectives = True
381+
378382
# runtime estimation function for ops
379383
# for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle
380384
estimate_op_runtime = "default"

torch/_inductor/memory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from torch._utils_internal import signpost_event
1010
from torch.utils._ordered_set import OrderedSet
1111

12+
from . import config
1213
from .ir import MultiOutputLayout, NoneLayout
13-
from .utils import get_dtype_size, is_wait
14+
from .utils import contains_collective_or_wait, get_dtype_size, is_wait
1415
from .virtualized import V
1516

1617

@@ -648,6 +649,11 @@ def reorder_for_peak_memory(
648649

649650
torch_log.info("Reordering for peak memory -- %d nodes", len(nodes))
650651

652+
if config.disable_peak_mem_reorder_with_collectives and contains_collective_or_wait(
653+
nodes
654+
):
655+
return nodes
656+
651657
estimated_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
652658
nodes,
653659
name_to_buf,

torch/_inductor/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,6 +2341,10 @@ def contains_wait(snode: BaseSchedulerNode) -> bool:
23412341
return is_wait(snode.node)
23422342

23432343

2344+
def contains_collective_or_wait(snodes: list[BaseSchedulerNode]) -> bool:
2345+
return any(contains_collective(snode) or contains_wait(snode) for snode in snodes)
2346+
2347+
23442348
def is_fallback_op(
23452349
node: Optional[Operation],
23462350
op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]],

0 commit comments

Comments
 (0)