Skip to content

gh-89687: fix get_type_hints with dataclasses __init__ generation #137168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def evaluate(

arg = self.__forward_arg__
if arg.isidentifier() and not keyword.iskeyword(arg):
if self.__forward_module__ is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct, since we infer globals from __forward_module__ if no explicit globals are given. And if explicit globals are given, we shouldn't be using __forward_module__.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think current behavior is correct. Was there any discussion?

return getattr(
sys.modules.get(self.__forward_module__, None),
arg,
)
if arg in locals:
return locals[arg]
elif arg in globals:
Expand Down
22 changes: 19 additions & 3 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,25 @@ def _field_assign(frozen, name, value, self_name):
return f' {self_name}.{name}={value}'


def _field_init(f, frozen, globals, self_name, slots):
def _field_init(f, frozen, globals, self_name, slots, module):
# Return the text of the line in the body of __init__ that will
# initialize this field.

if f.init and isinstance(f.type, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this is good, but I'd do it like this:

diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 53b3b54cfb3..e3d25fb0840 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -789,7 +789,10 @@ def _get_field(cls, a_name, a_type, default_kw_only):
 
     # Only at this point do we know the name and the type.  Set them.
     f.name = a_name
-    f.type = a_type
+    if isinstance(a_type, str):
+        f.type = annotationlib.ForwardRef(a_type, owner=cls)
+    else:
+        f.type = a_type
 
     # Assume it's a normal field until proven otherwise.  We're next
     # going to decide if it's a ClassVar or InitVar, everything else

ForwardRef now lives in annotationlib, not typing. There's no need to catch SyntaxError as it doesn't parse its input in the constructor.

from typing import ForwardRef # `typing` is a heavy import
# We need to resolve this string type into a real `ForwardRef` object,
# because otherwise we might end up with unsolvable annotations.
# For example:
# def __init__(self, d: collections.OrderedDict) -> None:
# We won't be able to resolve `collections.OrderedDict`
# with wrong `module` param, when placed in a different module. #45524
try:
f.type = ForwardRef(f.type, module=module, is_class=True)
except SyntaxError:
# We don't want to fail class creation
# when `ForwardRef` cannot be constructed.
pass

default_name = f'__dataclass_dflt_{f.name}__'
if f.default_factory is not MISSING:
if f.init:
Expand Down Expand Up @@ -616,7 +631,7 @@ def _init_param(f):


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
self_name, func_builder, slots):
self_name, func_builder, slots, module):
# fields contains both real fields and InitVar pseudo-fields.

# Make sure we don't have fields without defaults following fields
Expand All @@ -643,7 +658,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,

body_lines = []
for f in fields:
line = _field_init(f, frozen, locals, self_name, slots)
line = _field_init(f, frozen, locals, self_name, slots, module)
# line is None means that this field doesn't require
# initialization (it's a pseudo-field). Just skip it.
if line:
Expand Down Expand Up @@ -1093,6 +1108,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
else 'self',
func_builder,
slots,
cls.__module__,
)

_set_new_attribute(cls, '__replace__', _replace)
Expand Down
124 changes: 124 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4140,6 +4140,130 @@ def test_text_annotations(self):
{'foo': dataclass_textanno.Foo,
'return': type(None)})

def test_dataclass_from_another_module(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno.Bar):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno.Bar):
def __init__(self, foo: dataclass_textanno.Foo) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{'foo': dataclass_textanno.Foo},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{'foo': dataclass_textanno.Foo, 'return': type(None)},
)

def test_dataclass_from_proxy_module(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno
from test.test_dataclasses import dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class WithInitFalse(dataclass_textanno2.Child):
pass

@dataclass(init=False)
class CustomInit(dataclass_textanno2.Child):
def __init__(
self,
foo: dataclass_textanno.Foo,
custom: dataclass_textanno2.Custom,
) -> None:
pass

@dataclass
class FutureInitChild(dataclass_textanno2.WithFutureInit):
pass

classes = [
Default,
WithInitFalse,
CustomInit,
dataclass_textanno2.WithFutureInit,
FutureInitChild,
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno.Foo,
'custom': dataclass_textanno2.Custom,
'return': type(None),
},
)

def test_dataclass_proxy_modules_matching_name_override(self):
# see bpo-45524
from test.test_dataclasses import dataclass_textanno2
from dataclasses import dataclass

@dataclass
class Default(dataclass_textanno2.WithMatchingNameOverride):
pass

classes = [
Default,
dataclass_textanno2.WithMatchingNameOverride
]
for klass in classes:
with self.subTest(klass=klass):
self.assertEqual(
get_type_hints(klass),
{
'foo': dataclass_textanno2.Foo,
},
)
self.assertEqual(get_type_hints(klass.__new__), {})
self.assertEqual(
get_type_hints(klass.__init__),
{
'foo': dataclass_textanno2.Foo,
'return': type(None),
},
)


ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',
Expand Down
6 changes: 6 additions & 0 deletions Lib/test/test_dataclasses/dataclass_textanno.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ class Foo:
@dataclasses.dataclass
class Bar:
foo: Foo


@dataclasses.dataclass(init=False)
class WithFutureInit(Bar):
def __init__(self, foo: Foo) -> None:
pass
30 changes: 30 additions & 0 deletions Lib/test/test_dataclasses/dataclass_textanno2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

import dataclasses

# We need to be sure that `Foo` is not in scope
from test.test_dataclasses import dataclass_textanno


class Custom:
pass


@dataclasses.dataclass
class Child(dataclass_textanno.Bar):
custom: Custom


class Foo: # matching name with `dataclass_testanno.Foo`
pass


@dataclasses.dataclass
class WithMatchingNameOverride(dataclass_textanno.Bar):
foo: Foo # Existing `foo` annotation should be overridden


@dataclasses.dataclass(init=False)
class WithFutureInit(Child):
def __init__(self, foo: dataclass_textanno.Foo, custom: Custom) -> None:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``typing.get_type_hints()`` failure on ``@dataclass`` hierarchies in different modules.
Loading