Skip to content

Commit 83d2236

Browse files
author
Joan Fontanals
authored
feat: enable dynamic doc with Pydantic v2 (docarray#1795)
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
1 parent 92de15e commit 83d2236

File tree

2 files changed

+65
-27
lines changed

2 files changed

+65
-27
lines changed

docarray/utils/create_dynamic_doc_class.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pydantic import BaseModel, create_model
44
from pydantic.fields import FieldInfo
55

6+
from docarray.base_doc.doc import BaseDocWithoutId
67
from docarray import BaseDoc, DocList
78
from docarray.typing import AnyTensor
89
from docarray.utils._internal._typing import safe_issubclass
@@ -50,16 +51,19 @@ class MyDoc(BaseDoc):
5051
:param model: The input model
5152
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
5253
"""
53-
if is_pydantic_v2:
54-
raise NotImplementedError(
55-
'This method is not supported in Pydantic 2.0. Please use Pydantic 1.8.2 or lower.'
56-
)
57-
5854
fields: Dict[str, Any] = {}
59-
for field_name, field in model.__annotations__.items():
60-
if field_name not in model.__fields__:
55+
import copy
56+
57+
fields_copy = copy.deepcopy(model.__fields__)
58+
annotations_copy = copy.deepcopy(model.__annotations__)
59+
for field_name, field in annotations_copy.items():
60+
if field_name not in fields_copy:
6161
continue
62-
field_info = model.__fields__[field_name].field_info
62+
63+
if is_pydantic_v2:
64+
field_info = fields_copy[field_name]
65+
else:
66+
field_info = fields_copy[field_name].field_info
6367
try:
6468
if safe_issubclass(field, DocList):
6569
t: Any = field.doc_type
@@ -68,9 +72,8 @@ class MyDoc(BaseDoc):
6872
fields[field_name] = (field, field_info)
6973
except TypeError:
7074
fields[field_name] = (field, field_info)
71-
return create_model(
72-
model.__name__, __base__=model, __validators__=model.__validators__, **fields
73-
)
75+
76+
return create_model(model.__name__, __base__=model, __doc__=model.__doc__, **fields)
7477

7578

7679
def _get_field_annotation_from_schema(
@@ -201,6 +204,8 @@ def _get_field_annotation_from_schema(
201204
num_recursions=num_recursions + 1,
202205
definitions=definitions,
203206
)
207+
elif field_type == 'null':
208+
ret = None
204209
else:
205210
if num_recursions > 0:
206211
raise ValueError(
@@ -255,14 +260,18 @@ class MyDoc(BaseDoc):
255260
:return: A BaseDoc class dynamically created following the `schema`.
256261
"""
257262
if not definitions:
258-
definitions = schema.get('definitions', {})
263+
definitions = (
264+
schema.get('definitions', {}) if not is_pydantic_v2 else schema.get('$defs')
265+
)
259266

260267
cached_models = cached_models if cached_models is not None else {}
261268
fields: Dict[str, Any] = {}
262269
if base_doc_name in cached_models:
263270
return cached_models[base_doc_name]
271+
has_id = False
264272
for field_name, field_schema in schema.get('properties', {}).items():
265-
273+
if field_name == 'id':
274+
has_id = True
266275
field_type = _get_field_annotation_from_schema(
267276
field_schema=field_schema,
268277
field_name=field_name,
@@ -272,17 +281,43 @@ class MyDoc(BaseDoc):
272281
num_recursions=0,
273282
definitions=definitions,
274283
)
275-
fields[field_name] = (
276-
field_type,
277-
FieldInfo(default=field_schema.pop('default', None), **field_schema),
278-
)
284+
if not is_pydantic_v2:
285+
field_schema['default'] = field_schema.get('default', None)
286+
fields[field_name] = (
287+
field_type,
288+
FieldInfo(**field_schema),
289+
)
290+
else:
291+
field_kwargs = {}
292+
field_json_schema_extra = {}
293+
for k, v in field_schema.items():
294+
if k in FieldInfo.__slots__:
295+
field_kwargs[k] = v
296+
else:
297+
field_json_schema_extra[k] = v
298+
fields[field_name] = (
299+
field_type,
300+
FieldInfo(
301+
json_schema_extra=field_json_schema_extra,
302+
**field_kwargs,
303+
),
304+
)
279305

280-
model = create_model(base_doc_name, __base__=BaseDoc, **fields)
281-
model.__config__.title = schema.get('title', model.__config__.title)
306+
base_model = BaseDoc if has_id else BaseDocWithoutId
307+
model = create_model(base_doc_name, __base__=base_model, **fields)
308+
if not is_pydantic_v2:
309+
model.__config__.title = schema.get('title', model.__config__.title)
310+
else:
311+
set_title = schema.get('title', model.model_config.get('title', None))
312+
if set_title:
313+
model.model_config['title'] = set_title
282314

283315
for k in RESERVED_KEYS:
284316
if k in schema:
285317
schema.pop(k)
286-
model.__config__.schema_extra = schema
318+
if not is_pydantic_v2:
319+
model.__config__.schema_extra = schema
320+
else:
321+
model.model_config['json_schema_extra'] = schema
287322
cached_models[base_doc_name] = model
288323
return model

tests/units/util/test_create_dynamic_code_class.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from docarray import BaseDoc, DocList
88
from docarray.documents import TextDoc
99
from docarray.typing import AnyTensor, ImageUrl
10-
from docarray.utils._internal.pydantic import is_pydantic_v2
1110
from docarray.utils.create_dynamic_doc_class import (
1211
create_base_doc_from_schema,
1312
create_pure_python_type_model,
1413
)
14+
from docarray.utils._internal.pydantic import is_pydantic_v2
1515

1616

17-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
1817
@pytest.mark.parametrize('transformation', ['proto', 'json'])
1918
def test_create_pydantic_model_from_schema(transformation):
2019
class Nested2Doc(BaseDoc):
@@ -26,7 +25,7 @@ class Nested1Doc(BaseDoc):
2625
classvar: ClassVar[str] = 'classvar1'
2726

2827
class CustomDoc(BaseDoc):
29-
tensor: Optional[AnyTensor]
28+
tensor: Optional[AnyTensor] = None
3029
url: ImageUrl
3130
lll: List[List[List[int]]] = [[[5]]]
3231
fff: List[List[List[float]]] = [[[5.2]]]
@@ -80,7 +79,10 @@ class CustomDoc(BaseDoc):
8079
assert len(custom_partial_da) == 1
8180
assert custom_partial_da[0].url == 'photo.jpg'
8281
assert custom_partial_da[0].lll == [[[40]]]
83-
assert custom_partial_da[0].lu == ['3', '4'] # Union validates back to string
82+
if is_pydantic_v2:
83+
assert custom_partial_da[0].lu == [3, 4]
84+
else:
85+
assert custom_partial_da[0].lu == ['3', '4'] # Union validates back to string
8486
assert custom_partial_da[0].fff == [[[40.2]]]
8587
assert custom_partial_da[0].di == {'a': 2}
8688
assert custom_partial_da[0].d == {'b': 'a'}
@@ -99,7 +101,10 @@ class CustomDoc(BaseDoc):
99101
assert len(original_back) == 1
100102
assert original_back[0].url == 'photo.jpg'
101103
assert original_back[0].lll == [[[40]]]
102-
assert original_back[0].lu == ['3', '4'] # Union validates back to string
104+
if is_pydantic_v2:
105+
assert original_back[0].lu == [3, 4] # Union validates back to string
106+
else:
107+
assert original_back[0].lu == ['3', '4'] # Union validates back to string
103108
assert original_back[0].fff == [[[40.2]]]
104109
assert original_back[0].di == {'a': 2}
105110
assert original_back[0].d == {'b': 'a'}
@@ -174,7 +179,6 @@ class ResultTestDoc(BaseDoc):
174179
assert doc.ia == f'ID {i}'
175180

176181

177-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
178182
@pytest.mark.parametrize('transformation', ['proto', 'json'])
179183
def test_create_empty_doc_list_from_schema(transformation):
180184
class CustomDoc(BaseDoc):
@@ -260,7 +264,6 @@ class ResultTestDoc(BaseDoc):
260264
assert len(custom_da) == 0
261265

262266

263-
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
264267
def test_create_with_field_info():
265268
class CustomDoc(BaseDoc):
266269
"""Here I have the description of the class"""

0 commit comments

Comments
 (0)