From 3b49d337f47cb3aaf40e824a62ba6b9e8bb1ff6e Mon Sep 17 00:00:00 2001 From: James Wu Date: Fri, 6 Jun 2025 17:59:59 -0700 Subject: [PATCH] Testing ground [ghstack-poisoned] --- test/dynamo/test_package.py | 67 +++++++----------------- torch/_dynamo/package.py | 80 +++++++++++++++++++++++++++-- torch/_dynamo/precompile_context.py | 13 ++--- 3 files changed, 101 insertions(+), 59 deletions(-) diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py index 0ec0ee0ed568..a7f7779559c7 100644 --- a/test/dynamo/test_package.py +++ b/test/dynamo/test_package.py @@ -9,48 +9,18 @@ import torch._inductor.test_case import torch.onnx.operators import torch.utils.cpp_extension -from torch._dynamo.package import _CompilePackage +from torch._dynamo.package import _CompilePackage, DynamoStore from torch._inductor.runtime.runtime_utils import cache_dir -class TestStorageContext: - def __init__(self, path: str): - self.path = path - self.backends = {} - - def _write_pickle(self, data, *path: str): - with open(os.path.join(self.path, *path) + ".pickle", "wb") as f: - pickle.dump(data, f) - - def write_dynamo(self, dynamo): - self._write_pickle(dynamo, "dynamo") - - def write_backend(self, backend_id): - os.makedirs(os.path.join(self.path, backend_id), exist_ok=True) - self._write_pickle(self.backends[backend_id], backend_id, "fx_graph") - - def _read_pickle(self, *path): - with open(os.path.join(self.path, *path) + ".pickle", "rb") as f: - return pickle.load(f) - - def read_backend(self, backend_id): - return self._read_pickle(backend_id, "fx_graph") - - def read_dynamo(self): - return self._read_pickle("dynamo") - - def add_backend(self, backend_id, backend): - self.backends[backend_id] = backend - - class TestPackage(torch._inductor.test_case.TestCase): - def context(self): + def path(self): path = os.path.join(cache_dir(), f"package_{self.id()}") os.makedirs(path, exist_ok=True) - return TestStorageContext(path) + return path def test_basic_fn(self): - ctx = self.context() + ctx = DynamoStore() def fn(x): return x + 1 @@ -62,8 +32,8 @@ def fn(x): compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) expected = compiled_fn(*args) for backend_id, backend in package.cached_backends.items(): - ctx.add_backend(backend_id, backend) - package.save(ctx) + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() @@ -74,13 +44,13 @@ def fn(x): ): compiled_fn(*args) - package = _CompilePackage(fn, ctx.read_dynamo()) + package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) - package.install(ctx) + package.install(backends) self.assertEqual(expected, compiled_fn(*args)) def test_graph_break_bomb(self): - ctx = self.context() + ctx = DynamoStore() def fn(x, l, r): if l > r: @@ -109,8 +79,8 @@ def guard_filter_fn(guards): for args in args_list: compiled_fn(*args) for backend_id, backend in package.cached_backends.items(): - ctx.add_backend(backend_id, backend) - package.save(ctx) + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() @@ -121,12 +91,11 @@ def guard_filter_fn(guards): "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) - dynamo_artifacts = ctx.read_dynamo() - package = _CompilePackage(fn, dynamo_artifacts) + package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize( backend="eager", package=package, guard_filter_fn=guard_filter_fn )(fn) - package.install(ctx) + package.install(backends) for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) @@ -137,7 +106,7 @@ def guard_filter_fn(guards): compiled_fn(torch.tensor(N), 0, N - 1) def test_dynamic_shape(self): - ctx = self.context() + ctx = DynamoStore() def fn(x): return x + x.shape[0] @@ -154,8 +123,8 @@ def fn(x): compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) compiled_fn(*args) for backend_id, backend in package.cached_backends.items(): - ctx.add_backend(backend_id, backend) - package.save(ctx) + ctx.record_eager_backend(backend_id, backend) + ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() @@ -166,9 +135,9 @@ def fn(x): ): compiled_fn(*args1) - package = _CompilePackage(fn, ctx.read_dynamo()) + package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) - package.install(ctx) + package.install(backends) self.assertEqual(expected1, compiled_fn(*args1)) diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py index edd1af1bcc77..65cf6737698d 100644 --- a/torch/_dynamo/package.py +++ b/torch/_dynamo/package.py @@ -8,6 +8,9 @@ from a different process or host. """ +from abc import ABC, abstractmethod + +import os import contextlib import dataclasses import functools @@ -23,6 +26,8 @@ from typing import Any, NewType, Optional import torch +from torch._dynamo.precompile_context import PrecompileContext, PrecompileCacheArtifact +from torch.compiler._cache import CacheArtifactFactory import torch._inductor.package from .bytecode_transformation import get_code_keys @@ -168,6 +173,14 @@ def merge( merged_codes[py_code] = code return _DynamoCacheEntry(codes=list(merged_codes.values())) +@CacheArtifactFactory.register +class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]): + @staticmethod + def type() -> str: + return "precompile_dynamo" + + def after_deserialization(self) -> _DynamoCacheEntry: + return pickle.loads(self.content) class _CompilePackage: """ @@ -320,7 +333,7 @@ def uninstall(self) -> None: _reset_precompile_entries(self._innermost_fn.__code__) - def install(self, storage_context) -> None: + def install(self, backends: dict[_BackendId, Any]) -> None: """ Sync the package states to the compiled function. This includes the following actions: 1. Clean up the previously installed states. @@ -343,7 +356,11 @@ def install(self, storage_context) -> None: fn = types.FunctionType(code, module.__dict__, function_name) self._install_global(module, function_name, fn) for backend_id in entry.backend_ids: - backend = storage_context.read_backend(backend_id) + if backend_id not in backends: + raise RuntimeError( + f"Backend {backend_id} is not found in the given backends" + ) + backend = backends[backend_id] torch._dynamo.eval_frame.skip_code( innermost_fn(backend).__code__, recursive=True ) @@ -371,8 +388,63 @@ def install(self, storage_context) -> None: def save(self, storage_context) -> None: self.validate() - for backend_id in self.backend_ids: - storage_context.write_backend(backend_id) + # # TODO: this shouldn't be needed + # for backend_id in self.backend_ids: + # storage_context.write_backend(backend_id) storage_context.write_dynamo( _DynamoCacheEntry(codes=list(self._codes.values())) ) + + +@CacheArtifactFactory.register +class EagerCacheArtifact(PrecompileCacheArtifact[Any]): + @staticmethod + def type() -> str: + return "precompile_eager" + + def after_deserialization(self) -> Any: + return pickle.loads(self.content) + + + +class DynamoStore: + + def record_package(self, package: _CompilePackage) -> None: + # records a package to PrecompileContext + cache_entry = _DynamoCacheEntry(codes=list(package._codes.values())) + pickled_result = pickle.dumps(cache_entry) + PrecompileContext.record_artifact(_DynamoCacheArtifact.type(), key=package.source_id, content = pickled_result) + + def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None: + # Records eager fx graphs to PrecompileContext for testing + pickled_result = pickle.dumps(backend) + PrecompileContext.record_artifact(EagerCacheArtifact.type(), key=backend_id, content = pickled_result) + + def save_package(self, package: _CompilePackage, path: str) -> None: + # saves a package to a given path + backend_content = {} + for code in package._codes.values(): + for backend_id in code.backend_ids: + backend_content[backend_id] = PrecompileContext.serialize_artifact_by_key(backend_id) + cache_entry = _DynamoCacheEntry(codes=list(package._codes.values())) + try: + with open(os.path.join(path, "dynamo"), "wb") as dynamo_path: + pickle.dump(cache_entry, dynamo_path) + with open(os.path.join(path, "backends"), "wb") as backend_path: + pickle.dump(backend_content, backend_path) + except Exception as e: + raise RuntimeError(f"Failed to save package to {path}: {e}") + + def load_package(self, fn, path: str) -> tuple[_CompilePackage, dict[_BackendId, Any]]: + # loads a package from a given path and installs proper backends to it + try: + with open(os.path.join(path, "dynamo"), "rb") as dynamo_path: + cache_entry = pickle.load(dynamo_path) + with open(os.path.join(path, "backends"), "rb") as backend_path: + backend_content = pickle.load(backend_path) + except Exception as e: + raise RuntimeError(f"Failed to load package from path {path}: {e}") + for backend_id, backend in backend_content.items(): + backend_content[backend_id] = backend.after_deserialization() + package = _CompilePackage(fn, cache_entry) + return package, backend_content diff --git a/torch/_dynamo/precompile_context.py b/torch/_dynamo/precompile_context.py index 821e0f738e4e..69e8d0fb2b86 100644 --- a/torch/_dynamo/precompile_context.py +++ b/torch/_dynamo/precompile_context.py @@ -1,4 +1,4 @@ -from abc import abstractmethod +from abc import ABC, abstractmethod from collections import defaultdict from itertools import chain from typing import Any, Generic, Optional, TypeVar @@ -23,6 +23,7 @@ T = TypeVar("T") + class PrecompileCacheArtifact(CacheArtifact, Generic[T]): """ Data for each cache artifact that will be serialized and deserialized by @@ -78,7 +79,7 @@ class PrecompileContext(CacheArtifactManager): # This allows us to implement serialize_by_key easily. # On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key # are transferred to _new_cache_artifacts before serialization. - _new_cache_artifacts_by_key: dict[str, list[CacheArtifact]] = defaultdict(list) + _new_cache_artifacts_by_key: dict[str, CacheArtifact] = {} _new_cache_artifacts: CacheArtifactsResult = defaultdict(list) # Keep a seperate seen artifacts list to make avoid unnecessary duplicates # This list will not be cleared between serialize() calls @@ -111,7 +112,7 @@ def record_artifact( artifact = CacheArtifactFactory.encode_create(artifact_type, key, content) if artifact in cls._seen_artifacts: return - cls._new_cache_artifacts_by_key[key].append(artifact) + cls._new_cache_artifacts_by_key[key] = artifact cls._seen_artifacts.add(artifact) @classmethod @@ -120,16 +121,16 @@ def _save_artifacts_by_type(cls) -> None: We normally record artifacts by key, but serialization expects them to be organized by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts """ - for artifact in chain(*cls._new_cache_artifacts_by_key.values()): + for artifact in cls._new_cache_artifacts_by_key.values(): cls._new_cache_artifacts[artifact.__class__.type()].append(artifact) cls._new_cache_artifacts_by_key.clear() @classmethod - def serialize_artifacts_by_key(cls, key: str) -> list[CacheArtifact]: + def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]: """ Serialize all artifacts with the given key returned in a list. """ - return cls._new_cache_artifacts_by_key.get(key, []) + return cls._new_cache_artifacts_by_key.get(key, None) @classmethod def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]: