-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Allow exposing more functions during initial template expansion #159554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
from concurrent.futures import as_completed, ThreadPoolExecutor | ||
from io import StringIO | ||
from types import ModuleType | ||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union | ||
from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING, Union | ||
from typing_extensions import Self | ||
from unittest.mock import patch | ||
|
||
|
@@ -167,38 +167,49 @@ class PartialRender: | |
of replacements after the initial render. | ||
""" | ||
|
||
FINALIZED_HOOK: object = object() | ||
HookFn = Callable[[], str] | ||
MaybeHookFn = Union[HookFn, Literal["finalized"]] | ||
|
||
def __init__(self, code, replacement_hooks) -> None: | ||
def __init__(self, code: str, replacement_hooks: dict[str, MaybeHookFn]) -> None: | ||
super().__init__() | ||
self._code = code | ||
self.replacement_hooks = replacement_hooks | ||
self._code: str = code | ||
self.replacement_hooks: dict[str, PartialRender.MaybeHookFn] = replacement_hooks | ||
|
||
@property | ||
def code(self): | ||
def code(self) -> str: | ||
""" | ||
The fully rendered code. Will **error** if any hooks have yet to be | ||
finalized. | ||
""" | ||
remaining_active_hooks = [ | ||
key | ||
for key, fn in self.replacement_hooks.items() | ||
if fn is not self.FINALIZED_HOOK | ||
key for key, fn in self.replacement_hooks.items() if fn != "finalized" | ||
] | ||
assert len(remaining_active_hooks) == 0, ( | ||
f"The following hooks have not yet been finalized:\n {remaining_active_hooks=}" | ||
) | ||
return self._code | ||
|
||
def finalize_hook(self, hook_key: str, strict=True) -> None: | ||
def finalize_hook(self, hook_key: str, strict: bool = True) -> None: | ||
""" | ||
Finalize a hook by name. | ||
|
||
:param strict: If ``True``, raise an error if the hook wasn't found. | ||
|
||
NOTE: Will **error** if the hook has already been finalized. | ||
""" | ||
if hook_key not in self.replacement_hooks: | ||
if strict: | ||
raise RuntimeError( | ||
f"{hook_key} not registered in self.replacement_hooks" | ||
) | ||
else: | ||
return | ||
assert self.replacement_hooks[hook_key] is not self.FINALIZED_HOOK, ( | ||
"hook_key can only be called once" | ||
) | ||
self._code = self._code.replace(hook_key, self.replacement_hooks[hook_key]()) | ||
self.replacement_hooks[hook_key] = self.FINALIZED_HOOK | ||
|
||
hook = self.replacement_hooks[hook_key] | ||
assert hook != "finalized", "hook_key can only be called once" | ||
self._code = self._code.replace(hook_key, hook()) | ||
|
||
self.replacement_hooks[hook_key] = "finalized" | ||
|
||
def finalize_remaining(self) -> str: | ||
""" | ||
|
@@ -209,11 +220,17 @@ def finalize_remaining(self) -> str: | |
finalize active hooks. | ||
""" | ||
for key, fn in self.replacement_hooks.items(): | ||
if fn is not self.FINALIZED_HOOK: | ||
if fn != "finalized": | ||
self.finalize_hook(key) | ||
return self.code | ||
|
||
def finalize_all(self) -> str: | ||
""" | ||
Finalize all active hooks. | ||
|
||
NOTE: unlike ``finalize_remaining``, this method will **error** if any | ||
hook has already been finalized. | ||
""" | ||
for key in self.replacement_hooks: | ||
self.finalize_hook(key) | ||
return self.code | ||
|
@@ -432,6 +449,9 @@ def __init__( | |
# by adding all inputs. | ||
self.prologue_loads_all_inputs = prologue_loads_all_inputs | ||
|
||
# Extra functions to be exposed during partial template rendering. | ||
self.extra_template_env_fns: list[Callable[..., Any]] = [] | ||
|
||
def input_dependent_preserved_state(self) -> str: | ||
# Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit. | ||
# (never accessed). | ||
|
@@ -603,8 +623,7 @@ def hook(): | |
arg_defs, *_ = self.args.python_argdefs() | ||
return f"{', '.join(x.full_name() for x in arg_defs)}" | ||
|
||
self.render_hooks["<ARGDEFS>"] = hook | ||
return "<ARGDEFS>" | ||
return self._register_hook("<ARGDEFS>", hook, allow_overwriting=True) | ||
|
||
def gen_defines(self): | ||
return self.defines | ||
|
@@ -682,9 +701,7 @@ def hook(): | |
code.splice(renames.getvalue()) | ||
return code.getvalue() | ||
|
||
assert "<DEF_KERNEL>" not in self.render_hooks | ||
self.render_hooks["<DEF_KERNEL>"] = hook | ||
return "<DEF_KERNEL>" | ||
return self._register_hook("<DEF_KERNEL>", hook) | ||
|
||
def size(self, name: str, index: int): | ||
""" | ||
|
@@ -977,9 +994,7 @@ def hook(): | |
|
||
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() | ||
|
||
assert hook_key not in self.render_hooks | ||
self.render_hooks[hook_key] = hook | ||
return hook_key | ||
return self._register_hook(hook_key, hook) | ||
|
||
def store_output( | ||
self, | ||
|
@@ -1061,9 +1076,48 @@ def hook(): | |
|
||
return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() | ||
|
||
assert "<STORE_OUTPUT>" not in self.render_hooks | ||
self.render_hooks["<STORE_OUTPUT>"] = hook | ||
return "<STORE_OUTPUT>" | ||
return self._register_hook("<STORE_OUTPUT>", hook) | ||
|
||
def _register_hook( | ||
self, | ||
hook_name: str, | ||
hook_fn: PartialRender.HookFn, | ||
*, | ||
allow_overwriting: bool = False, | ||
) -> str: | ||
""" | ||
Register a hook function with a name. | ||
|
||
``hook_name`` should match the string that will be replaced via | ||
``hook_fn``, and should not already be in use for a hook. | ||
|
||
If ``allow_overwriting`` is ``False``, will assert that there isn't | ||
currently a registered hook of the same name before registering the new | ||
one. | ||
""" | ||
|
||
if not allow_overwriting: | ||
assert hook_name not in self.render_hooks, ( | ||
f"Tried to register the hook {hook_name} multiple times. If " | ||
"desired, pass allow_overwriting=True to _register_hook" | ||
) | ||
self.render_hooks[hook_name] = hook_fn | ||
return hook_name | ||
|
||
def _register_extra_template_env_fns(self, *fns: Callable[..., Any]): | ||
""" | ||
Register some extra functions to expose when performing the initial | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you explain more what "to expose" means here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i guess you could say it means "to make available to be used by jinja expressions": jinja template strings can include expressions inside double curly brackets, and these expressions can include things like calls to python functions. the shall i add more info to the comment along these lines? |
||
template render. | ||
|
||
These can be used to, for example, implement extra replacement hooks, | ||
if the given function: | ||
|
||
* Returns the name of their hook, which should also be the string to | ||
replace via the hook function. The convention is to use the format | ||
<HOOK_NAME>. | ||
* Assigns the corresponding entry in ``self.render_hooks`` to a hook function. | ||
""" | ||
self.extra_template_env_fns.extend(fns) | ||
|
||
def render(self, template, kwargs, record_input_dependent_tracked_event=False): | ||
if record_input_dependent_tracked_event: | ||
|
@@ -1083,6 +1137,7 @@ def render(self, template, kwargs, record_input_dependent_tracked_event=False): | |
self.modification, | ||
self.gen_argdefs, | ||
self.gen_defines, | ||
*self.extra_template_env_fns, | ||
] | ||
} | ||
return PartialRender( | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you consider using None as proxy for finalized?
MaybeHookFn being Optional[HookFn]
you can use FINALIZED=None.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, to be honest i kept it as something other than
None
because it actually used to beNone
but got changed to something else; tbh i'm not totally sure the reason for the change (i can't find any references toFINALIZED_HOOK
outside ofPartialRender
). i just made it a literal cause i couldn't see a way to type-annotate the rawobject()
nicely.