|
| 1 | +# Owner(s): ["module: inductor"] |
| 2 | +import json |
| 3 | +import os |
| 4 | +import shutil |
| 5 | +import tempfile |
| 6 | +import unittest |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch._inductor import config as inductor_config |
| 12 | +from torch._inductor.lookup_table_recorder import ( |
| 13 | + clear, |
| 14 | + DirectoryRecordBackend, |
| 15 | + emit_backend, |
| 16 | + EmitBackend, |
| 17 | + get_lookup_table_recorder, |
| 18 | + LookupTableEntry, |
| 19 | + record_backend, |
| 20 | + RecordBackend, |
| 21 | +) |
| 22 | +from torch._inductor.test_case import run_tests, TestCase |
| 23 | +from torch._inductor.utils import fresh_cache |
| 24 | +from torch.testing._internal.inductor_utils import HAS_CUDA |
| 25 | +from torch.utils._triton import has_triton_tma_device |
| 26 | + |
| 27 | + |
| 28 | +class TestEmitBackend(EmitBackend): |
| 29 | + """Test emit backend that captures emitted entries""" |
| 30 | + |
| 31 | + def __init__(self): |
| 32 | + self.emitted_entries: list[LookupTableEntry] = [] |
| 33 | + |
| 34 | + def emit(self, entry: LookupTableEntry): |
| 35 | + self.emitted_entries.append(entry) |
| 36 | + |
| 37 | + |
| 38 | +class TestRecordBackend(RecordBackend): |
| 39 | + """Test record backend that captures dumped data""" |
| 40 | + |
| 41 | + def __init__(self): |
| 42 | + self.dumped_data: dict[str, list[dict[str, Any]]] = {} |
| 43 | + |
| 44 | + def dump(self, data: dict[str, list[dict[str, Any]]]): |
| 45 | + self.dumped_data = data.copy() |
| 46 | + |
| 47 | + |
| 48 | +class SimpleMMModel(nn.Module): |
| 49 | + """Simple model that performs matrix multiplication""" |
| 50 | + |
| 51 | + def forward(self, a, b): |
| 52 | + return torch.mm(a, b) |
| 53 | + |
| 54 | + |
| 55 | +def force_recorder_reset(): |
| 56 | + """Force the global recorder to be recreated on next access""" |
| 57 | + import torch._inductor.lookup_table_recorder as recorder_module |
| 58 | + |
| 59 | + recorder_module._lookup_table_recorder = None |
| 60 | + |
| 61 | + |
| 62 | +@unittest.skipIf(not HAS_CUDA, "CUDA not available") |
| 63 | +class TestLookupTableRecorder(TestCase): |
| 64 | + """Test suite for lookup table recorder functionality""" |
| 65 | + |
| 66 | + def setUp(self): |
| 67 | + torch._dynamo.reset() |
| 68 | + self.device = torch.device("cuda") |
| 69 | + self.temp_dir = tempfile.mkdtemp() |
| 70 | + |
| 71 | + # Store original config values |
| 72 | + self.original_recorder_emit = ( |
| 73 | + inductor_config.template_lookup_table_config.recorder_emit |
| 74 | + ) |
| 75 | + self.original_recorder_record_dir = ( |
| 76 | + inductor_config.template_lookup_table_config.recorder_record_dir |
| 77 | + ) |
| 78 | + self.original_lookup_table = inductor_config.template_lookup_table |
| 79 | + |
| 80 | + # Clear any existing recorder |
| 81 | + clear() |
| 82 | + force_recorder_reset() |
| 83 | + |
| 84 | + def tearDown(self): |
| 85 | + # Restore original config |
| 86 | + inductor_config.template_lookup_table_config.recorder_emit = ( |
| 87 | + self.original_recorder_emit |
| 88 | + ) |
| 89 | + inductor_config.template_lookup_table_config.recorder_record_dir = ( |
| 90 | + self.original_recorder_record_dir |
| 91 | + ) |
| 92 | + inductor_config.template_lookup_table = self.original_lookup_table |
| 93 | + |
| 94 | + # Clean up temp files |
| 95 | + shutil.rmtree(self.temp_dir, ignore_errors=True) |
| 96 | + |
| 97 | + # Clear recorder |
| 98 | + clear() |
| 99 | + force_recorder_reset() |
| 100 | + |
| 101 | + def create_simple_mm_tensors(self): |
| 102 | + """Create small test tensors for torch.mm""" |
| 103 | + return [ |
| 104 | + torch.randn(64, 32, device=self.device, dtype=torch.float16), |
| 105 | + torch.randn(32, 64, device=self.device, dtype=torch.float16), |
| 106 | + ] |
| 107 | + |
| 108 | + def compile_and_run_mm(self, config_patches=None): |
| 109 | + """Compile and execute a simple mm operation""" |
| 110 | + default_config = {"max_autotune_gemm": True} |
| 111 | + if config_patches: |
| 112 | + default_config.update(config_patches) |
| 113 | + |
| 114 | + with inductor_config.patch(default_config): |
| 115 | + model = SimpleMMModel().to(self.device) |
| 116 | + tensors = self.create_simple_mm_tensors() |
| 117 | + compiled_model = torch.compile(model, mode="max-autotune") |
| 118 | + return compiled_model(*tensors) |
| 119 | + |
| 120 | + @fresh_cache() |
| 121 | + def test_emit_works(self): |
| 122 | + """Test that emit functionality works correctly""" |
| 123 | + # Force recorder reset and create test backend |
| 124 | + force_recorder_reset() |
| 125 | + test_emit_backend = TestEmitBackend() |
| 126 | + |
| 127 | + # Get fresh recorder and add our test backend |
| 128 | + recorder = get_lookup_table_recorder() |
| 129 | + recorder.add_backend(test_emit_backend) |
| 130 | + |
| 131 | + # Trigger compilation with autotuning |
| 132 | + self.compile_and_run_mm() |
| 133 | + |
| 134 | + # Verify entries were emitted |
| 135 | + self.assertGreater( |
| 136 | + len(test_emit_backend.emitted_entries), |
| 137 | + 0, |
| 138 | + "Expected at least one entry to be emitted", |
| 139 | + ) |
| 140 | + |
| 141 | + # Check structure of emitted entries |
| 142 | + for entry in test_emit_backend.emitted_entries: |
| 143 | + self.assertIsInstance(entry, LookupTableEntry) |
| 144 | + self.assertIsInstance(entry.key, str) |
| 145 | + self.assertIsInstance(entry.value, dict) |
| 146 | + self.assertIn("template_id", entry.value) |
| 147 | + |
| 148 | + @fresh_cache() |
| 149 | + def test_directory_record_backend(self): |
| 150 | + """Test that DirectoryRecordBackend creates timestamped files correctly""" |
| 151 | + # Setup directory for dumping |
| 152 | + dump_dir = os.path.join(self.temp_dir, "test_dump_dir") |
| 153 | + |
| 154 | + # Force recorder reset and add DirectoryRecordBackend |
| 155 | + force_recorder_reset() |
| 156 | + recorder = get_lookup_table_recorder() |
| 157 | + dir_backend = DirectoryRecordBackend(dump_dir) |
| 158 | + recorder.add_backend(dir_backend) |
| 159 | + |
| 160 | + # Trigger compilation with autotuning |
| 161 | + self.compile_and_run_mm() |
| 162 | + |
| 163 | + # Trigger dump |
| 164 | + recorder.dump() |
| 165 | + |
| 166 | + # Verify directory was created |
| 167 | + self.assertTrue(os.path.exists(dump_dir), "Dump directory should be created") |
| 168 | + self.assertTrue(os.path.isdir(dump_dir), "Dump path should be a directory") |
| 169 | + |
| 170 | + # Find the generated file |
| 171 | + files = os.listdir(dump_dir) |
| 172 | + json_files = [ |
| 173 | + f for f in files if f.endswith(".json") and f.startswith("inductor_lut_") |
| 174 | + ] |
| 175 | + self.assertEqual(len(json_files), 1, "Should have exactly one JSON file") |
| 176 | + |
| 177 | + # Verify filename format (inductor_lut_YYYYMMDD_HHMMSS_mmm.json) |
| 178 | + filename = json_files[0] |
| 179 | + self.assertTrue( |
| 180 | + filename.startswith("inductor_lut_"), |
| 181 | + "Filename should start with 'inductor_lut_'", |
| 182 | + ) |
| 183 | + self.assertTrue(filename.endswith(".json"), "Filename should end with '.json'") |
| 184 | + |
| 185 | + # Extract timestamp part and verify format |
| 186 | + timestamp_part = filename[len("inductor_lut_") : -len(".json")] |
| 187 | + parts = timestamp_part.split("_") |
| 188 | + self.assertEqual( |
| 189 | + len(parts), 3, "Timestamp should have 3 parts: date, time, milliseconds" |
| 190 | + ) |
| 191 | + |
| 192 | + date_part, time_part, ms_part = parts |
| 193 | + self.assertEqual(len(date_part), 8, "Date part should be 8 digits (YYYYMMDD)") |
| 194 | + self.assertEqual(len(time_part), 6, "Time part should be 6 digits (HHMMSS)") |
| 195 | + self.assertEqual(len(ms_part), 3, "Millisecond part should be 3 digits") |
| 196 | + |
| 197 | + # Verify file contains valid JSON |
| 198 | + filepath = os.path.join(dump_dir, filename) |
| 199 | + with open(filepath) as f: |
| 200 | + data = json.load(f) |
| 201 | + |
| 202 | + self.assertIsInstance(data, dict) |
| 203 | + self.assertGreater(len(data), 0, "Expected at least one entry in dump") |
| 204 | + |
| 205 | + # Check structure of dumped data |
| 206 | + for key, configs in data.items(): |
| 207 | + self.assertIsInstance(key, str) |
| 208 | + self.assertIsInstance(configs, list) |
| 209 | + self.assertGreater(len(configs), 0) |
| 210 | + for config in configs: |
| 211 | + self.assertIsInstance(config, dict) |
| 212 | + self.assertIn("template_id", config) |
| 213 | + |
| 214 | + @fresh_cache() |
| 215 | + def test_end_to_end_workflow(self): |
| 216 | + """Test complete workflow from recording to reading and feeding back to inductor""" |
| 217 | + # Step 1: Record a lookup table during first compilation |
| 218 | + dump_dir = os.path.join(self.temp_dir, "e2e_test_dir") |
| 219 | + |
| 220 | + # Force recorder reset and add DirectoryRecordBackend |
| 221 | + force_recorder_reset() |
| 222 | + recorder = get_lookup_table_recorder() |
| 223 | + dir_backend = DirectoryRecordBackend(dump_dir) |
| 224 | + recorder.add_backend(dir_backend) |
| 225 | + |
| 226 | + # First compilation - this should record entries |
| 227 | + _ = self.compile_and_run_mm() |
| 228 | + |
| 229 | + # Dump the recorded table |
| 230 | + recorder.dump() |
| 231 | + |
| 232 | + # Verify directory and file were created |
| 233 | + self.assertTrue(os.path.exists(dump_dir), "Dump directory should be created") |
| 234 | + files = os.listdir(dump_dir) |
| 235 | + json_files = [ |
| 236 | + f for f in files if f.endswith(".json") and f.startswith("inductor_lut_") |
| 237 | + ] |
| 238 | + self.assertEqual(len(json_files), 1, "Should have exactly one JSON file") |
| 239 | + |
| 240 | + # Step 2: Read the table from the file |
| 241 | + dump_file = os.path.join(dump_dir, json_files[0]) |
| 242 | + with open(dump_file) as f: |
| 243 | + recorded_table = json.load(f) |
| 244 | + |
| 245 | + self.assertGreater(len(recorded_table), 0, "Should have recorded some entries") |
| 246 | + |
| 247 | + # Step 3: Configure inductor to use the recorded table |
| 248 | + inductor_config.template_lookup_table = recorded_table |
| 249 | + |
| 250 | + # Clear the recorder to start fresh |
| 251 | + clear() |
| 252 | + force_recorder_reset() |
| 253 | + |
| 254 | + # Step 4: Compile the same operation again |
| 255 | + # This should work and throw no errors |
| 256 | + _ = self.compile_and_run_mm() |
| 257 | + |
| 258 | + @fresh_cache() |
| 259 | + def test_recorder_clear_functionality(self): |
| 260 | + """Test that clear functionality works correctly""" |
| 261 | + # Setup recording |
| 262 | + force_recorder_reset() |
| 263 | + recorder = get_lookup_table_recorder() |
| 264 | + test_record_backend = TestRecordBackend() |
| 265 | + recorder.add_backend(test_record_backend) |
| 266 | + |
| 267 | + # Trigger compilation to populate data |
| 268 | + self.compile_and_run_mm() |
| 269 | + |
| 270 | + # Verify data exists |
| 271 | + self.assertGreater(len(recorder.data), 0) |
| 272 | + |
| 273 | + # Clear and verify data is gone |
| 274 | + recorder.clear() |
| 275 | + self.assertEqual(len(recorder.data), 0) |
| 276 | + |
| 277 | + @fresh_cache() |
| 278 | + def test_decorator_registration_conditional(self): |
| 279 | + """Test that decorator registration respects should_register parameter""" |
| 280 | + |
| 281 | + # Create two backend classes with decorators - one should register, one shouldn't |
| 282 | + @emit_backend(should_register=True) |
| 283 | + class ShouldRegisterEmitBackend(EmitBackend): |
| 284 | + def emit(self, entry: LookupTableEntry): |
| 285 | + pass |
| 286 | + |
| 287 | + @emit_backend(should_register=False) |
| 288 | + class ShouldNotRegisterEmitBackend(EmitBackend): |
| 289 | + def emit(self, entry: LookupTableEntry): |
| 290 | + pass |
| 291 | + |
| 292 | + @record_backend(should_register=True) |
| 293 | + class ShouldRegisterRecordBackend(RecordBackend): |
| 294 | + def dump(self, data: dict[str, list[dict[str, Any]]]): |
| 295 | + pass |
| 296 | + |
| 297 | + @record_backend(should_register=False) |
| 298 | + class ShouldNotRegisterRecordBackend(RecordBackend): |
| 299 | + def dump(self, data: dict[str, list[dict[str, Any]]]): |
| 300 | + pass |
| 301 | + |
| 302 | + # Force recorder reset to trigger re-registration with new backends |
| 303 | + force_recorder_reset() |
| 304 | + recorder = get_lookup_table_recorder() |
| 305 | + |
| 306 | + # Check that only the backends with should_register=True were registered |
| 307 | + registered_emit_types = [type(backend) for backend in recorder.emit_backends] |
| 308 | + registered_record_types = [ |
| 309 | + type(backend) for backend in recorder.record_backends |
| 310 | + ] |
| 311 | + |
| 312 | + # Should have registered the "should register" backends |
| 313 | + self.assertIn(ShouldRegisterEmitBackend, registered_emit_types) |
| 314 | + self.assertIn(ShouldRegisterRecordBackend, registered_record_types) |
| 315 | + |
| 316 | + # Should NOT have registered the "should not register" backends |
| 317 | + self.assertNotIn(ShouldNotRegisterEmitBackend, registered_emit_types) |
| 318 | + self.assertNotIn(ShouldNotRegisterRecordBackend, registered_record_types) |
| 319 | + |
| 320 | + @unittest.skipIf( |
| 321 | + not has_triton_tma_device(), "Need device-side TMA support in Triton" |
| 322 | + ) |
| 323 | + @fresh_cache() |
| 324 | + def test_tma_entries_recorded(self): |
| 325 | + """Test that TMA-specific entries are recorded when TMA is enabled""" |
| 326 | + # Create custom record backend to capture data |
| 327 | + test_record_backend = TestRecordBackend() |
| 328 | + |
| 329 | + # Enable TMA and recording |
| 330 | + inductor_config.template_lookup_table_config.recorder_emit = True |
| 331 | + |
| 332 | + # Get recorder and add our test backend |
| 333 | + recorder = get_lookup_table_recorder() |
| 334 | + self.assertIsNotNone(recorder) |
| 335 | + assert recorder is not None # Type narrowing for mypy |
| 336 | + recorder.add_backend(test_record_backend) |
| 337 | + |
| 338 | + # Trigger compilation with TMA enabled |
| 339 | + self.compile_and_run_mm({"triton.enable_persistent_tma_matmul": True}) |
| 340 | + |
| 341 | + # Trigger dump to populate our test backend |
| 342 | + recorder.dump() |
| 343 | + |
| 344 | + # Check if any entries contain TMA-related values |
| 345 | + tma_entries_found = False |
| 346 | + for configs in test_record_backend.dumped_data.values(): |
| 347 | + for config in configs: |
| 348 | + # Check for TMA in template_id or TMA-specific parameters |
| 349 | + if ( |
| 350 | + "tma" in config.get("template_id", "").lower() |
| 351 | + or "TMA_SIZE" in config |
| 352 | + or "NUM_SMS" in config |
| 353 | + or "TMA_EXPERIMENTAL_API" in config |
| 354 | + ): |
| 355 | + tma_entries_found = True |
| 356 | + break |
| 357 | + if tma_entries_found: |
| 358 | + break |
| 359 | + |
| 360 | + self.assertTrue( |
| 361 | + tma_entries_found, |
| 362 | + "Expected to find at least one entry with TMA-related values", |
| 363 | + ) |
| 364 | + |
| 365 | + |
| 366 | +if __name__ == "__main__": |
| 367 | + from torch._inductor.utils import is_big_gpu |
| 368 | + from torch.testing._internal.inductor_utils import HAS_CPU, HAS_GPU |
| 369 | + |
| 370 | + # Set env to make it work in CI |
| 371 | + if HAS_GPU and HAS_CPU and is_big_gpu(): |
| 372 | + run_tests() |
0 commit comments