Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions docarray/utils/create_dynamic_doc_class.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from docarray import DocList, BaseDoc
from docarray.typing import AnyTensor
from pydantic import create_model
from pydantic.fields import FieldInfo
from typing import Dict, List, Any, Union, Optional, Type
from docarray.utils._internal._typing import safe_issubclass

Expand Down Expand Up @@ -36,14 +37,15 @@ class MyDoc(BaseDoc):
"""
fields: Dict[str, Any] = {}
for field_name, field in model.__annotations__.items():
field_info = model.__fields__[field_name].field_info
try:
if safe_issubclass(field, DocList):
t: Any = field.doc_type
fields[field_name] = (List[t], {})
fields[field_name] = (List[t], field_info)
else:
fields[field_name] = (field, {})
fields[field_name] = (field, field_info)
except TypeError:
fields[field_name] = (field, {})
fields[field_name] = (field, field_info)
return create_model(
model.__name__, __base__=model, __validators__=model.__validators__, **fields
)
Expand Down Expand Up @@ -222,6 +224,7 @@ class MyDoc(BaseDoc):
if base_doc_name in cached_models:
return cached_models[base_doc_name]
for field_name, field_schema in schema.get('properties', {}).items():

field_type = _get_field_type_from_schema(
field_schema=field_schema,
field_name=field_name,
Expand All @@ -230,7 +233,10 @@ class MyDoc(BaseDoc):
is_tensor=False,
num_recursions=0,
)
fields[field_name] = (field_type, field_schema.get('description'))
fields[field_name] = (
field_type,
FieldInfo(default=field_schema.pop('default', None), **field_schema),
)

model = create_model(base_doc_name, __base__=BaseDoc, **fields)
cached_models[base_doc_name] = model
Expand Down
18 changes: 18 additions & 0 deletions tests/units/util/test_create_dynamic_code_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from docarray import BaseDoc, DocList
from docarray.typing import AnyTensor, ImageUrl
from docarray.documents import TextDoc
from pydantic import Field


@pytest.mark.parametrize('transformation', ['proto', 'json'])
Expand Down Expand Up @@ -238,3 +239,20 @@ class ResultTestDoc(BaseDoc):

assert len(original_back) == 0
assert len(custom_da) == 0


def test_create_with_field_info():
class CustomDoc(BaseDoc):
a: str = Field(examples=['Example here'], another_extra='I am another extra')

CustomDocCopy = create_pure_python_type_model(CustomDoc)
new_custom_doc_model = create_base_doc_from_schema(
CustomDocCopy.schema(), 'CustomDoc'
)
assert new_custom_doc_model.schema().get('properties')['a']['examples'] == [
'Example here'
]
assert (
new_custom_doc_model.schema().get('properties')['a']['another_extra']
== 'I am another extra'
)