From 942ed17019045f473c18f249fcc2e9ef623dd8a4 Mon Sep 17 00:00:00 2001 From: Joan Fontanals Martinez Date: Sat, 29 Jul 2023 16:19:00 +0200 Subject: [PATCH] fix: create more info from dynamic Signed-off-by: Joan Fontanals Martinez --- docarray/utils/create_dynamic_doc_class.py | 14 ++++++++++---- .../util/test_create_dynamic_code_class.py | 18 ++++++++++++++++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/docarray/utils/create_dynamic_doc_class.py b/docarray/utils/create_dynamic_doc_class.py index 254b7013f36..6b85e97ee0a 100644 --- a/docarray/utils/create_dynamic_doc_class.py +++ b/docarray/utils/create_dynamic_doc_class.py @@ -1,6 +1,7 @@ from docarray import DocList, BaseDoc from docarray.typing import AnyTensor from pydantic import create_model +from pydantic.fields import FieldInfo from typing import Dict, List, Any, Union, Optional, Type from docarray.utils._internal._typing import safe_issubclass @@ -36,14 +37,15 @@ class MyDoc(BaseDoc): """ fields: Dict[str, Any] = {} for field_name, field in model.__annotations__.items(): + field_info = model.__fields__[field_name].field_info try: if safe_issubclass(field, DocList): t: Any = field.doc_type - fields[field_name] = (List[t], {}) + fields[field_name] = (List[t], field_info) else: - fields[field_name] = (field, {}) + fields[field_name] = (field, field_info) except TypeError: - fields[field_name] = (field, {}) + fields[field_name] = (field, field_info) return create_model( model.__name__, __base__=model, __validators__=model.__validators__, **fields ) @@ -222,6 +224,7 @@ class MyDoc(BaseDoc): if base_doc_name in cached_models: return cached_models[base_doc_name] for field_name, field_schema in schema.get('properties', {}).items(): + field_type = _get_field_type_from_schema( field_schema=field_schema, field_name=field_name, @@ -230,7 +233,10 @@ class MyDoc(BaseDoc): is_tensor=False, num_recursions=0, ) - fields[field_name] = (field_type, field_schema.get('description')) + fields[field_name] = ( + field_type, + FieldInfo(default=field_schema.pop('default', None), **field_schema), + ) model = create_model(base_doc_name, __base__=BaseDoc, **fields) cached_models[base_doc_name] = model diff --git a/tests/units/util/test_create_dynamic_code_class.py b/tests/units/util/test_create_dynamic_code_class.py index ff7f6551403..2e3243bfeea 100644 --- a/tests/units/util/test_create_dynamic_code_class.py +++ b/tests/units/util/test_create_dynamic_code_class.py @@ -9,6 +9,7 @@ from docarray import BaseDoc, DocList from docarray.typing import AnyTensor, ImageUrl from docarray.documents import TextDoc +from pydantic import Field @pytest.mark.parametrize('transformation', ['proto', 'json']) @@ -238,3 +239,20 @@ class ResultTestDoc(BaseDoc): assert len(original_back) == 0 assert len(custom_da) == 0 + + +def test_create_with_field_info(): + class CustomDoc(BaseDoc): + a: str = Field(examples=['Example here'], another_extra='I am another extra') + + CustomDocCopy = create_pure_python_type_model(CustomDoc) + new_custom_doc_model = create_base_doc_from_schema( + CustomDocCopy.schema(), 'CustomDoc' + ) + assert new_custom_doc_model.schema().get('properties')['a']['examples'] == [ + 'Example here' + ] + assert ( + new_custom_doc_model.schema().get('properties')['a']['another_extra'] + == 'I am another extra' + )