Skip to content

Commit e370f0a

Browse files
committed
[inductor] add lookup table recorder
\# Why make it easier for users to generate lookup tables and use them \# What - infrastructure to record lookup tables from autotuning - sample implementation for recording to directory - sample implementation for emitting to log.debug (key/value) \## caveats - right now it just records mm_templates and everything that inductor considers. There are some architectural changes needed to make it record e.g. a topk after autotuning. once that is ready, this is modular enough to adjust to recording only topk, however there is value now in being able to record a simple table, see the format, and manually edit it down to the topk entries using the autotuning logs \# Testing using ``` \#!/bin/bash \# Create a temporary directory for the lookup table dumps TEMP_DIR=$(mktemp -d) echo "Created temporary directory for lookup table dumps: $TEMP_DIR" \# Set environment variables to enable verbose output and recording export TORCH_LOGS="+inductor" export TORCH_INDUCTOR_LOOKUP_TABLE_RECORD_DIR=$TEMP_DIR export PYTORCH_DEBUG=1 \# Run the Python script python3 -c " import os import torch import logging from torch._inductor import config as inductor_config from torch._inductor.lookup_table_recorder import dump \# Configure logging to see emit messages logging.basicConfig(level=logging.DEBUG) \# Enable TMA for matmul inductor_config.triton.enable_persistent_tma_matmul = True \# Create large tensors with bfloat16 dtype print('Creating 1024x1024 bfloat16 tensors for matrix multiplication...') a = torch.randn(1024, 1024, device='cuda', dtype=torch.bfloat16) b = torch.randn(1024, 1024, device='cuda', dtype=torch.bfloat16) \# Compile and run the matrix multiplication print('Compiling and running torch.mm with TMA...') compiled_mm = torch.compile(torch.mm, mode='max-autotune') result = compiled_mm(a, b) \# Force synchronization to ensure compilation is complete torch.cuda.synchronize() \# Dump the lookup table print('Dumping lookup table...') dump() print('\\nMatrix multiplication completed successfully!') " 2>&1 | tee /tmp/recorder_output.log \# Check if emit logic works by grepping for LookupTable entries echo -e "\n\n=== CHECKING EMIT FUNCTIONALITY ===" if grep -q "LookupTable:" /tmp/recorder_output.log; then echo "✅ Emit functionality is working! Found LookupTable entries in the log." else echo "❌ Emit functionality not detected. No LookupTable entries found in the log." fi \# Display the dumped lookup table echo -e "\n\n=== DUMPED LOOKUP TABLE CONTENTS ===" LATEST_JSON=$(ls -t $TEMP_DIR/inductor_lut_*.json | head -1) if [ -f "$LATEST_JSON" ]; then echo "Found lookup table file: $LATEST_JSON" echo "File size: $(du -h $LATEST_JSON | cut -f1)" echo -e "\nFirst 20 lines of the lookup table:" head -n 20 $LATEST_JSON # Check for TMA entries echo -e "\n=== CHECKING FOR TMA ENTRIES ===" if grep -q "tma\|TMA_SIZE\|NUM_SMS" $LATEST_JSON; then echo "✅ TMA entries found in the lookup table!" echo -e "\nSample TMA entry:" grep -m 1 -A 10 -B 2 "tma\|TMA_SIZE\|NUM_SMS" $LATEST_JSON else echo "❌ No TMA entries found in the lookup table." fi else echo "❌ No lookup table JSON file found in $TEMP_DIR" fi echo -e "\n\nLookup table files are available in: $TEMP_DIR" echo "Log file is available at: /tmp/recorder_output.log" ``` ``` === CHECKING EMIT FUNCTIONALITY === ✅ Emit functionality is working! Found LookupTable entries in the log. === DUMPED LOOKUP TABLE CONTENTS === Found lookup table file: /tmp/tmp.L9pydR3sH4/inductor_lut_20250723_221836_641.json File size: 12K First 20 lines of the lookup table: { "NVIDIA H100+mm+((torch.bfloat16, [1024, 1024], [1024, 1]), (torch.bfloat16, [1024, 1024], [1024, 1]))": [ { "template_id": "mm", "EVEN_K": true, "ALLOW_TF32": false, "USE_FAST_ACCUM": false, "ACC_TYPE": "tl.float32", "num_stages": 1, "num_warps": 2, "BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 16, "hint_override": null, "GROUP_M": 8, "template_hash": "0717af5834e39dcca7ea817f896b8d85b4886422da7a3ab5f6911b4cfe568896" }, { "template_id": "mm", "EVEN_K": true, === CHECKING FOR TMA ENTRIES === ✅ TMA entries found in the lookup table! Sample TMA entry: }, { "template_id": "mm_persistent_tma", "EVEN_K": true, "ALLOW_TF32": false, "USE_FAST_ACCUM": false, "ACC_TYPE": "tl.float32", "num_stages": 3, "num_warps": 8, "BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "hint_override": null, Lookup table files are available in: /tmp/tmp.L9pydR3sH4 Log file is available at: /tmp/recorder_output.log ``` ghstack-source-id: da0b7c4 Pull Request resolved: #158987
1 parent 203cacb commit e370f0a

File tree

4 files changed

+739
-11
lines changed

4 files changed

+739
-11
lines changed

0 commit comments

Comments
 (0)