Skip to content

gh-137530: generate an __annotate__ function for dataclasses __init__ #137711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,11 @@ def __init__(self, globals):
self.locals = {}
self.overwrite_errors = {}
self.unconditional_adds = {}
self.method_annotations = {}

def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
overwrite_error=False, unconditional_add=False, decorator=None):
overwrite_error=False, unconditional_add=False, decorator=None,
annotation_fields=None):
if locals is not None:
self.locals.update(locals)

Expand All @@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,

self.names.append(name)

if return_type is not MISSING:
self.locals[f'__dataclass_{name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{name}_return_type__'
else:
return_annotation = ''
if annotation_fields is not None:
self.method_annotations[name] = (annotation_fields, return_type)

args = ','.join(args)
body = '\n'.join(body)

# Compute the text of the entire function, add it to the text we're generating.
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')

def add_fns_to_class(self, cls):
# The source to all of the functions we're generating.
Expand Down Expand Up @@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
# Now that we've generated the functions, assign them into cls.
for name, fn in zip(self.names, fns):
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"

try:
annotation_fields, return_type = self.method_annotations[name]
except KeyError:
pass
else:
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
fn.__annotate__ = annotate_fn

if self.unconditional_adds.get(name, False):
setattr(cls, name, fn)
else:
Expand All @@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
raise TypeError(error_msg)


def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
# Create an __annotate__ function for a dataclass
# Try to return annotations in the same format as they would be
# from a regular __init__ function

def __annotate__(format, /):
Format = annotationlib.Format
match format:
case Format.VALUE | Format.FORWARDREF | Format.STRING:
cls_annotations = {}
for base in reversed(__class__.__mro__):
cls_annotations.update(
annotationlib.get_annotations(base, format=format)
)

new_annotations = {}
for k in annotation_fields:
new_annotations[k] = cls_annotations[k]

if return_type is not MISSING:
if format == Format.STRING:
new_annotations["return"] = annotationlib.type_repr(return_type)
else:
new_annotations["return"] = return_type

return new_annotations

case _:
raise NotImplementedError(format)

# This is a flag for _add_slots to know it needs to regenerate this method
# In order to remove references to the original class when it is replaced
__annotate__._generated_by_dataclasses = True
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"

return __annotate__


def _field_assign(frozen, name, value, self_name):
# If we're a frozen class, then assign to our fields in __init__
# via object.__setattr__. Otherwise, just use a simple
Expand Down Expand Up @@ -612,7 +659,7 @@ def _init_param(f):
elif f.default_factory is not MISSING:
# There's a factory function. Set a marker.
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
return f'{f.name}:__dataclass_type_{f.name}__{default}'
return f'{f.name}{default}'


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
Expand All @@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}')

locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object,
}
}
annotation_fields = [f.name for f in fields if f.init]

locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object}

body_lines = []
for f in fields:
Expand Down Expand Up @@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
[self_name] + _init_params,
body_lines,
locals=locals,
return_type=None)
return_type=None,
annotation_fields=annotation_fields)


def _frozen_get_del_attr(cls, fields, func_builder):
Expand Down Expand Up @@ -1336,6 +1383,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
break

# Get new annotations to remove references to the original class
# in forward references
newcls_ann = annotationlib.get_annotations(
newcls, format=annotationlib.Format.FORWARDREF)

# Fix references in dataclass Fields
for f in getattr(newcls, _FIELDS).values():
try:
ann = newcls_ann[f.name]
except KeyError:
pass
else:
f.type = ann

# Fix the class reference in the __annotate__ method
init_annotate = newcls.__init__.__annotate__
if getattr(init_annotate, "_generated_by_dataclasses", False):
_update_func_cell_for__class__(init_annotate, cls, newcls)

return newcls


Expand Down
136 changes: 135 additions & 1 deletion Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,132 @@ def __init__(self, a):
self.assertEqual(D(5).a, 10)


class TestInitAnnotate(unittest.TestCase):
# Tests for the generated __annotate__ function for __init__
# See: https://github.com/python/cpython/issues/137530

def test_annotate_function(self):
# No forward references
@dataclass
class A:
a: int

value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'a': int, 'return': None})
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})

self.assertTrue(getattr(A.__init__.__annotate__, "_generated_by_dataclasses"))

def test_annotate_function_forwardref(self):
# With forward references
@dataclass
class B:
b: undefined

# VALUE annotations should raise while unresolvable
with self.assertRaises(NameError):
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)

forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
undefined = int

value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'b': int, 'return': None})
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

def test_annotate_function_init_false(self):
# Check `init=False` attributes don't get into the annotations of the __init__ function
@dataclass
class C:
c: str = field(init=False)

self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})

def test_annotate_function_contains_forwardref(self):
# Check string annotations on objects containing a ForwardRef
@dataclass
class D:
d: list[undefined]

with self.assertRaises(NameError):
annotationlib.get_annotations(D.__init__)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)

# Now test when it is defined
undefined = str

# VALUE should now resolve
self.assertEqual(
annotationlib.get_annotations(D.__init__),
{"d": list[str], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[str], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)

def test_annotate_function_not_replaced(self):
# Check that __annotate__ is not replaced on non-generated __init__ functions
@dataclass(slots=True)
class E:
x: str
def __init__(self, x: int) -> None:
self.x = x

self.assertEqual(
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
)

self.assertFalse(hasattr(E.__init__.__annotate__, "_generated_by_dataclasses"))

def test_init_false_forwardref(self):
# Currently this raises a NameError even though the ForwardRef
# is not in the __init__ method

@dataclass
class F:
not_in_init: list[undefined] = field(init=False, default=None)
in_init: int

annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
self.assertEqual(
annos,
{"in_init": int, "return": None},
)

with self.assertRaises(NameError):
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init


class TestRepr(unittest.TestCase):
def test_repr(self):
@dataclass
Expand Down Expand Up @@ -3831,7 +3957,15 @@ def method(self) -> int:

return SlotsTest

for make in (make_simple, make_with_annotations, make_with_annotations_and_method):
def make_with_forwardref():
@dataclass(slots=True)
class SlotsTest:
x: undefined
y: list[undefined]

return SlotsTest

for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
with self.subTest(make=make):
C = make()
support.gc_collect()
Expand Down
Loading