Skip to content

Commit fea3fa7

Browse files
committed
[inductor] add lookup table recorder
\# Why make it easier for users to generate lookup tables and use them \# What - infrastructure to record lookup tables from autotuning - sample implementation for recording to directory - sample implementation for emitting to log.debug (key/value) \## caveats - right now it just records mm_templates and everything that inductor considers. There are some architectural changes needed to make it record e.g. a topk after autotuning. once that is ready, this is modular enough to adjust to recording only topk, however there is value now in being able to record a simple table, see the format, and manually edit it down to the topk entries using the autotuning logs \# Testing using ``` \#!/bin/bash \# Create a temporary directory for the lookup table dumps TEMP_DIR=$(mktemp -d) echo "Created temporary directory for lookup table dumps: $TEMP_DIR" \# Set environment variables to enable verbose output and recording export TORCH_LOGS="+inductor" export TORCH_INDUCTOR_LOOKUP_TABLE_RECORD_DIR=$TEMP_DIR export PYTORCH_DEBUG=1 \# Run the Python script python3 -c " import os import torch import logging from torch._inductor import config as inductor_config from torch._inductor.lookup_table_recorder import dump \# Configure logging to see emit messages logging.basicConfig(level=logging.DEBUG) \# Enable TMA for matmul inductor_config.triton.enable_persistent_tma_matmul = True \# Create large tensors with bfloat16 dtype print('Creating 1024x1024 bfloat16 tensors for matrix multiplication...') a = torch.randn(1024, 1024, device='cuda', dtype=torch.bfloat16) b = torch.randn(1024, 1024, device='cuda', dtype=torch.bfloat16) \# Compile and run the matrix multiplication print('Compiling and running torch.mm with TMA...') compiled_mm = torch.compile(torch.mm, mode='max-autotune') result = compiled_mm(a, b) \# Force synchronization to ensure compilation is complete torch.cuda.synchronize() \# Dump the lookup table print('Dumping lookup table...') dump() print('\\nMatrix multiplication completed successfully!') " 2>&1 | tee /tmp/recorder_output.log \# Check if emit logic works by grepping for LookupTable entries echo -e "\n\n=== CHECKING EMIT FUNCTIONALITY ===" if grep -q "LookupTable:" /tmp/recorder_output.log; then echo "✅ Emit functionality is working! Found LookupTable entries in the log." else echo "❌ Emit functionality not detected. No LookupTable entries found in the log." fi \# Display the dumped lookup table echo -e "\n\n=== DUMPED LOOKUP TABLE CONTENTS ===" LATEST_JSON=$(ls -t $TEMP_DIR/inductor_lut_*.json | head -1) if [ -f "$LATEST_JSON" ]; then echo "Found lookup table file: $LATEST_JSON" echo "File size: $(du -h $LATEST_JSON | cut -f1)" echo -e "\nFirst 20 lines of the lookup table:" head -n 20 $LATEST_JSON # Check for TMA entries echo -e "\n=== CHECKING FOR TMA ENTRIES ===" if grep -q "tma\|TMA_SIZE\|NUM_SMS" $LATEST_JSON; then echo "✅ TMA entries found in the lookup table!" echo -e "\nSample TMA entry:" grep -m 1 -A 10 -B 2 "tma\|TMA_SIZE\|NUM_SMS" $LATEST_JSON else echo "❌ No TMA entries found in the lookup table." fi else echo "❌ No lookup table JSON file found in $TEMP_DIR" fi echo -e "\n\nLookup table files are available in: $TEMP_DIR" echo "Log file is available at: /tmp/recorder_output.log" ``` ``` === CHECKING EMIT FUNCTIONALITY === ✅ Emit functionality is working! Found LookupTable entries in the log. === DUMPED LOOKUP TABLE CONTENTS === Found lookup table file: /tmp/tmp.L9pydR3sH4/inductor_lut_20250723_221836_641.json File size: 12K First 20 lines of the lookup table: { "NVIDIA H100+mm+((torch.bfloat16, [1024, 1024], [1024, 1]), (torch.bfloat16, [1024, 1024], [1024, 1]))": [ { "template_id": "mm", "EVEN_K": true, "ALLOW_TF32": false, "USE_FAST_ACCUM": false, "ACC_TYPE": "tl.float32", "num_stages": 1, "num_warps": 2, "BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16, "hint_override": null, "GROUP_M": 8, "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" }, { "template_id": "mm", "EVEN_K": true, === CHECKING FOR TMA ENTRIES === ✅ TMA entries found in the lookup table! Sample TMA entry: }, { "template_id": "mm_persistent_tma", "EVEN_K": true, "ALLOW_TF32": false, "USE_FAST_ACCUM": false, "ACC_TYPE": "tl.float32", "num_stages": 3, "num_warps": 8, "BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "hint_override": null, Lookup table files are available in: /tmp/tmp.L9pydR3sH4 Log file is available at: /tmp/recorder_output.log ``` ghstack-source-id: f378305 Pull Request resolved: #158987
1 parent 5e64c25 commit fea3fa7

