Skip to content

[Precompile] Integrate PrecompileContext with CompilePackage #155384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: gh/jamesjwu/163/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 18 additions & 49 deletions test/dynamo/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,18 @@
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import _CompilePackage
from torch._dynamo.package import _CompilePackage, DynamoStore
from torch._inductor.runtime.runtime_utils import cache_dir


class TestStorageContext:
def __init__(self, path: str):
self.path = path
self.backends = {}

def _write_pickle(self, data, *path: str):
with open(os.path.join(self.path, *path) + ".pickle", "wb") as f:
pickle.dump(data, f)

def write_dynamo(self, dynamo):
self._write_pickle(dynamo, "dynamo")

def write_backend(self, backend_id):
os.makedirs(os.path.join(self.path, backend_id), exist_ok=True)
self._write_pickle(self.backends[backend_id], backend_id, "fx_graph")

def _read_pickle(self, *path):
with open(os.path.join(self.path, *path) + ".pickle", "rb") as f:
return pickle.load(f)

def read_backend(self, backend_id):
return self._read_pickle(backend_id, "fx_graph")

def read_dynamo(self):
return self._read_pickle("dynamo")

def add_backend(self, backend_id, backend):
self.backends[backend_id] = backend


class TestPackage(torch._inductor.test_case.TestCase):
def context(self):
def path(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return TestStorageContext(path)
return path

def test_basic_fn(self):
ctx = self.context()
ctx = DynamoStore()

def fn(x):
return x + 1
Expand All @@ -62,8 +32,8 @@ def fn(x):
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
expected = compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.add_backend(backend_id, backend)
package.save(ctx)
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())

# Loading
torch._dynamo.reset()
Expand All @@ -74,13 +44,13 @@ def fn(x):
):
compiled_fn(*args)

package = _CompilePackage(fn, ctx.read_dynamo())
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(ctx)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))

def test_graph_break_bomb(self):
ctx = self.context()
ctx = DynamoStore()

def fn(x, l, r):
if l > r:
Expand Down Expand Up @@ -109,8 +79,8 @@ def guard_filter_fn(guards):
for args in args_list:
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.add_backend(backend_id, backend)
package.save(ctx)
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())

# Loading
torch._dynamo.reset()
Expand All @@ -121,12 +91,11 @@ def guard_filter_fn(guards):
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
dynamo_artifacts = ctx.read_dynamo()
package = _CompilePackage(fn, dynamo_artifacts)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(ctx)
package.install(backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())

Expand All @@ -137,7 +106,7 @@ def guard_filter_fn(guards):
compiled_fn(torch.tensor(N), 0, N - 1)

def test_dynamic_shape(self):
ctx = self.context()
ctx = DynamoStore()

def fn(x):
return x + x.shape[0]
Expand All @@ -154,8 +123,8 @@ def fn(x):
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
ctx.add_backend(backend_id, backend)
package.save(ctx)
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())

# Loading
torch._dynamo.reset()
Expand All @@ -166,9 +135,9 @@ def fn(x):
):
compiled_fn(*args1)

package = _CompilePackage(fn, ctx.read_dynamo())
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(ctx)
package.install(backends)

self.assertEqual(expected1, compiled_fn(*args1))

Expand Down
80 changes: 76 additions & 4 deletions torch/_dynamo/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from a different process or host.
"""

from abc import ABC, abstractmethod

import os
import contextlib
import dataclasses
import functools
Expand All @@ -23,6 +26,8 @@
from typing import Any, NewType, Optional

import torch
from torch._dynamo.precompile_context import PrecompileContext, PrecompileCacheArtifact
from torch.compiler._cache import CacheArtifactFactory
import torch._inductor.package

from .bytecode_transformation import get_code_keys
Expand Down Expand Up @@ -142,7 +147,7 @@
python_code=a.python_code,
python_module=a.python_module,
function_names=list({*a.function_names, *b.function_names}),
guarded_codes=merged_codes,

Check failure on line 150 in torch/_dynamo/package.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [arg-type]

Argument "guarded_codes" to "_DynamoCodeCacheEntry" has incompatible type "dict[bytes, _GuardedCodeCacheEntry]"; expected "list[_GuardedCodeCacheEntry]"
import_sources=merged_imports,
backend_ids=list({*a.backend_ids, *b.backend_ids}),
)
Expand All @@ -159,7 +164,7 @@
def merge(
cls, a: "_DynamoCacheEntry", b: "_DynamoCacheEntry"
) -> "_DynamoCacheEntry":
merged_codes = {}

Check failure on line 167 in torch/_dynamo/package.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [var-annotated]

Need type annotation for "merged_codes" (hint: "merged_codes: dict[<type>, <type>] = ...")
for code in itertools.chain(a.codes, b.codes):
py_code = code.python_code
if existing_code := merged_codes.get(py_code):
Expand All @@ -168,6 +173,14 @@
merged_codes[py_code] = code
return _DynamoCacheEntry(codes=list(merged_codes.values()))

@CacheArtifactFactory.register
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
@staticmethod
def type() -> str:
return "precompile_dynamo"

def after_deserialization(self) -> _DynamoCacheEntry:
return pickle.loads(self.content)

class _CompilePackage:
"""
Expand All @@ -183,7 +196,7 @@
updates with compiled functions and resume functions.
"""

def __init__(self, fn, dynamo: Optional[_DynamoCacheEntry] = None):

Check failure on line 199 in torch/_dynamo/package.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a type annotation for one or more arguments
self._innermost_fn = None
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}

Expand All @@ -196,7 +209,7 @@
self._initialize(fn, dynamo)
self.validate()

def _initialize(self, fn, dynamo: Optional[_DynamoCacheEntry] = None):

Check failure on line 212 in torch/_dynamo/package.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a type annotation for one or more arguments

Check failure on line 212 in torch/_dynamo/package.py

View workflow job for this annotation

GitHub Actions / lintrunner-noclang / linux-job

MYPY [no-untyped-def]

Function is missing a return type annotation
from .eval_frame import innermost_fn

self._innermost_fn = innermost_fn(fn)
Expand Down Expand Up @@ -320,7 +333,7 @@

_reset_precompile_entries(self._innermost_fn.__code__)

def install(self, storage_context) -> None:
def install(self, backends: dict[_BackendId, Any]) -> None:
"""
Sync the package states to the compiled function. This includes the following actions:
1. Clean up the previously installed states.
Expand All @@ -343,7 +356,11 @@
fn = types.FunctionType(code, module.__dict__, function_name)
self._install_global(module, function_name, fn)
for backend_id in entry.backend_ids:
backend = storage_context.read_backend(backend_id)
if backend_id not in backends:
raise RuntimeError(
f"Backend {backend_id} is not found in the given backends"
)
backend = backends[backend_id]
torch._dynamo.eval_frame.skip_code(
innermost_fn(backend).__code__, recursive=True
)
Expand Down Expand Up @@ -371,8 +388,63 @@

