Skip to content

Commit bc379ae

Browse files
Revert "Still run TritonBundler with BundledAOTAutogradCache, save autotune results (#158048)"
This reverts commit 8e57cdb. Reverted #158048 on behalf of https://github.com/jeffdaily due to rocm failures due to unit test introduced in this PR, but no pre-merge signal available ([comment](#158048 (comment)))
1 parent b1a0c34 commit bc379ae

File tree

4 files changed

+3
-79
lines changed

4 files changed

+3
-79
lines changed

test/dynamo/test_package.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
1616
from torch._dynamo.precompile_context import PrecompileContext
1717
from torch._functorch import config as functorch_config
18-
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
1918
from torch._inductor.runtime.runtime_utils import cache_dir
2019
from torch.testing._internal.common_utils import (
2120
instantiate_parametrized_tests,
@@ -429,39 +428,6 @@ def fn2(x):
429428
self.assertEqual(expected, [result1, result2])
430429
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
431430

432-
@parametrize("device", ("cuda", "xpu"))
433-
@torch._dynamo.config.patch(caching_precompile=True)
434-
def test_automatic_dynamo_autotune_cache(self, device):
435-
if device == "cuda" and not HAS_CUDA:
436-
raise unittest.SkipTest("Requires CUDA/Triton")
437-
if device == "xpu" and not HAS_XPU:
438-
raise unittest.SkipTest("Requires XPU/Triton")
439-
440-
def fn(x, y):
441-
return x.sin() + y
442-
443-
arg1 = torch.randn(3, 3, device=device)
444-
arg2 = torch.randn(3, 3, device=device)
445-
expected = fn(arg1, arg2).clone()
446-
447-
with PatchCaches():
448-
compiled_fn1 = torch.compile(fn, mode="max-autotune")
449-
result = compiled_fn1(arg1, arg2).clone()
450-
self.assertEqual(expected, result)
451-
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
452-
DynamoCache.clear()
453-
454-
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
455-
self._save_and_reload(
456-
expected_backends=1, expected_dynamo=1, expected_autotune=1
457-
)
458-
compiled_fn1 = torch.compile(fn, mode="max-autotune")
459-
with torch.compiler.set_stance("fail_on_recompile"):
460-
result1 = compiled_fn1(arg1, arg2).clone()
461-
self.assertEqual(expected, result1)
462-
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
463-
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
464-
465431
@parametrize("device", ("cpu", "cuda", "xpu"))
466432
@torch._dynamo.config.patch(caching_precompile=True)
467433
def test_automatic_dynamo_recompiles(self, device):

torch/_dynamo/precompile_context.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ class PrecompileContext(CacheArtifactManager):
7070
7171
The following artifact types are supported by PrecompileContext:
7272
- BundledAOTAutogradCacheArtifact
73-
- DynamoCodeStateArtifact
74-
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
73+
- CodeStateArtifact (from torch._dynamo.package once available)
7574
"""
7675

7776
# Protected by the compile_lock
@@ -150,12 +149,8 @@ def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
150149
artifacts_by_key = {}
151150
cache_info = CacheInfo()
152151
for artifact in chain(*artifacts.values()):
153-
if artifact.type() == "autotune":
154-
# Populate autotune cache artifacts
155-
artifact.populate_cache()
156-
else:
157-
artifacts_by_key[artifact.key] = artifact
158152
cache_info.add(artifact)
153+
artifacts_by_key[artifact.key] = artifact
159154

160155
from torch._dynamo.package import _BackendId, DynamoCache
161156

torch/_inductor/compile_fx.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -909,37 +909,10 @@ def _compile_fx_inner(
909909
else:
910910
log.debug("Failed to generate FX cache key")
911911

912-
if torch._functorch.config.bundled_autograd_cache:
913-
assert mb_compiled_graph is None
914-
assert cache_info is None
915-
# When using bundled autograd cache, we still want
916-
# to use the TritonBundler, but we don't want to save
917-
# the results here. The results will get saved directly
918-
# to AOTAutogradCache.
919-
TritonBundler.begin_compile()
920-
try:
921-
mb_compiled_graph = fx_codegen_and_compile(
922-
gm, example_inputs, inputs_to_check, **graph_kwargs
923-
)
924-
assert mb_compiled_graph is not None
925-
(
926-
triton_bundle,
927-
triton_bundler_meta,
928-
) = TritonBundler.collect()
929-
mb_compiled_graph.set_triton_bundle(triton_bundle)
930-
except (ShortenTraceback, SkipFrame):
931-
raise
932-
except Exception as e:
933-
raise InductorError(e, currentframe()).with_traceback(
934-
e.__traceback__
935-
) from None
936-
finally:
937-
TritonBundler.end_compile()
938-
939912
# CACHE BYPASS: Compile the graph, don't save it to the cache
940913
# (this can happen either because cache was disabled, or we
941914
# determined the input is uncacheable)
942-
elif cache_info is None or cache_info["cache_state"] == "bypass":
915+
if cache_info is None or cache_info["cache_state"] == "bypass":
943916
assert mb_compiled_graph is None
944917
log.debug(
945918
"FX cache bypass reason: %s",

torch/_inductor/runtime/autotune_cache.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from typing_extensions import override
3636

3737
import torch
38-
from torch._dynamo.precompile_context import PrecompileContext
3938
from torch._inductor.runtime.runtime_utils import cache_dir
4039
from torch.compiler._cache import (
4140
CacheArtifact,
@@ -126,7 +125,6 @@ def create(
126125
) -> Optional[AutotuneCache]:
127126
cache = AutotuneCache(configs_hash)
128127
key = AutotuneCache._prepare_key(filename)
129-
130128
cache._setup_local_cache(inductor_meta, os.path.dirname(filename), key)
131129
cache._setup_remote_autotune_cache(inductor_meta, key)
132130
if cache.local_cache or cache.remote_cache:
@@ -302,10 +300,6 @@ def save(
302300
CacheArtifactManager.record_artifact(
303301
AutotuneCacheArtifact.type(), autotune_artifact_key, data
304302
)
305-
if torch._dynamo.config.caching_precompile:
306-
PrecompileContext.record_artifact(
307-
AutotuneCacheArtifact.type(), autotune_artifact_key, data
308-
)
309303

310304
if log.isEnabledFor(logging.DEBUG):
311305
type_str = "coordesc" if found_by_coordesc else "heuristic"
@@ -631,10 +625,6 @@ def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]:
631625
CacheArtifactManager.record_artifact(
632626
AutotuneCacheArtifact.type(), autotune_artifact_key, result
633627
)
634-
if torch._dynamo.config.caching_precompile:
635-
PrecompileContext.record_artifact(
636-
AutotuneCacheArtifact.type(), autotune_artifact_key, result
637-
)
638628
return result
639629

640630
@override

0 commit comments

Comments
 (0)