File tree

4 files changed

+650
-9
lines changed

4 files changed

+650
-9
lines changed
Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Owner(s): ["module: inductor"]
2+
import json
3+
import os
4+
import shutil
5+
import tempfile
6+
import unittest
7+
from typing import Any
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch._inductor import config as inductor_config
12+
from torch._inductor.lookup_table_recorder import (
13+
clear,
14+
DirectoryRecordBackend,
15+
emit_backend,
16+
EmitBackend,
17+
get_lookup_table_recorder,
18+
LookupTableEntry,
19+
record_backend,
20+
RecordBackend,
21+
)
22+
from torch._inductor.test_case import run_tests, TestCase
23+
from torch._inductor.utils import fresh_cache
24+
from torch.testing._internal.inductor_utils import HAS_CUDA
25+
from torch.utils._triton import has_triton_tma_device
26+
27+
28+
class TestEmitBackend(EmitBackend):
29+
"""Test emit backend that captures emitted entries"""
30+
31+
def __init__(self):
32+
self.emitted_entries: list[LookupTableEntry] = []
33+
34+
def emit(self, entry: LookupTableEntry):
35+
self.emitted_entries.append(entry)
36+
37+
38+
class TestRecordBackend(RecordBackend):
39+
"""Test record backend that captures dumped data"""
40+
41+
def __init__(self):
42+
self.dumped_data: dict[str, list[dict[str, Any]]] = {}
43+
44+
def dump(self, data: dict[str, list[dict[str, Any]]]):
45+
self.dumped_data = data.copy()
46+
47+
48+
class SimpleMMModel(nn.Module):
49+
"""Simple model that performs matrix multiplication"""
50+
51+
def forward(self, a, b):
52+
return torch.mm(a, b)
53+
54+
55+
def force_recorder_reset():
56+
"""Force the global recorder to be recreated on next access"""
57+
import torch._inductor.lookup_table_recorder as recorder_module
58+
59+
recorder_module._lookup_table_recorder = None
60+
61+
62+
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
63+
class TestLookupTableRecorder(TestCase):
64+
"""Test suite for lookup table recorder functionality"""
65+
66+
def setUp(self):
67+
torch._dynamo.reset()
68+
self.device = torch.device("cuda")
69+
self.temp_dir = tempfile.mkdtemp()
70+
71+
# Store original config values
72+
self.original_recorder_emit = (
73+
inductor_config.template_lookup_table_config.recorder_emit
74+
)
75+
self.original_recorder_record_dir = (
76+
inductor_config.template_lookup_table_config.recorder_record_dir
77+
)
78+
self.original_lookup_table = inductor_config.template_lookup_table
79+
80+
# Clear any existing recorder
81+
clear()
82+
force_recorder_reset()
83+
84+
def tearDown(self):
85+
# Restore original config
86+
inductor_config.template_lookup_table_config.recorder_emit = (
87+
self.original_recorder_emit
88+
)
89+
inductor_config.template_lookup_table_config.recorder_record_dir = (
90+
self.original_recorder_record_dir
91+
)
92+
inductor_config.template_lookup_table = self.original_lookup_table
93+
94+
# Clean up temp files
95+
shutil.rmtree(self.temp_dir, ignore_errors=True)
96+
97+
# Clear recorder
98+
clear()
99+
force_recorder_reset()
100+
101+
def create_simple_mm_tensors(self):
102+
"""Create small test tensors for torch.mm"""
103+
return [
104+
torch.randn(64, 32, device=self.device, dtype=torch.float16),
105+
torch.randn(32, 64, device=self.device, dtype=torch.float16),
106+
]
107+
108+
def compile_and_run_mm(self, config_patches=None):
109+
"""Compile and execute a simple mm operation"""
110+
default_config = {"max_autotune_gemm": True}
111+
if config_patches:
112+
default_config.update(config_patches)
113+
114+
with inductor_config.patch(default_config):
115+
model = SimpleMMModel().to(self.device)
116+
tensors = self.create_simple_mm_tensors()
117+
compiled_model = torch.compile(model, mode="max-autotune")
118+
return compiled_model(*tensors)
119+
120+
@fresh_cache()
121+
def test_emit_works(self):
122+
"""Test that emit functionality works correctly"""
123+
# Force recorder reset and create test backend
124+
force_recorder_reset()
125+
test_emit_backend = TestEmitBackend()
126+
127+
# Get fresh recorder and add our test backend
128+
recorder = get_lookup_table_recorder()
129+
recorder.add_backend(test_emit_backend)
130+
131+
# Trigger compilation with autotuning
132+
self.compile_and_run_mm()
133+
134+
# Verify entries were emitted
135+
self.assertGreater(
136+
len(test_emit_backend.emitted_entries),
137+
0,
138+
"Expected at least one entry to be emitted",
139+
)
140+
141+
# Check structure of emitted entries
142+
for entry in test_emit_backend.emitted_entries:
143+
self.assertIsInstance(entry, LookupTableEntry)
144+
self.assertIsInstance(entry.key, str)
145+
self.assertIsInstance(entry.value, dict)
146+
self.assertIn("template_id", entry.value)
147+
148+
@fresh_cache()
149+
def test_directory_record_backend(self):
150+
"""Test that DirectoryRecordBackend creates timestamped files correctly"""
151+
# Setup directory for dumping
152+
dump_dir = os.path.join(self.temp_dir, "test_dump_dir")
153+
154+
# Force recorder reset and add DirectoryRecordBackend
155+
force_recorder_reset()
156+
recorder = get_lookup_table_recorder()
157+
dir_backend = DirectoryRecordBackend(dump_dir)
158+
recorder.add_backend(dir_backend)
159+
160+
# Trigger compilation with autotuning
161+
self.compile_and_run_mm()
162+
163+
# Trigger dump
164+
recorder.dump()
165+
166+
# Verify directory was created
167+
self.assertTrue(os.path.exists(dump_dir), "Dump directory should be created")
168+
self.assertTrue(os.path.isdir(dump_dir), "Dump path should be a directory")
169+
170+
# Find the generated file
171+
files = os.listdir(dump_dir)
172+
json_files = [
173+
f for f in files if f.endswith(".json") and f.startswith("inductor_lut_")
174+
]
175+
self.assertEqual(len(json_files), 1, "Should have exactly one JSON file")
176+
177+
# Verify filename format (inductor_lut_YYYYMMDD_HHMMSS_mmm.json)
178+
filename = json_files[0]
179+
self.assertTrue(
180+
filename.startswith("inductor_lut_"),
181+
"Filename should start with 'inductor_lut_'",
182+
)
183+
self.assertTrue(filename.endswith(".json"), "Filename should end with '.json'")
184+
185+
# Extract timestamp part and verify format
186+
timestamp_part = filename[len("inductor_lut_") : -len(".json")]
187+
parts = timestamp_part.split("_")
188+
self.assertEqual(
189+
len(parts), 3, "Timestamp should have 3 parts: date, time, milliseconds"
190+
)
191+
192+
date_part, time_part, ms_part = parts
193+
self.assertEqual(len(date_part), 8, "Date part should be 8 digits (YYYYMMDD)")
194+
self.assertEqual(len(time_part), 6, "Time part should be 6 digits (HHMMSS)")
195+
self.assertEqual(len(ms_part), 3, "Millisecond part should be 3 digits")
196+
197+
# Verify file contains valid JSON
198+
filepath = os.path.join(dump_dir, filename)
199+
with open(filepath) as f:
200+
data = json.load(f)
201+
202+
self.assertIsInstance(data, dict)
203+
self.assertGreater(len(data), 0, "Expected at least one entry in dump")
204+
205+
# Check structure of dumped data
206+
for key, configs in data.items():
207+
self.assertIsInstance(key, str)
208+
self.assertIsInstance(configs, list)
209+
self.assertGreater(len(configs), 0)
210+
for config in configs:
211+
self.assertIsInstance(config, dict)
212+
self.assertIn("template_id", config)
213+
214+
@fresh_cache()
215+
def test_end_to_end_workflow(self):
216+
"""Test complete workflow from recording to reading and feeding back to inductor"""
217+
# Step 1: Record a lookup table during first compilation
218+
dump_dir = os.path.join(self.temp_dir, "e2e_test_dir")
219+
220+
# Force recorder reset and add DirectoryRecordBackend
221+
force_recorder_reset()
222+
recorder = get_lookup_table_recorder()
223+
dir_backend = DirectoryRecordBackend(dump_dir)
224+
recorder.add_backend(dir_backend)
225+
226+
# First compilation - this should record entries
227+
_ = self.compile_and_run_mm()
228+
229+
# Dump the recorded table
230+
recorder.dump()
231+
232+
# Verify directory and file were created
233+
self.assertTrue(os.path.exists(dump_dir), "Dump directory should be created")
234+
files = os.listdir(dump_dir)
235+
json_files = [
236+
f for f in files if f.endswith(".json") and f.startswith("inductor_lut_")
237+
]
238+
self.assertEqual(len(json_files), 1, "Should have exactly one JSON file")
239+
240+
# Step 2: Read the table from the file
241+
dump_file = os.path.join(dump_dir, json_files[0])
242+
with open(dump_file) as f:
243+
recorded_table = json.load(f)
244+
245+
self.assertGreater(len(recorded_table), 0, "Should have recorded some entries")
246+
247+
# Step 3: Configure inductor to use the recorded table
248+
inductor_config.template_lookup_table = recorded_table
249+
250+
# Clear the recorder to start fresh
251+
clear()
252+
force_recorder_reset()
253+
254+
# Step 4: Compile the same operation again
255+
# This should work and throw no errors
256+
_ = self.compile_and_run_mm()
257+
258+
@fresh_cache()
259+
def test_recorder_clear_functionality(self):
260+
"""Test that clear functionality works correctly"""
261+
# Setup recording
262+
force_recorder_reset()
263+
recorder = get_lookup_table_recorder()
264+
test_record_backend = TestRecordBackend()
265+
recorder.add_backend(test_record_backend)
266+
267+
# Trigger compilation to populate data
268+
self.compile_and_run_mm()
269+
270+
# Verify data exists
271+
self.assertGreater(len(recorder.data), 0)
272+
273+
# Clear and verify data is gone
274+
recorder.clear()
275+
self.assertEqual(len(recorder.data), 0)
276+
277+
@fresh_cache()
278+
def test_decorator_registration_conditional(self):
279+
"""Test that decorator registration respects should_register parameter"""
280+
281+
# Create two backend classes with decorators - one should register, one shouldn't
282+
@emit_backend(should_register=True)
283+
class ShouldRegisterEmitBackend(EmitBackend):
284+
def emit(self, entry: LookupTableEntry):
285+
pass
286+
287+
@emit_backend(should_register=False)
288+
class ShouldNotRegisterEmitBackend(EmitBackend):
289+
def emit(self, entry: LookupTableEntry):
290+
pass
291+
292+
@record_backend(should_register=True)
293+
class ShouldRegisterRecordBackend(RecordBackend):
294+
def dump(self, data: dict[str, list[dict[str, Any]]]):
295+
pass
296+
297+
@record_backend(should_register=False)
298+
class ShouldNotRegisterRecordBackend(RecordBackend):
299+
def dump(self, data: dict[str, list[dict[str, Any]]]):
300+
pass
301+
302+
# Force recorder reset to trigger re-registration with new backends
303+
force_recorder_reset()
304+
recorder = get_lookup_table_recorder()
305+
306+
# Check that only the backends with should_register=True were registered
307+
registered_emit_types = [type(backend) for backend in recorder.emit_backends]
308+
registered_record_types = [
309+
type(backend) for backend in recorder.record_backends
310+
]
311+
312+
# Should have registered the "should register" backends
313+
self.assertIn(ShouldRegisterEmitBackend, registered_emit_types)
314+
self.assertIn(ShouldRegisterRecordBackend, registered_record_types)
315+
316+
# Should NOT have registered the "should not register" backends
317+
self.assertNotIn(ShouldNotRegisterEmitBackend, registered_emit_types)
318+
self.assertNotIn(ShouldNotRegisterRecordBackend, registered_record_types)
319+
320+
@unittest.skipIf(
321+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
322+
)
323+
@fresh_cache()
324+
def test_tma_entries_recorded(self):
325+
"""Test that TMA-specific entries are recorded when TMA is enabled"""
326+
# Create custom record backend to capture data
327+
test_record_backend = TestRecordBackend()
328+
329+
# Enable TMA and recording
330+
inductor_config.template_lookup_table_config.recorder_emit = True
331+
332+
# Get recorder and add our test backend
333+
recorder = get_lookup_table_recorder()
334+
self.assertIsNotNone(recorder)
335+
assert recorder is not None # Type narrowing for mypy
336+
recorder.add_backend(test_record_backend)
337+
338+
# Trigger compilation with TMA enabled
339+
self.compile_and_run_mm({"triton.enable_persistent_tma_matmul": True})
340+
341+
# Trigger dump to populate our test backend
342+
recorder.dump()
343+
344+
# Check if any entries contain TMA-related values
345+
tma_entries_found = False
346+
for configs in test_record_backend.dumped_data.values():
347+
for config in configs:
348+
# Check for TMA in template_id or TMA-specific parameters
349+
if (
350+
"tma" in config.get("template_id", "").lower()
351+
or "TMA_SIZE" in config
352+
or "NUM_SMS" in config
353+
or "TMA_EXPERIMENTAL_API" in config
354+
):
355+
tma_entries_found = True
356+
break
357+
if tma_entries_found:
358+
break
359+
360+
self.assertTrue(
361+
tma_entries_found,
362+
"Expected to find at least one entry with TMA-related values",
363+
)
364+
365+
366+
if __name__ == "__main__":
367+
from torch._inductor.utils import is_big_gpu
368+
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU
369+
370+
# Set env to make it work in CI
371+
if HAS_GPU and HAS_CPU and is_big_gpu():
372+
run_tests()

0 commit comments

Comments
 (0)