Skip to content

gh-132805: annotationlib: Fix handling of non-constant values in FORWARDREF #132812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
175 changes: 132 additions & 43 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class Format(enum.IntEnum):
"__weakref__",
"__arg__",
"__globals__",
"__extra_names__",
"__code__",
"__ast_node__",
"__cell__",
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
# is created through __class__ assignment on a _Stringifier object.
self.__globals__ = None
self.__cell__ = None
self.__extra_names__ = None
# These are initially None but serve as a cache and may be set to a non-None
# value later.
self.__code__ = None
Expand Down Expand Up @@ -151,6 +153,8 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None):
if not self.__forward_is_class__ or param_name not in globals:
globals[param_name] = param
locals.pop(param_name, None)
if self.__extra_names__:
locals = {**locals, **self.__extra_names__}

arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg):
Expand Down Expand Up @@ -231,6 +235,10 @@ def __eq__(self, other):
and self.__forward_is_class__ == other.__forward_is_class__
and self.__cell__ == other.__cell__
and self.__owner__ == other.__owner__
and (
(tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) ==
(tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None)
)
)

def __hash__(self):
Expand All @@ -241,6 +249,7 @@ def __hash__(self):
self.__forward_is_class__,
self.__cell__,
self.__owner__,
tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None,
))

def __or__(self, other):
Expand Down Expand Up @@ -274,6 +283,7 @@ def __init__(
cell=None,
*,
stringifier_dict,
extra_names=None,
):
# Either an AST node or a simple str (for the common case where a ForwardRef
# represent a single name).
Expand All @@ -285,49 +295,91 @@ def __init__(
self.__code__ = None
self.__ast_node__ = node
self.__globals__ = globals
self.__extra_names__ = extra_names
self.__cell__ = cell
self.__owner__ = owner
self.__stringifier_dict__ = stringifier_dict

def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
if isinstance(other.__ast_node__, str):
return ast.Name(id=other.__ast_node__)
return other.__ast_node__
elif isinstance(other, slice):
return ast.Name(id=other.__ast_node__), other.__extra_names__
return other.__ast_node__, other.__extra_names__
elif (
# In STRING format we don't bother with the create_unique_name() dance;
# it's better to emit the repr() of the object instead of an opaque name.
self.__stringifier_dict__.format == Format.STRING
or other is None
or type(other) in (str, int, float, bool, complex)
):
return ast.Constant(value=other), None
elif type(other) is dict:
extra_names = {}
keys = []
values = []
for key, value in other.items():
new_key, new_extra_names = self.__convert_to_ast(key)
if new_extra_names is not None:
extra_names.update(new_extra_names)
keys.append(new_key)
new_value, new_extra_names = self.__convert_to_ast(value)
if new_extra_names is not None:
extra_names.update(new_extra_names)
values.append(new_value)
return ast.Dict(keys, values), extra_names
elif type(other) in (list, tuple, set):
extra_names = {}
elts = []
for elt in other:
new_elt, new_extra_names = self.__convert_to_ast(elt)
if new_extra_names is not None:
extra_names.update(new_extra_names)
elts.append(new_elt)
ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)]
return ast_class(elts), extra_names
else:
name = self.__stringifier_dict__.create_unique_name()
return ast.Name(id=name), {name: other}

def __convert_to_ast_getitem(self, other):
if isinstance(other, slice):
extra_names = {}

def conv(obj):
if obj is None:
return None
new_obj, new_extra_names = self.__convert_to_ast(obj)
if new_extra_names is not None:
extra_names.update(new_extra_names)
return new_obj

return ast.Slice(
lower=(
self.__convert_to_ast(other.start)
if other.start is not None
else None
),
upper=(
self.__convert_to_ast(other.stop)
if other.stop is not None
else None
),
step=(
self.__convert_to_ast(other.step)
if other.step is not None
else None
),
)
lower=conv(other.start),
upper=conv(other.stop),
step=conv(other.step),
), extra_names
else:
return ast.Constant(value=other)
return self.__convert_to_ast(other)

def __get_ast(self):
node = self.__ast_node__
if isinstance(node, str):
return ast.Name(id=node)
return node

def __make_new(self, node):
def __make_new(self, node, extra_names=None):
new_extra_names = {}
if self.__extra_names__ is not None:
new_extra_names.update(self.__extra_names__)
if extra_names is not None:
new_extra_names.update(extra_names)
stringifier = _Stringifier(
node,
self.__globals__,
self.__owner__,
self.__forward_is_class__,
stringifier_dict=self.__stringifier_dict__,
extra_names=new_extra_names or None,
)
self.__stringifier_dict__.stringifiers.append(stringifier)
return stringifier
Expand All @@ -343,27 +395,37 @@ def __getitem__(self, other):
if self.__ast_node__ == "__classdict__":
raise KeyError
if isinstance(other, tuple):
elts = [self.__convert_to_ast(elt) for elt in other]
extra_names = {}
elts = []
for elt in other:
new_elt, new_extra_names = self.__convert_to_ast_getitem(elt)
if new_extra_names is not None:
extra_names.update(new_extra_names)
elts.append(new_elt)
other = ast.Tuple(elts)
else:
other = self.__convert_to_ast(other)
other, extra_names = self.__convert_to_ast_getitem(other)
assert isinstance(other, ast.AST), repr(other)
return self.__make_new(ast.Subscript(self.__get_ast(), other))
return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names)