def save(self, storage_context) -> None:
self.validate()
for backend_id in self.backend_ids:
storage_context.write_backend(backend_id)
# # TODO: this shouldn't be needed
# for backend_id in self.backend_ids:
# storage_context.write_backend(backend_id)
storage_context.write_dynamo(
_DynamoCacheEntry(codes=list(self._codes.values()))
)


@CacheArtifactFactory.register
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
@staticmethod
def type() -> str:
return "precompile_eager"

def after_deserialization(self) -> Any:
return pickle.loads(self.content)



class DynamoStore:

def record_package(self, package: _CompilePackage) -> None:
# records a package to PrecompileContext
cache_entry = _DynamoCacheEntry(codes=list(package._codes.values()))
pickled_result = pickle.dumps(cache_entry)
PrecompileContext.record_artifact(_DynamoCacheArtifact.type(), key=package.source_id, content = pickled_result)

def record_eager_backend(self, backend_id: _BackendId, backend: Any) -> None:
# Records eager fx graphs to PrecompileContext for testing
pickled_result = pickle.dumps(backend)
PrecompileContext.record_artifact(EagerCacheArtifact.type(), key=backend_id, content = pickled_result)

def save_package(self, package: _CompilePackage, path: str) -> None:
# saves a package to a given path
backend_content = {}
for code in package._codes.values():
for backend_id in code.backend_ids:
backend_content[backend_id] = PrecompileContext.serialize_artifact_by_key(backend_id)
cache_entry = _DynamoCacheEntry(codes=list(package._codes.values()))
try:
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
pickle.dump(cache_entry, dynamo_path)
with open(os.path.join(path, "backends"), "wb") as backend_path:
pickle.dump(backend_content, backend_path)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}")

def load_package(self, fn, path: str) -> tuple[_CompilePackage, dict[_BackendId, Any]]:
# loads a package from a given path and installs proper backends to it
try:
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
cache_entry = pickle.load(dynamo_path)
with open(os.path.join(path, "backends"), "rb") as backend_path:
backend_content = pickle.load(backend_path)
except Exception as e:
raise RuntimeError(f"Failed to load package from path {path}: {e}")
for backend_id, backend in backend_content.items():
backend_content[backend_id] = backend.after_deserialization()
package = _CompilePackage(fn, cache_entry)
return package, backend_content
13 changes: 7 additions & 6 deletions torch/_dynamo/precompile_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import chain
from typing import Any, Generic, Optional, TypeVar
Expand All @@ -23,6 +23,7 @@
T = TypeVar("T")



class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
"""
Data for each cache artifact that will be serialized and deserialized by
Expand Down Expand Up @@ -78,7 +79,7 @@ class PrecompileContext(CacheArtifactManager):
# This allows us to implement serialize_by_key easily.
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
# are transferred to _new_cache_artifacts before serialization.
_new_cache_artifacts_by_key: dict[str, list[CacheArtifact]] = defaultdict(list)
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
Expand Down Expand Up @@ -111,7 +112,7 @@ def record_artifact(
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
if artifact in cls._seen_artifacts:
return
cls._new_cache_artifacts_by_key[key].append(artifact)
cls._new_cache_artifacts_by_key[key] = artifact
cls._seen_artifacts.add(artifact)

@classmethod
Expand All @@ -120,16 +121,16 @@ def _save_artifacts_by_type(cls) -> None:
We normally record artifacts by key, but serialization expects them to be organized
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
"""
for artifact in chain(*cls._new_cache_artifacts_by_key.values()):
for artifact in cls._new_cache_artifacts_by_key.values():
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear()

@classmethod
def serialize_artifacts_by_key(cls, key: str) -> list[CacheArtifact]:
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
"""
Serialize all artifacts with the given key returned in a list.
"""
return cls._new_cache_artifacts_by_key.get(key, [])
return cls._new_cache_artifacts_by_key.get(key, None)

@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
Expand Down
Loading