1
1
from docarray import DocList , BaseDoc
2
2
from docarray .typing import AnyTensor
3
3
from pydantic import create_model
4
- from typing import Dict , List , Any , Union , Optional , Tuple , Type
5
- from typing_extensions import TypeAlias
4
+ from typing import Dict , List , Any , Union , Optional , Type
6
5
7
6
8
- def create_new_model_cast_doclist_to_list (model : BaseDoc ) -> BaseDoc :
9
- fields : Dict [str , Tuple [ Type , Dict ] ] = {}
7
+ def create_new_model_cast_doclist_to_list (model : Any ) -> BaseDoc :
8
+ fields : Dict [str , Any ] = {}
10
9
for field_name , field in model .__annotations__ .items ():
11
10
try :
12
11
if issubclass (field , DocList ):
13
- fields [field_name ] = (List [field .doc_type ], {})
12
+ t : Any = field .doc_type
13
+ fields [field_name ] = (List [t ], {})
14
14
else :
15
15
fields [field_name ] = (field , {})
16
16
except TypeError :
@@ -30,7 +30,7 @@ def _get_field_from_type(
30
30
):
31
31
field_type = field_schema .get ('type' , None )
32
32
tensor_shape = field_schema .get ('tensor/array shape' , None )
33
- ret : TypeAlias
33
+ ret : Any
34
34
if 'anyOf' in field_schema :
35
35
any_of_types = []
36
36
for any_of_schema in field_schema ['anyOf' ]:
@@ -82,15 +82,14 @@ def _get_field_from_type(
82
82
for rec in range (num_recursions ):
83
83
ret = List [ret ]
84
84
elif field_type == 'object' or field_type is None :
85
+ doc_type : Any
85
86
if 'additionalProperties' in field_schema : # handle Dictionaries
86
87
additional_props = field_schema ['additionalProperties' ]
87
88
if additional_props .get ('type' ) == 'object' :
88
- ret = Dict [
89
- str ,
90
- create_base_doc_from_schema (
91
- additional_props , field_name , cached_models = cached_models
92
- ),
93
- ]
89
+ doc_type = create_base_doc_from_schema (
90
+ additional_props , field_name , cached_models = cached_models
91
+ )
92
+ ret = Dict [str , doc_type ]
94
93
else :
95
94
ret = Dict [str , Any ]
96
95
else :
@@ -110,19 +109,17 @@ def _get_field_from_type(
110
109
else : # object reference in definitions
111
110
if obj_ref :
112
111
ref_name = obj_ref .split ('/' )[- 1 ]
113
- ret = DocList [
114
- create_base_doc_from_schema (
115
- root_schema ['definitions' ][ref_name ],
116
- ref_name ,
117
- cached_models = cached_models ,
118
- )
119
- ]
112
+ doc_type = create_base_doc_from_schema (
113
+ root_schema ['definitions' ][ref_name ],
114
+ ref_name ,
115
+ cached_models = cached_models ,
116
+ )
117
+ ret = DocList [doc_type ]
120
118
else :
121
- ret = DocList [
122
- create_base_doc_from_schema (
123
- field_schema , field_name , cached_models = cached_models
124
- )
125
- ]
119
+ doc_type = create_base_doc_from_schema (
120
+ field_schema , field_name , cached_models = cached_models
121
+ )
122
+ ret = DocList [doc_type ]
126
123
elif field_type == 'array' :
127
124
ret = _get_field_from_type (
128
125
field_schema = field_schema .get ('items' , {}),
@@ -148,7 +145,7 @@ def create_base_doc_from_schema(
148
145
schema : Dict [str , Any ], model_name : str , cached_models : Optional [Dict ] = None
149
146
) -> Type :
150
147
cached_models = cached_models if cached_models is not None else {}
151
- fields = {}
148
+ fields : Dict [ str , Any ] = {}
152
149
if model_name in cached_models :
153
150
return cached_models [model_name ]
154
151
for field_name , field_schema in schema .get ('properties' , {}).items ():
0 commit comments