diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 115257720b..a3f86aad2c 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -1,10 +1,11 @@ -from docarray import DocList, BaseDoc -from docarray.typing import AnyTensor +from typing import Any, Dict, List, Optional, Type, Union + from pydantic import create_model from pydantic.fields import FieldInfo -from typing import Dict, List, Any, Union, Optional, Type -from docarray.utils._internal._typing import safe_issubclass +from docarray import BaseDoc, DocList +from docarray.typing import AnyTensor +from docarray.utils._internal._typing import safe_issubclass RESERVED_KEYS = [ 'type', @@ -71,6 +72,7 @@ def _get_field_type_from_schema( cached_models: Dict[str, Any], is_tensor: bool = False, num_recursions: int = 0, + definitions: Optional[Dict] = None, ) -> type: """ Private method used to extract the corresponding field type from the schema. @@ -80,8 +82,11 @@ def _get_field_type_from_schema( :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes. :param is_tensor: Boolean used to tell between tensor and list :param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..) + :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas. :return: A type created from the schema """ + if not definitions: + definitions = {} field_type = field_schema.get('type', None) tensor_shape = field_schema.get('tensor/array shape', None) ret: Any @@ -96,6 +101,7 @@ def _get_field_type_from_schema( root_schema['definitions'][ref_name], ref_name, cached_models=cached_models, + definitions=definitions, ) ) else: @@ -107,6 +113,7 @@ def _get_field_type_from_schema( cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=0, + definitions=definitions, ) ) # No Union of Lists ret = Union[tuple(any_of_types)] @@ -154,9 +161,10 @@ def _get_field_type_from_schema( if obj_ref: ref_name = obj_ref.split('/')[-1] ret = create_base_doc_from_schema( - root_schema['definitions'][ref_name], + definitions[ref_name], ref_name, cached_models=cached_models, + definitions=definitions, ) else: ret = Any @@ -164,9 +172,10 @@ def _get_field_type_from_schema( if obj_ref: ref_name = obj_ref.split('/')[-1] doc_type = create_base_doc_from_schema( - root_schema['definitions'][ref_name], + definitions[ref_name], ref_name, cached_models=cached_models, + definitions=definitions, ) ret = DocList[doc_type] else: @@ -182,6 +191,7 @@ def _get_field_type_from_schema( cached_models=cached_models, is_tensor=tensor_shape is not None, num_recursions=num_recursions + 1, + definitions=definitions, ) else: if num_recursions > 0: @@ -196,7 +206,10 @@ def _get_field_type_from_schema( def create_base_doc_from_schema( - schema: Dict[str, Any], base_doc_name: str, cached_models: Optional[Dict] = None + schema: Dict[str, Any], + base_doc_name: str, + cached_models: Optional[Dict] = None, + definitions: Optional[Dict] = None, ) -> Type: """ Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`. @@ -230,8 +243,12 @@ class MyDoc(BaseDoc): :param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents. :param base_doc_name: The name of the new pydantic model created. :param cached_models: Parameter used when this method is called recursively to reuse partial nested classes. + :param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas. :return: A BaseDoc class dynamically created following the `schema`. """ + if not definitions: + definitions = schema.get('definitions', {}) + cached_models = cached_models if cached_models is not None else {} fields: Dict[str, Any] = {} if base_doc_name in cached_models: @@ -245,6 +262,7 @@ class MyDoc(BaseDoc): cached_models=cached_models, is_tensor=False, num_recursions=0, + definitions=definitions, ) fields[field_name] = ( field_type, diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index 3e785d943e..848a1dd805 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -1,19 +1,26 @@ +from typing import Any, Dict, List, Optional, Union + +import numpy as np import pytest -from typing import List, Dict, Union, Any +from pydantic import Field + +from docarray import BaseDoc, DocList +from docarray.documents import TextDoc +from docarray.typing import AnyTensor, ImageUrl from docarray.utils.create_dynamic_doc_class import ( create_base_doc_from_schema, create_pure_python_type_model, ) -import numpy as np -from typing import Optional -from docarray import BaseDoc, DocList -from docarray.typing import AnyTensor, ImageUrl -from docarray.documents import TextDoc -from pydantic import Field @pytest.mark.parametrize('transformation', ['proto', 'json']) def test_create_pydantic_model_from_schema(transformation): + class Nested2Doc(BaseDoc): + value: str + + class Nested1Doc(BaseDoc): + nested: Nested2Doc + class CustomDoc(BaseDoc): tensor: Optional[AnyTensor] url: ImageUrl @@ -26,6 +33,7 @@ class CustomDoc(BaseDoc): u: Union[str, int] lu: List[Union[str, int]] = [0, 1, 2] tags: Optional[Dict[str, Any]] = None + nested: Nested1Doc CustomDocCopy = create_pure_python_type_model(CustomDoc) new_custom_doc_model = create_base_doc_from_schema( @@ -43,6 +51,7 @@ class CustomDoc(BaseDoc): single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)), u='a', lu=[3, 4], + nested=Nested1Doc(nested=Nested2Doc(value='hello world')), ) ] ) @@ -77,6 +86,7 @@ class CustomDoc(BaseDoc): assert custom_partial_da[0].u == 'a' assert custom_partial_da[0].single_text.text == 'single hey ha' assert custom_partial_da[0].single_text.embedding.shape == (2,) + assert original_back[0].nested.nested.value == 'hello world' assert len(original_back) == 1 assert original_back[0].url == 'photo.jpg'