def __getattr__(self, attr):
return self.__make_new(ast.Attribute(self.__get_ast(), attr))

def __call__(self, *args, **kwargs):
return self.__make_new(
ast.Call(
self.__get_ast(),
[self.__convert_to_ast(arg) for arg in args],
[
ast.keyword(key, self.__convert_to_ast(value))
for key, value in kwargs.items()
],
)
)
extra_names = {}
ast_args = []
for arg in args:
new_arg, new_extra_names = self.__convert_to_ast(arg)
if new_extra_names is not None:
extra_names.update(new_extra_names)
ast_args.append(new_arg)
ast_kwargs = []
for key, value in kwargs.items():
new_value, new_extra_names = self.__convert_to_ast(value)
if new_extra_names is not None:
extra_names.update(new_extra_names)
ast_kwargs.append(ast.keyword(key, new_value))
return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names)

def __iter__(self):
yield self.__make_new(ast.Starred(self.__get_ast()))
Expand All @@ -378,8 +440,9 @@ def __format__(self, format_spec):

def _make_binop(op: ast.AST):
def binop(self, other):
rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other))
ast.BinOp(self.__get_ast(), op, rhs), extra_names
)

return binop
Expand All @@ -402,8 +465,9 @@ def binop(self, other):

def _make_rbinop(op: ast.AST):
def rbinop(self, other):
new_other, extra_names = self.__convert_to_ast(other)
return self.__make_new(
ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast())
ast.BinOp(new_other, op, self.__get_ast()), extra_names
)

return rbinop
Expand All @@ -426,12 +490,14 @@ def rbinop(self, other):

def _make_compare(op):
def compare(self, other):
rhs, extra_names = self.__convert_to_ast(other)
return self.__make_new(
ast.Compare(
left=self.__get_ast(),
ops=[op],
comparators=[self.__convert_to_ast(other)],
)
comparators=[rhs],
),
extra_names,
)

return compare
Expand Down Expand Up @@ -459,13 +525,15 @@ def unary_op(self):


class _StringifierDict(dict):
def __init__(self, namespace, globals=None, owner=None, is_class=False):
def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format):
super().__init__(namespace)
self.namespace = namespace
self.globals = globals
self.owner = owner
self.is_class = is_class
self.stringifiers = []
self.next_id = 1
self.format = format

def __missing__(self, key):
fwdref = _Stringifier(
Expand All @@ -478,6 +546,11 @@ def __missing__(self, key):
self.stringifiers.append(fwdref)
return fwdref

def create_unique_name(self):
name = f"__annotationlib_name_{self.next_id}__"
self.next_id += 1
return name


def call_evaluate_function(evaluate, format, *, owner=None):
"""Call an evaluate function. Evaluate functions are normally generated for
Expand Down Expand Up @@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# possibly constants if the annotate function uses them directly). We then
# convert each of those into a string to get an approximation of the
# original source.
globals = _StringifierDict({})
globals = _StringifierDict({}, format=format)
if annotate.__closure__:
freevars = annotate.__code__.co_freevars
new_closure = []
Expand All @@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
)
annos = func(Format.VALUE_WITH_FAKE_GLOBALS)
if _is_evaluate:
return annos if isinstance(annos, str) else repr(annos)
return _stringify_single(annos)
return {
key: val if isinstance(val, str) else repr(val)
key: _stringify_single(val)
for key, val in annos.items()
}
elif format == Format.FORWARDREF:
Expand All @@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
# that returns a bool and an defined set of attributes.
namespace = {**annotate.__builtins__, **annotate.__globals__}
is_class = isinstance(owner, type)
globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class)
globals = _StringifierDict(
namespace,
globals=annotate.__globals__,
owner=owner,
is_class=is_class,
format=format,
)
if annotate.__closure__:
freevars = annotate.__code__.co_freevars
new_closure = []
Expand Down Expand Up @@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False):
raise ValueError(f"Invalid format: {format!r}")


def _stringify_single(anno):
if anno is ...:
return "..."
# We have to handle str specially to support PEP 563 stringified annotations.
elif isinstance(anno, str):
return anno
else:
return repr(anno)


def get_annotate_function(obj):
"""Get the __annotate__ function for an object.

Expand Down
Loading
Loading