Skip to content

Commit 132c74b

Browse files
committed
gh-89687: fix get_type_hints with dataclasses __init__ generation
1 parent 1481384 commit 132c74b

File tree

5 files changed

+180
-3
lines changed

5 files changed

+180
-3
lines changed

Lib/dataclasses.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,25 @@ def _field_assign(frozen, name, value, self_name):
536536
return f' {self_name}.{name}={value}'
537537

538538

539-
def _field_init(f, frozen, globals, self_name, slots):
539+
def _field_init(f, frozen, globals, self_name, slots, module):
540540
# Return the text of the line in the body of __init__ that will
541541
# initialize this field.
542542

543+
if f.init and isinstance(f.type, str):
544+
from typing import ForwardRef # `typing` is a heavy import
545+
# We need to resolve this string type into a real `ForwardRef` object,
546+
# because otherwise we might end up with unsolvable annotations.
547+
# For example:
548+
# def __init__(self, d: collections.OrderedDict) -> None:
549+
# We won't be able to resolve `collections.OrderedDict`
550+
# with wrong `module` param, when placed in a different module. #45524
551+
try:
552+
f.type = ForwardRef(f.type, module=module, is_class=True)
553+
except SyntaxError:
554+
# We don't want to fail class creation
555+
# when `ForwardRef` cannot be constructed.
556+
pass
557+
543558
default_name = f'__dataclass_dflt_{f.name}__'
544559
if f.default_factory is not MISSING:
545560
if f.init:
@@ -616,7 +631,7 @@ def _init_param(f):
616631

617632

618633
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
619-
self_name, func_builder, slots):
634+
self_name, func_builder, slots, module):
620635
# fields contains both real fields and InitVar pseudo-fields.
621636

622637
# Make sure we don't have fields without defaults following fields
@@ -643,7 +658,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
643658

644659
body_lines = []
645660
for f in fields:
646-
line = _field_init(f, frozen, locals, self_name, slots)
661+
line = _field_init(f, frozen, locals, self_name, slots, module)
647662
# line is None means that this field doesn't require
648663
# initialization (it's a pseudo-field). Just skip it.
649664
if line:
@@ -1093,6 +1108,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10931108
else 'self',
10941109
func_builder,
10951110
slots,
1111+
cls.__module__,
10961112
)
10971113

10981114
_set_new_attribute(cls, '__replace__', _replace)

Lib/test/test_dataclasses/__init__.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4140,6 +4140,130 @@ def test_text_annotations(self):
41404140
{'foo': dataclass_textanno.Foo,
41414141
'return': type(None)})
41424142

