Skip to content

Commit 86670b3

Browse files
xuanzhang816pytorchmergebot
authored andcommitted
[PT2][memory] mutation size correctness (#157562)
Pull Request resolved: #157562 Approved by: https://github.com/yf225
1 parent c78bbdf commit 86670b3

File tree

3 files changed

+89
-18
lines changed

3 files changed

+89
-18
lines changed

test/distributed/test_compute_comm_reordering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,9 @@ def func(a):
179179
.check("extern_kernels.mm")
180180
.check("triton_poi_fused_relu")
181181
.check("torch.ops._c10d_functional.all_reduce_.default")
182-
.check("extern_kernels.mm")
183182
.check("torch.ops._c10d_functional.wait_tensor.default")
184183
.check("extern_kernels.mm")
184+
.check("extern_kernels.mm")
185185
.run(code)
186186
)
187187
out = compiled(inputs)

test/inductor/test_memory.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,79 @@ def reorder_with_only_dfs(
203203
outp = compiled_model(self.inputs)
204204
self.assertTrue(same(outp, outp_corr))
205205

206+
@mock.patch.object(config, "allow_buffer_reuse", False)
207+
def test_mutation_size_propogation(self):
208+
"""
209+
This tests correct size propogation in the case of mutations.
210+
In this example, buf1 is a mutation of buf0; we should have:
211+
* buf0: has size_alloc 2048 and size_free 0;
212+
* buf1: has size_alloc 0 and size_free 2048.
213+
This is because
214+
- when buf1 is created, no additional memory is used; and
215+
- the 2048 bytes of memory can only be released when buf1 is freed.
216+
Similar arguments for buf2 and buf3, buf4 and buf5, etc.
217+
"""
218+
219+
# using triton custom kernel to creat small example with mutations
220+
@triton.jit
221+
def convert_to_bf16_kernel(
222+
input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
223+
):
224+
pid = tl.program_id(axis=0)
225+
block_start = pid * BLOCK_SIZE
226+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
227+
mask = offsets < n_elements
228+
x = tl.load(input_ptr + offsets, mask=mask)
229+
x_bf16 = x.to(tl.bfloat16)
230+
tl.store(output_ptr + offsets, x_bf16, mask=mask)
231+
232+
def convert_to_bf16(x):
233+
output = torch.empty_like(x, dtype=torch.bfloat16)
234+
n_elements = x.numel()
235+
BLOCK_SIZE = 1024
236+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
237+
convert_to_bf16_kernel[grid](
238+
x.flatten(), output.flatten(), n_elements, BLOCK_SIZE
239+
)
240+
return output.view(x.shape)
241+
242+
# create a custom function to record the buffer size information
243+
buffer_info = {}
244+
og_method = memory.assign_memory_planning_info_for_scheduler_buffers
245+
246+
def assign_memory_planning_info_for_scheduler_buffers_with_records(
247+
nodes, name_to_buf
248+
):
249+
og_method(nodes, name_to_buf)
250+
for buf_name, buf in name_to_buf.items():
251+
buffer_info[buf_name] = (
252+
buf.mpi_buffer.size_alloc,
253+
buf.mpi_buffer.size_free,
254+
)
255+
256+
# test example and checks
257+
def f(a, p):
258+
for e in a:
259+
e = convert_to_bf16(e)
260+
p = p @ e
261+
return p
262+
263+
a = [torch.randn(32, 32, device=GPU_TYPE) for _ in range(4)]
264+
p = torch.ones(a[0].size(), dtype=torch.bfloat16, device=GPU_TYPE)
265+
266+
with mock.patch.object(
267+
memory,
268+
"assign_memory_planning_info_for_scheduler_buffers",
269+
assign_memory_planning_info_for_scheduler_buffers_with_records,
270+
):
271+
f_compiled = torch.compile(f)
272+
f_compiled(a, p)
273+
for buf_name in ["buf0", "buf2", "buf4", "buf6"]:
274+
self.assertEqual(buffer_info[buf_name], (2048, 0))
275+
276+
for buf_name in ["buf1", "buf3", "buf5", "buf7"]:
277+
self.assertEqual(buffer_info[buf_name], (0, 2048))
278+
206279
@unittest.skipIf(
207280
not torch.cuda.is_available()
208281
or torch.cuda.get_device_properties().total_memory < int(1e10),
@@ -228,4 +301,7 @@ def f(a, b, c):
228301
from torch._inductor.test_case import run_tests
229302

230303
if HAS_GPU:
304+
import triton
305+
from triton import language as tl
306+
231307
run_tests()

torch/_inductor/memory.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.utils._ordered_set import OrderedSet
1111

1212
from .ir import MultiOutputLayout, NoneLayout
13-
from .utils import get_dtype_size, is_wait
13+
from .utils import get_dtype_size
1414
from .virtualized import V
1515

1616

@@ -147,23 +147,18 @@ def _compute_and_update_buf_size(
147147
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
148148
) -> int:
149149
if isinstance(sched_buf.node.layout, NoneLayout):
150-
_size = 0
151-
# for a wait tensor op, its schedulerBuffer NoneLayout layout. However,
152-
# the schedulerBuffer is treated as a mutation of the collective output
153-
# so it needs to inherit the size of the collectives
154-
if (
155-
sched_buf.defining_op
156-
and is_wait(sched_buf.defining_op.node)
157-
and sched_buf.get_mutations()
158-
):
150+
# mutations should inherit the size of the mutated buffer
151+
if sched_buf.get_mutations():
159152
mutated_buf_name = sched_buf.get_mutations()[0]
160-
_size = (
161-
sched_buf_to_size[mutated_buf_name][1]
162-
if mutated_buf_name in sched_buf_to_size
163-
else 0
164-
)
165-
sched_buf_to_size[sched_buf.get_name()] = (_size, _size)
166-
return _size
153+
if mutated_buf_name in sched_buf_to_size:
154+
(_size_alloc, _size_free) = sched_buf_to_size[mutated_buf_name]
155+
else:
156+
(_size_alloc, _size_free) = (0, 0)
157+
sched_buf_to_size[sched_buf.get_name()] = (0, _size_free)
158+
sched_buf_to_size[mutated_buf_name] = (_size_alloc, 0)
159+
else:
160+
sched_buf_to_size[sched_buf.get_name()] = (0, 0)
161+
return 0
167162
elif isinstance(sched_buf.node.layout, MultiOutputLayout):
168163
size_alloc = 0
169164
for user in sched_buf.users:

0 commit comments

Comments
 (0)