Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from docarray import DocList, BaseDoc
from docarray.typing import AnyTensor
from typing import Any, Dict, List, Optional, Type, Union

from pydantic import create_model
from pydantic.fields import FieldInfo
from typing import Dict, List, Any, Union, Optional, Type
from docarray.utils._internal._typing import safe_issubclass

from docarray import BaseDoc, DocList
from docarray.typing import AnyTensor
from docarray.utils._internal._typing import safe_issubclass

RESERVED_KEYS = [
'type',
Expand Down Expand Up @@ -71,6 +72,7 @@ def _get_field_type_from_schema(
cached_models: Dict[str, Any],
is_tensor: bool = False,
num_recursions: int = 0,
definitions: Optional[Dict] = None,
) -> type:
"""
Private method used to extract the corresponding field type from the schema.
Expand All @@ -80,8 +82,11 @@ def _get_field_type_from_schema(
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
:param is_tensor: Boolean used to tell between tensor and list
:param num_recursions: Number of recursions to properly handle nested types (Dict, List, etc ..)
:param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
:return: A type created from the schema
"""
if not definitions:
definitions = {}
field_type = field_schema.get('type', None)
tensor_shape = field_schema.get('tensor/array shape', None)
ret: Any
Expand All @@ -96,6 +101,7 @@ def _get_field_type_from_schema(
root_schema['definitions'][ref_name],
ref_name,
cached_models=cached_models,
definitions=definitions,
)
)
else:
Expand All @@ -107,6 +113,7 @@ def _get_field_type_from_schema(
cached_models=cached_models,
is_tensor=tensor_shape is not None,
num_recursions=0,
definitions=definitions,
)
) # No Union of Lists
ret = Union[tuple(any_of_types)]
Expand Down Expand Up @@ -154,19 +161,21 @@ def _get_field_type_from_schema(
if obj_ref:
ref_name = obj_ref.split('/')[-1]
ret = create_base_doc_from_schema(
root_schema['definitions'][ref_name],
definitions[ref_name],
ref_name,
cached_models=cached_models,
definitions=definitions,
)
else:
ret = Any
else: # object reference in definitions
if obj_ref:
ref_name = obj_ref.split('/')[-1]
doc_type = create_base_doc_from_schema(
root_schema['definitions'][ref_name],
definitions[ref_name],
ref_name,
cached_models=cached_models,
definitions=definitions,
)
ret = DocList[doc_type]
else:
Expand All @@ -182,6 +191,7 @@ def _get_field_type_from_schema(
cached_models=cached_models,
is_tensor=tensor_shape is not None,
num_recursions=num_recursions + 1,
definitions=definitions,
)
else:
if num_recursions > 0:
Expand All @@ -196,7 +206,10 @@ def _get_field_type_from_schema(


def create_base_doc_from_schema(
schema: Dict[str, Any], base_doc_name: str, cached_models: Optional[Dict] = None
schema: Dict[str, Any],
base_doc_name: str,
cached_models: Optional[Dict] = None,
definitions: Optional[Dict] = None,
) -> Type:
"""
Dynamically create a `BaseDoc` subclass from a `schema` of another `BaseDoc`.
Expand Down Expand Up @@ -230,8 +243,12 @@ class MyDoc(BaseDoc):
:param schema: The schema of the original `BaseDoc` where DocLists are passed as regular Lists of Documents.
:param base_doc_name: The name of the new pydantic model created.
:param cached_models: Parameter used when this method is called recursively to reuse partial nested classes.
:param definitions: Parameter used when this method is called recursively to reuse root definitions of other schemas.
:return: A BaseDoc class dynamically created following the `schema`.
"""
if not definitions:
definitions = schema.get('definitions', {})

cached_models = cached_models if cached_models is not None else {}
fields: Dict[str, Any] = {}
if base_doc_name in cached_models:
Expand All @@ -245,6 +262,7 @@ class MyDoc(BaseDoc):
cached_models=cached_models,
is_tensor=False,
num_recursions=0,
definitions=definitions,
)
fields[field_name] = (
field_type,
Expand Down
24 changes: 17 additions & 7 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pytest
from typing import List, Dict, Union, Any
from pydantic import Field

from docarray import BaseDoc, DocList
from docarray.documents import TextDoc
from docarray.typing import AnyTensor, ImageUrl
from docarray.utils.create_dynamic_doc_class import (
create_base_doc_from_schema,
create_pure_python_type_model,
)
import numpy as np
from typing import Optional
from docarray import BaseDoc, DocList
from docarray.typing import AnyTensor, ImageUrl
from docarray.documents import TextDoc
from pydantic import Field


@pytest.mark.parametrize('transformation', ['proto', 'json'])
def test_create_pydantic_model_from_schema(transformation):
class Nested2Doc(BaseDoc):
value: str

class Nested1Doc(BaseDoc):
nested: Nested2Doc

class CustomDoc(BaseDoc):
tensor: Optional[AnyTensor]
url: ImageUrl
Expand All @@ -26,6 +33,7 @@ class CustomDoc(BaseDoc):
u: Union[str, int]
lu: List[Union[str, int]] = [0, 1, 2]
tags: Optional[Dict[str, Any]] = None
nested: Nested1Doc

CustomDocCopy = create_pure_python_type_model(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
Expand All @@ -43,6 +51,7 @@ class CustomDoc(BaseDoc):
single_text=TextDoc(text='single hey ha', embedding=np.zeros(2)),
u='a',
lu=[3, 4],
nested=Nested1Doc(nested=Nested2Doc(value='hello world')),
)
]
)
Expand Down Expand Up @@ -77,6 +86,7 @@ class CustomDoc(BaseDoc):
assert custom_partial_da[0].u == 'a'
assert custom_partial_da[0].single_text.text == 'single hey ha'
assert custom_partial_da[0].single_text.embedding.shape == (2,)
assert original_back[0].nested.nested.value == 'hello world'

assert len(original_back) == 1
assert original_back[0].url == 'photo.jpg'
Expand Down