3
3
from pydantic import BaseModel , create_model
4
4
from pydantic .fields import FieldInfo
5
5
6
+ from docarray .base_doc .doc import BaseDocWithoutId
6
7
from docarray import BaseDoc , DocList
7
8
from docarray .typing import AnyTensor
8
9
from docarray .utils ._internal ._typing import safe_issubclass
@@ -50,16 +51,19 @@ class MyDoc(BaseDoc):
50
51
:param model: The input model
51
52
:return: A new subclass of BaseDoc, where every DocList type in the schema is replaced by List.
52
53
"""
53
- if is_pydantic_v2 :
54
- raise NotImplementedError (
55
- 'This method is not supported in Pydantic 2.0. Please use Pydantic 1.8.2 or lower.'
56
- )
57
-
58
54
fields : Dict [str , Any ] = {}
59
- for field_name , field in model .__annotations__ .items ():
60
- if field_name not in model .__fields__ :
55
+ import copy
56
+
57
+ fields_copy = copy .deepcopy (model .__fields__ )
58
+ annotations_copy = copy .deepcopy (model .__annotations__ )
59
+ for field_name , field in annotations_copy .items ():
60
+ if field_name not in fields_copy :
61
61
continue
62
- field_info = model .__fields__ [field_name ].field_info
62
+
63
+ if is_pydantic_v2 :
64
+ field_info = fields_copy [field_name ]
65
+ else :
66
+ field_info = fields_copy [field_name ].field_info
63
67
try :
64
68
if safe_issubclass (field , DocList ):
65
69
t : Any = field .doc_type
@@ -68,9 +72,8 @@ class MyDoc(BaseDoc):
68
72
fields [field_name ] = (field , field_info )
69
73
except TypeError :
70
74
fields [field_name ] = (field , field_info )
71
- return create_model (
72
- model .__name__ , __base__ = model , __validators__ = model .__validators__ , ** fields
73
- )
75
+
76
+ return create_model (model .__name__ , __base__ = model , __doc__ = model .__doc__ , ** fields )
74
77
75
78
76
79
def _get_field_annotation_from_schema (
@@ -201,6 +204,8 @@ def _get_field_annotation_from_schema(
201
204
num_recursions = num_recursions + 1 ,
202
205
definitions = definitions ,
203
206
)
207
+ elif field_type == 'null' :
208
+ ret = None
204
209
else :
205
210
if num_recursions > 0 :
206
211
raise ValueError (
@@ -255,14 +260,18 @@ class MyDoc(BaseDoc):
255
260
:return: A BaseDoc class dynamically created following the `schema`.
256
261
"""
257
262
if not definitions :
258
- definitions = schema .get ('definitions' , {})
263
+ definitions = (
264
+ schema .get ('definitions' , {}) if not is_pydantic_v2 else schema .get ('$defs' )
265
+ )
259
266
260
267
cached_models = cached_models if cached_models is not None else {}
261
268
fields : Dict [str , Any ] = {}
262
269
if base_doc_name in cached_models :
263
270
return cached_models [base_doc_name ]
271
+ has_id = False
264
272
for field_name , field_schema in schema .get ('properties' , {}).items ():
265
-
273
+ if field_name == 'id' :
274
+ has_id = True
266
275
field_type = _get_field_annotation_from_schema (
267
276
field_schema = field_schema ,
268
277
field_name = field_name ,
@@ -272,17 +281,43 @@ class MyDoc(BaseDoc):
272
281
num_recursions = 0 ,
273
282
definitions = definitions ,
274
283
)
275
- fields [field_name ] = (
276
- field_type ,
277
- FieldInfo (default = field_schema .pop ('default' , None ), ** field_schema ),
278
- )
284
+ if not is_pydantic_v2 :
285
+ field_schema ['default' ] = field_schema .get ('default' , None )
286
+ fields [field_name ] = (
287
+ field_type ,
288
+ FieldInfo (** field_schema ),
289
+ )
290
+ else :
291
+ field_kwargs = {}
292
+ field_json_schema_extra = {}
293
+ for k , v in field_schema .items ():
294
+ if k in FieldInfo .__slots__ :
295
+ field_kwargs [k ] = v
296
+ else :
297
+ field_json_schema_extra [k ] = v
298
+ fields [field_name ] = (
299
+ field_type ,
300
+ FieldInfo (
301
+ json_schema_extra = field_json_schema_extra ,
302
+ ** field_kwargs ,
303
+ ),
304
+ )
279
305
280
- model = create_model (base_doc_name , __base__ = BaseDoc , ** fields )
281
- model .__config__ .title = schema .get ('title' , model .__config__ .title )
306
+ base_model = BaseDoc if has_id else BaseDocWithoutId
307
+ model = create_model (base_doc_name , __base__ = base_model , ** fields )
308
+ if not is_pydantic_v2 :
309
+ model .__config__ .title = schema .get ('title' , model .__config__ .title )
310
+ else :
311
+ set_title = schema .get ('title' , model .model_config .get ('title' , None ))
312
+ if set_title :
313
+ model .model_config ['title' ] = set_title
282
314
283
315
for k in RESERVED_KEYS :
284
316
if k in schema :
285
317
schema .pop (k )
286
- model .__config__ .schema_extra = schema
318
+ if not is_pydantic_v2 :
319
+ model .__config__ .schema_extra = schema
320
+ else :
321
+ model .model_config ['json_schema_extra' ] = schema
287
322
cached_models [base_doc_name ] = model
288
323
return model
0 commit comments