@@ -203,6 +203,79 @@ def reorder_with_only_dfs(
203
203
outp = compiled_model (self .inputs )
204
204
self .assertTrue (same (outp , outp_corr ))
205
205
206
+ @mock .patch .object (config , "allow_buffer_reuse" , False )
207
+ def test_mutation_size_propogation (self ):
208
+ """
209
+ This tests correct size propogation in the case of mutations.
210
+ In this example, buf1 is a mutation of buf0; we should have:
211
+ * buf0: has size_alloc 2048 and size_free 0;
212
+ * buf1: has size_alloc 0 and size_free 2048.
213
+ This is because
214
+ - when buf1 is created, no additional memory is used; and
215
+ - the 2048 bytes of memory can only be released when buf1 is freed.
216
+ Similar arguments for buf2 and buf3, buf4 and buf5, etc.
217
+ """
218
+
219
+ # using triton custom kernel to creat small example with mutations
220
+ @triton .jit
221
+ def convert_to_bf16_kernel (
222
+ input_ptr , output_ptr , n_elements , BLOCK_SIZE : tl .constexpr
223
+ ):
224
+ pid = tl .program_id (axis = 0 )
225
+ block_start = pid * BLOCK_SIZE
226
+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
227
+ mask = offsets < n_elements
228
+ x = tl .load (input_ptr + offsets , mask = mask )
229
+ x_bf16 = x .to (tl .bfloat16 )
230
+ tl .store (output_ptr + offsets , x_bf16 , mask = mask )
231
+
232
+ def convert_to_bf16 (x ):
233
+ output = torch .empty_like (x , dtype = torch .bfloat16 )
234
+ n_elements = x .numel ()
235
+ BLOCK_SIZE = 1024
236
+ grid = (triton .cdiv (n_elements , BLOCK_SIZE ),)
237
+ convert_to_bf16_kernel [grid ](
238
+ x .flatten (), output .flatten (), n_elements , BLOCK_SIZE
239
+ )
240
+ return output .view (x .shape )
241
+
242
+ # create a custom function to record the buffer size information
243
+ buffer_info = {}
244
+ og_method = memory .assign_memory_planning_info_for_scheduler_buffers
245
+
246
+ def assign_memory_planning_info_for_scheduler_buffers_with_records (
247
+ nodes , name_to_buf
248
+ ):
249
+ og_method (nodes , name_to_buf )
250
+ for buf_name , buf in name_to_buf .items ():
251
+ buffer_info [buf_name ] = (
252
+ buf .mpi_buffer .size_alloc ,
253
+ buf .mpi_buffer .size_free ,
254
+ )
255
+
256
+ # test example and checks
257
+ def f (a , p ):
258
+ for e in a :
259
+ e = convert_to_bf16 (e )
260
+ p = p @ e
261
+ return p
262
+
263
+ a = [torch .randn (32 , 32 , device = GPU_TYPE ) for _ in range (4 )]
264
+ p = torch .ones (a [0 ].size (), dtype = torch .bfloat16 , device = GPU_TYPE )
265
+
266
+ with mock .patch .object (
267
+ memory ,
268
+ "assign_memory_planning_info_for_scheduler_buffers" ,
269
+ assign_memory_planning_info_for_scheduler_buffers_with_records ,
270
+ ):
271
+ f_compiled = torch .compile (f )
272
+ f_compiled (a , p )
273
+ for buf_name in ["buf0" , "buf2" , "buf4" , "buf6" ]:
274
+ self .assertEqual (buffer_info [buf_name ], (2048 , 0 ))
275
+
276
+ for buf_name in ["buf1" , "buf3" , "buf5" , "buf7" ]:
277
+ self .assertEqual (buffer_info [buf_name ], (0 , 2048 ))
278
+
206
279
@unittest .skipIf (
207
280
not torch .cuda .is_available ()
208
281
or torch .cuda .get_device_properties ().total_memory < int (1e10 ),
@@ -228,4 +301,7 @@ def f(a, b, c):
228
301
from torch ._inductor .test_case import run_tests
229
302
230
303
if HAS_GPU :
304
+ import triton
305
+ from triton import language as tl
306
+
231
307
run_tests ()
0 commit comments