Skip to content

Allow exposing more functions during initial template expansion #159554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 8 additions & 36 deletions test/inductor/test_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch._inductor.ir import FixedLayout
from torch._inductor.select_algorithm import (
autotune_select_algorithm,
PartialRender,
TritonTemplate,
TritonTemplateKernel,
)
Expand Down Expand Up @@ -455,48 +454,21 @@ def test_finalized_subclass_hooks(self):
hook_identifier = "# CUSTOM_HOOK"

class ExtensionTritonTemplateKernel(TritonTemplateKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._register_extra_template_env_fns(
self.custom_hook,
)

def custom_hook(self) -> str:
"""
Custom hook that just returns a test string for
validation
Custom hook that just returns a test string for validation
"""

def hook() -> str:
return hook_identifier

assert "<CUSTOM_HOOK>" not in self.render_hooks
self.render_hooks["<CUSTOM_HOOK>"] = hook
return "<CUSTOM_HOOK>"

def render(
self, template, kwargs, record_input_dependent_tracked_event=False
):
if record_input_dependent_tracked_event:
self.cached_replay_events = []

template_env = {
fn.__name__: self.record_input_dependent_tracked_event()(fn)
if record_input_dependent_tracked_event
else fn
for fn in [
self.def_kernel,
self.size,
self.stride,
self.store_output,
self.load_input,
self.make_load,
self.modification,
self.gen_argdefs,
self.gen_defines,
# This function registers a hook that the scheduler does
# not directly finalize
self.custom_hook,
]
}
return PartialRender(
template.render(**template_env, **kwargs),
self.render_hooks,
)
return self._register_hook("<CUSTOM_HOOK>", hook)

class ExtensionTritonTemplate(TritonTemplate):
kernel_type = ExtensionTritonTemplateKernel
Expand Down
109 changes: 82 additions & 27 deletions torch/_inductor/select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from types import ModuleType
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
from unittest.mock import patch

Expand Down Expand Up @@ -167,38 +167,49 @@ class PartialRender:
of replacements after the initial render.
"""

FINALIZED_HOOK: object = object()
HookFn = Callable[[], str]
MaybeHookFn = Union[HookFn, Literal["finalized"]]

def __init__(self, code, replacement_hooks) -> None:
def __init__(self, code: str, replacement_hooks: dict[str, MaybeHookFn]) -> None:
super().__init__()
self._code = code
self.replacement_hooks = replacement_hooks
self._code: str = code
self.replacement_hooks: dict[str, PartialRender.MaybeHookFn] = replacement_hooks

@property
def code(self):
def code(self) -> str:
"""
The fully rendered code. Will **error** if any hooks have yet to be
finalized.
"""
remaining_active_hooks = [
key
for key, fn in self.replacement_hooks.items()
if fn is not self.FINALIZED_HOOK
key for key, fn in self.replacement_hooks.items() if fn != "finalized"
]
assert len(remaining_active_hooks) == 0, (
f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}"
)
return self._code

def finalize_hook(self, hook_key: str, strict=True) -> None:
def finalize_hook(self, hook_key: str, strict: bool = True) -> None:
"""
Finalize a hook by name.

:param strict: If ``True``, raise an error if the hook wasn't found.

NOTE: Will **error** if the hook has already been finalized.
"""
if hook_key not in self.replacement_hooks:
if strict:
raise RuntimeError(
f"{hook_key} not registered in self.replacement_hooks"
)
else:
return
assert self.replacement_hooks[hook_key] is not self.FINALIZED_HOOK, (
"hook_key can only be called once"
)
self._code = self._code.replace(hook_key, self.replacement_hooks[hook_key]())
self.replacement_hooks[hook_key] = self.FINALIZED_HOOK

hook = self.replacement_hooks[hook_key]
assert hook != "finalized", "hook_key can only be called once"
self._code = self._code.replace(hook_key, hook())

self.replacement_hooks[hook_key] = "finalized"
Copy link
Contributor

@laithsakka laithsakka Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you consider using None as proxy for finalized?
MaybeHookFn being Optional[HookFn]

you can use FINALIZED=None.

Copy link
Contributor Author

@charlie-wt charlie-wt Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, to be honest i kept it as something other than None because it actually used to be None but got changed to something else; tbh i'm not totally sure the reason for the change (i can't find any references to FINALIZED_HOOK outside of PartialRender). i just made it a literal cause i couldn't see a way to type-annotate the raw object() nicely.


def finalize_remaining(self) -> str:
"""
Expand All @@ -209,11 +220,17 @@ def finalize_remaining(self) -> str:
finalize active hooks.
"""
for key, fn in self.replacement_hooks.items():
if fn is not self.FINALIZED_HOOK:
if fn != "finalized":
self.finalize_hook(key)
return self.code

def finalize_all(self) -> str:
"""
Finalize all active hooks.

NOTE: unlike ``finalize_remaining``, this method will **error** if any
hook has already been finalized.
"""
for key in self.replacement_hooks:
self.finalize_hook(key)
return self.code
Expand Down Expand Up @@ -432,6 +449,9 @@ def __init__(
# by adding all inputs.
self.prologue_loads_all_inputs = prologue_loads_all_inputs

# Extra functions to be exposed during partial template rendering.
self.extra_template_env_fns: list[Callable[..., Any]] = []

def input_dependent_preserved_state(self) -> str:
# Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit.
# (never accessed).
Expand Down Expand Up @@ -603,8 +623,7 @@ def hook():
arg_defs, *_ = self.args.python_argdefs()
return f"{', '.join(x.full_name() for x in arg_defs)}"

self.render_hooks["<ARGDEFS>"] = hook
return "<ARGDEFS>"
return self._register_hook("<ARGDEFS>", hook, allow_overwriting=True)

def gen_defines(self):
return self.defines
Expand Down Expand Up @@ -682,9 +701,7 @@ def hook():
code.splice(renames.getvalue())
return code.getvalue()

assert "<DEF_KERNEL>" not in self.render_hooks
self.render_hooks["<DEF_KERNEL>"] = hook
return "<DEF_KERNEL>"
return self._register_hook("<DEF_KERNEL>", hook)

def size(self, name: str, index: int):
"""
Expand Down Expand Up @@ -977,9 +994,7 @@ def hook():

return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()

assert hook_key not in self.render_hooks
self.render_hooks[hook_key] = hook
return hook_key
return self._register_hook(hook_key, hook)

def store_output(
self,
Expand Down Expand Up @@ -1061,9 +1076,48 @@ def hook():

return textwrap.indent(self.body.getvalue(), " " * indent_width).strip()

assert "<STORE_OUTPUT>" not in self.render_hooks
self.render_hooks["<STORE_OUTPUT>"] = hook
return "<STORE_OUTPUT>"
return self._register_hook("<STORE_OUTPUT>", hook)

def _register_hook(
self,
hook_name: str,
hook_fn: PartialRender.HookFn,
*,
allow_overwriting: bool = False,
) -> str:
"""
Register a hook function with a name.

``hook_name`` should match the string that will be replaced via
``hook_fn``, and should not already be in use for a hook.

If ``allow_overwriting`` is ``False``, will assert that there isn't
currently a registered hook of the same name before registering the new
one.
"""

if not allow_overwriting:
assert hook_name not in self.render_hooks, (
f"Tried to register the hook {hook_name} multiple times. If "
"desired, pass allow_overwriting=True to _register_hook"
)
self.render_hooks[hook_name] = hook_fn
return hook_name

def _register_extra_template_env_fns(self, *fns: Callable[..., Any]):
"""
Register some extra functions to expose when performing the initial
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain more what "to expose" means here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess you could say it means "to make available to be used by jinja expressions": jinja template strings can include expressions inside double curly brackets, and these expressions can include things like calls to python functions. the template_env dictionary being passed to template.render inside TritonTemplateKernel.render is laying out what functions you want to be in scope for any expressions in the template being rendered.

shall i add more info to the comment along these lines?

template render.

These can be used to, for example, implement extra replacement hooks,
if the given function:

* Returns the name of their hook, which should also be the string to
replace via the hook function. The convention is to use the format
<HOOK_NAME>.
* Assigns the corresponding entry in ``self.render_hooks`` to a hook function.
"""
self.extra_template_env_fns.extend(fns)

def render(self, template, kwargs, record_input_dependent_tracked_event=False):
if record_input_dependent_tracked_event:
Expand All @@ -1083,6 +1137,7 @@ def render(self, template, kwargs, record_input_dependent_tracked_event=False):
self.modification,
self.gen_argdefs,
self.gen_defines,
*self.extra_template_env_fns,
]
}
return PartialRender(
Expand Down
Loading