4143+
def test_dataclass_from_another_module(self):
4144+
# see bpo-45524
4145+
from test.test_dataclasses import dataclass_textanno
4146+
from dataclasses import dataclass
4147+
4148+
@dataclass
4149+
class Default(dataclass_textanno.Bar):
4150+
pass
4151+
4152+
@dataclass(init=False)
4153+
class WithInitFalse(dataclass_textanno.Bar):
4154+
pass
4155+
4156+
@dataclass(init=False)
4157+
class CustomInit(dataclass_textanno.Bar):
4158+
def __init__(self, foo: dataclass_textanno.Foo) -> None:
4159+
pass
4160+
4161+
@dataclass
4162+
class FutureInitChild(dataclass_textanno.WithFutureInit):
4163+
pass
4164+
4165+
classes = [
4166+
Default,
4167+
WithInitFalse,
4168+
CustomInit,
4169+
dataclass_textanno.WithFutureInit,
4170+
FutureInitChild,
4171+
]
4172+
for klass in classes:
4173+
with self.subTest(klass=klass):
4174+
self.assertEqual(
4175+
get_type_hints(klass),
4176+
{'foo': dataclass_textanno.Foo},
4177+
)
4178+
self.assertEqual(get_type_hints(klass.__new__), {})
4179+
self.assertEqual(
4180+
get_type_hints(klass.__init__),
4181+
{'foo': dataclass_textanno.Foo, 'return': type(None)},
4182+
)
4183+
4184+
def test_dataclass_from_proxy_module(self):
4185+
# see bpo-45524
4186+
from test.test_dataclasses import dataclass_textanno
4187+
from test.test_dataclasses import dataclass_textanno2
4188+
from dataclasses import dataclass
4189+
4190+
@dataclass
4191+
class Default(dataclass_textanno2.Child):
4192+
pass
4193+
4194+
@dataclass(init=False)
4195+
class WithInitFalse(dataclass_textanno2.Child):
4196+
pass
4197+
4198+
@dataclass(init=False)
4199+
class CustomInit(dataclass_textanno2.Child):
4200+
def __init__(
4201+
self,
4202+
foo: dataclass_textanno.Foo,
4203+
custom: dataclass_textanno2.Custom,
4204+
) -> None:
4205+
pass
4206+
4207+
@dataclass
4208+
class FutureInitChild(dataclass_textanno2.WithFutureInit):
4209+
pass
4210+
4211+
classes = [
4212+
Default,
4213+
WithInitFalse,
4214+
CustomInit,
4215+
dataclass_textanno2.WithFutureInit,
4216+
FutureInitChild,
4217+
]
4218+
for klass in classes:
4219+
with self.subTest(klass=klass):
4220+
self.assertEqual(
4221+
get_type_hints(klass),
4222+
{
4223+
'foo': dataclass_textanno.Foo,
4224+
'custom': dataclass_textanno2.Custom,
4225+
},
4226+
)
4227+
self.assertEqual(get_type_hints(klass.__new__), {})
4228+
self.assertEqual(
4229+
get_type_hints(klass.__init__),
4230+
{
4231+
'foo': dataclass_textanno.Foo,
4232+
'custom': dataclass_textanno2.Custom,
4233+
'return': type(None),
4234+
},
4235+
)
4236+
4237+
def test_dataclass_proxy_modules_matching_name_override(self):
4238+
# see bpo-45524
4239+
from test.test_dataclasses import dataclass_textanno2
4240+
from dataclasses import dataclass
4241+
4242+
@dataclass
4243+
class Default(dataclass_textanno2.WithMatchingNameOverride):
4244+
pass
4245+
4246+
classes = [
4247+
Default,
4248+
dataclass_textanno2.WithMatchingNameOverride
4249+
]
4250+
for klass in classes:
4251+
with self.subTest(klass=klass):
4252+
self.assertEqual(
4253+
get_type_hints(klass),
4254+
{
4255+
'foo': dataclass_textanno2.Foo,
4256+
},
4257+
)
4258+
self.assertEqual(get_type_hints(klass.__new__), {})
4259+
self.assertEqual(
4260+
get_type_hints(klass.__init__),
4261+
{
4262+
'foo': dataclass_textanno2.Foo,
4263+
'return': type(None),
4264+
},
4265+
)
4266+
41434267

41444268
ByMakeDataClass = make_dataclass('ByMakeDataClass', [('x', int)])
41454269
ManualModuleMakeDataClass = make_dataclass('ManualModuleMakeDataClass',

Lib/test/test_dataclasses/dataclass_textanno.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,9 @@ class Foo:
1010
@dataclasses.dataclass
1111
class Bar:
1212
foo: Foo
13+
14+
15+
@dataclasses.dataclass(init=False)
16+
class WithFutureInit(Bar):
17+
def __init__(self, foo: Foo) -> None:
18+
pass
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
5+
# We need to be sure that `Foo` is not in scope
6+
from test.test_dataclasses import dataclass_textanno
7+
8+
9+
class Custom:
10+
pass
11+
12+
13+
@dataclasses.dataclass
14+
class Child(dataclass_textanno.Bar):
15+
custom: Custom
16+
17+
18+
class Foo: # matching name with `dataclass_testanno.Foo`
19+
pass
20+
21+
22+
@dataclasses.dataclass
23+
class WithMatchingNameOverride(dataclass_textanno.Bar):
24+
foo: Foo # Existing `foo` annotation should be overridden
25+
26+
27+
@dataclasses.dataclass(init=False)
28+
class WithFutureInit(Child):
29+
def __init__(self, foo: dataclass_textanno.Foo, custom: Custom) -> None:
30+
pass
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ``typing.get_type_hints()`` failure on ``@dataclass`` hierarchies in different modules.

0 commit comments

Comments
 (0)