diff --git a/test/inductor/test_select_algorithm.py b/test/inductor/test_select_algorithm.py index d2cd77fe5cd29..1b3cbeec1fac3 100644 --- a/test/inductor/test_select_algorithm.py +++ b/test/inductor/test_select_algorithm.py @@ -17,7 +17,6 @@ from torch._inductor.ir import FixedLayout from torch._inductor.select_algorithm import ( autotune_select_algorithm, - PartialRender, TritonTemplate, TritonTemplateKernel, ) @@ -455,48 +454,21 @@ def test_finalized_subclass_hooks(self): hook_identifier = "# CUSTOM_HOOK" class ExtensionTritonTemplateKernel(TritonTemplateKernel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._register_extra_template_env_fns( + self.custom_hook, + ) + def custom_hook(self) -> str: """ - Custom hook that just returns a test string for - validation + Custom hook that just returns a test string for validation """ def hook() -> str: return hook_identifier - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" - - def render( - self, template, kwargs, record_input_dependent_tracked_event=False - ): - if record_input_dependent_tracked_event: - self.cached_replay_events = [] - - template_env = { - fn.__name__: self.record_input_dependent_tracked_event()(fn) - if record_input_dependent_tracked_event - else fn - for fn in [ - self.def_kernel, - self.size, - self.stride, - self.store_output, - self.load_input, - self.make_load, - self.modification, - self.gen_argdefs, - self.gen_defines, - # This function registers a hook that the scheduler does - # not directly finalize - self.custom_hook, - ] - } - return PartialRender( - template.render(**template_env, **kwargs), - self.render_hooks, - ) + return self._register_hook("", hook) class ExtensionTritonTemplate(TritonTemplate): kernel_type = ExtensionTritonTemplateKernel diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 903d616bb91eb..60b7c026260f0 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -17,7 +17,7 @@ from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO from types import ModuleType -from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING, Union from typing_extensions import Self from unittest.mock import patch @@ -167,26 +167,36 @@ class PartialRender: of replacements after the initial render. """ - FINALIZED_HOOK: object = object() + HookFn = Callable[[], str] + MaybeHookFn = Union[HookFn, Literal["finalized"]] - def __init__(self, code, replacement_hooks) -> None: + def __init__(self, code: str, replacement_hooks: dict[str, MaybeHookFn]) -> None: super().__init__() - self._code = code - self.replacement_hooks = replacement_hooks + self._code: str = code + self.replacement_hooks: dict[str, PartialRender.MaybeHookFn] = replacement_hooks @property - def code(self): + def code(self) -> str: + """ + The fully rendered code. Will **error** if any hooks have yet to be + finalized. + """ remaining_active_hooks = [ - key - for key, fn in self.replacement_hooks.items() - if fn is not self.FINALIZED_HOOK + key for key, fn in self.replacement_hooks.items() if fn != "finalized" ] assert len(remaining_active_hooks) == 0, ( f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}" ) return self._code - def finalize_hook(self, hook_key: str, strict=True) -> None: + def finalize_hook(self, hook_key: str, strict: bool = True) -> None: + """ + Finalize a hook by name. + + :param strict: If ``True``, raise an error if the hook wasn't found. + + NOTE: Will **error** if the hook has already been finalized. + """ if hook_key not in self.replacement_hooks: if strict: raise RuntimeError( @@ -194,11 +204,12 @@ def finalize_hook(self, hook_key: str, strict=True) -> None: ) else: return - assert self.replacement_hooks[hook_key] is not self.FINALIZED_HOOK, ( - "hook_key can only be called once" - ) - self._code = self._code.replace(hook_key, self.replacement_hooks[hook_key]()) - self.replacement_hooks[hook_key] = self.FINALIZED_HOOK + + hook = self.replacement_hooks[hook_key] + assert hook != "finalized", "hook_key can only be called once" + self._code = self._code.replace(hook_key, hook()) + + self.replacement_hooks[hook_key] = "finalized" def finalize_remaining(self) -> str: """ @@ -209,11 +220,17 @@ def finalize_remaining(self) -> str: finalize active hooks. """ for key, fn in self.replacement_hooks.items(): - if fn is not self.FINALIZED_HOOK: + if fn != "finalized": self.finalize_hook(key) return self.code def finalize_all(self) -> str: + """ + Finalize all active hooks. + + NOTE: unlike ``finalize_remaining``, this method will **error** if any + hook has already been finalized. + """ for key in self.replacement_hooks: self.finalize_hook(key) return self.code @@ -432,6 +449,9 @@ def __init__( # by adding all inputs. self.prologue_loads_all_inputs = prologue_loads_all_inputs + # Extra functions to be exposed during partial template rendering. + self.extra_template_env_fns: list[Callable[..., Any]] = [] + def input_dependent_preserved_state(self) -> str: # Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit. # (never accessed). @@ -603,8 +623,7 @@ def hook(): arg_defs, *_ = self.args.python_argdefs() return f"{', '.join(x.full_name() for x in arg_defs)}" - self.render_hooks[""] = hook - return "" + return self._register_hook("", hook, allow_overwriting=True) def gen_defines(self): return self.defines @@ -682,9 +701,7 @@ def hook(): code.splice(renames.getvalue()) return code.getvalue() - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" + return self._register_hook("", hook) def size(self, name: str, index: int): """ @@ -977,9 +994,7 @@ def hook(): return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() - assert hook_key not in self.render_hooks - self.render_hooks[hook_key] = hook - return hook_key + return self._register_hook(hook_key, hook) def store_output( self, @@ -1061,9 +1076,48 @@ def hook(): return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() - assert "" not in self.render_hooks - self.render_hooks[""] = hook - return "" + return self._register_hook("", hook) + + def _register_hook( + self, + hook_name: str, + hook_fn: PartialRender.HookFn, + *, + allow_overwriting: bool = False, + ) -> str: + """ + Register a hook function with a name. + + ``hook_name`` should match the string that will be replaced via + ``hook_fn``, and should not already be in use for a hook. + + If ``allow_overwriting`` is ``False``, will assert that there isn't + currently a registered hook of the same name before registering the new + one. + """ + + if not allow_overwriting: + assert hook_name not in self.render_hooks, ( + f"Tried to register the hook {hook_name} multiple times. If " + "desired, pass allow_overwriting=True to _register_hook" + ) + self.render_hooks[hook_name] = hook_fn + return hook_name + + def _register_extra_template_env_fns(self, *fns: Callable[..., Any]): + """ + Register some extra functions to expose when performing the initial + template render. + + These can be used to, for example, implement extra replacement hooks, + if the given function: + + * Returns the name of their hook, which should also be the string to + replace via the hook function. The convention is to use the format + . + * Assigns the corresponding entry in ``self.render_hooks`` to a hook function. + """ + self.extra_template_env_fns.extend(fns) def render(self, template, kwargs, record_input_dependent_tracked_event=False): if record_input_dependent_tracked_event: @@ -1083,6 +1137,7 @@ def render(self, template, kwargs, record_input_dependent_tracked_event=False): self.modification, self.gen_argdefs, self.gen_defines, + *self.extra_template_env_fns, ] } return PartialRender(