From 08d0fa3e68d30a90bbe9490d93d63a9afd0c2485 Mon Sep 17 00:00:00 2001 From: Joan Martinez Date: Fri, 15 Dec 2023 14:08:02 +0100 Subject: [PATCH] fix: fix issue serializing deserializing complex schemas Signed-off-by: Joan Martinez --- docarray/base_doc/mixins/io.py | 18 +++--- tests/units/array/test_array_from_to_json.py | 2 +- tests/units/array/test_array_proto.py | 39 ++++++++++++ tests/units/document/test_from_to_bytes.py | 63 +++++++++++++++++++- 4 files changed, 112 insertions(+), 10 deletions(-) diff --git a/docarray/base_doc/mixins/io.py b/docarray/base_doc/mixins/io.py index 0f371d21abf..3121c45c445 100644 --- a/docarray/base_doc/mixins/io.py +++ b/docarray/base_doc/mixins/io.py @@ -285,7 +285,6 @@ def _get_content_from_node_proto( ) return_field: Any - if docarray_type in content_type_dict: return_field = content_type_dict[docarray_type].from_protobuf( getattr(value, content_key) @@ -308,13 +307,18 @@ def _get_content_from_node_proto( f'{field_type} is not supported for proto deserialization' ) elif content_key == 'doc_array': - if field_name is None: + if field_type is not None and field_name is None: + return_field = field_type.from_protobuf(getattr(value, content_key)) + elif field_name is not None: + return_field = cls._get_field_annotation_array( + field_name + ).from_protobuf( + getattr(value, content_key) + ) # we get to the parent class + else: raise ValueError( - 'field_name cannot be None when trying to deserialize a BaseDoc' + 'field_name and field_type cannot be None when trying to deserialize a DocArray' ) - return_field = cls._get_field_annotation_array(field_name).from_protobuf( - getattr(value, content_key) - ) # we get to the parent class elif content_key is None: return_field = None elif docarray_type is None: @@ -330,8 +334,6 @@ def _get_content_from_node_proto( elif content_key in arg_to_container.keys(): if field_name and field_name in cls._docarray_fields(): field_type = cls._get_field_inner_type(field_name) - else: - field_type = None if isinstance(field_type, GenericAlias): field_type = get_args(field_type)[0] diff --git a/tests/units/array/test_array_from_to_json.py b/tests/units/array/test_array_from_to_json.py index 726c7520455..f257a22ac86 100644 --- a/tests/units/array/test_array_from_to_json.py +++ b/tests/units/array/test_array_from_to_json.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Dict, List import numpy as np import pytest diff --git a/tests/units/array/test_array_proto.py b/tests/units/array/test_array_proto.py index 495474dc1c4..916412461ed 100644 --- a/tests/units/array/test_array_proto.py +++ b/tests/units/array/test_array_proto.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from typing import Dict, List from docarray import BaseDoc, DocList from docarray.base_doc import AnyDoc @@ -111,3 +112,41 @@ class BasisUnion(BaseDoc): docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")]) docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf()) assert docs_copy == docs_basic + + +class MySimpleDoc(BaseDoc): + title: str + + +class MyComplexDoc(BaseDoc): + content_dict_doclist: Dict[str, DocList[MySimpleDoc]] + content_dict_list: Dict[str, List[MySimpleDoc]] + aux_dict: Dict[str, int] + + +def test_to_from_proto_complex(): + da = DocList[MyComplexDoc]( + [ + MyComplexDoc( + content_dict_doclist={ + 'test1': DocList[MySimpleDoc]( + [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + ) + }, + content_dict_list={ + 'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + }, + aux_dict={'a': 0}, + ) + ] + ) + da2 = DocList[MyComplexDoc].from_protobuf(da.to_protobuf()) + assert len(da2) == 1 + d2 = da2[0] + assert d2.aux_dict == {'a': 0} + assert len(d2.content_dict_doclist['test1']) == 2 + assert d2.content_dict_doclist['test1'][0].title == '123' + assert d2.content_dict_doclist['test1'][1].title == '456' + assert len(d2.content_dict_list['test1']) == 2 + assert d2.content_dict_list['test1'][0].title == '123' + assert d2.content_dict_list['test1'][1].title == '456' diff --git a/tests/units/document/test_from_to_bytes.py b/tests/units/document/test_from_to_bytes.py index 5a3eb620780..25917b0aca2 100644 --- a/tests/units/document/test_from_to_bytes.py +++ b/tests/units/document/test_from_to_bytes.py @@ -1,6 +1,7 @@ import pytest +from typing import Dict, List -from docarray import BaseDoc +from docarray import BaseDoc, DocList from docarray.documents import ImageDoc from docarray.typing import NdArray @@ -11,6 +12,16 @@ class MyDoc(BaseDoc): image: ImageDoc +class MySimpleDoc(BaseDoc): + title: str + + +class MyComplexDoc(BaseDoc): + content_dict_doclist: Dict[str, DocList[MySimpleDoc]] + content_dict_list: Dict[str, List[MySimpleDoc]] + aux_dict: Dict[str, int] + + @pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) @pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) def test_to_from_bytes(protocol, compress): @@ -39,3 +50,53 @@ def test_to_from_base64(protocol, compress): assert d2.text == 'hello' assert d2.embedding.tolist() == [1, 2, 3, 4, 5] assert d2.image.url == 'aux.png' + + +@pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) +@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) +def test_to_from_bytes_complex(protocol, compress): + d = MyComplexDoc( + content_dict_doclist={ + 'test1': DocList[MySimpleDoc]( + [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + ) + }, + content_dict_list={ + 'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + }, + aux_dict={'a': 0}, + ) + bstr = d.to_bytes(protocol=protocol, compress=compress) + d2 = MyComplexDoc.from_bytes(bstr, protocol=protocol, compress=compress) + assert d2.aux_dict == {'a': 0} + assert len(d2.content_dict_doclist['test1']) == 2 + assert d2.content_dict_doclist['test1'][0].title == '123' + assert d2.content_dict_doclist['test1'][1].title == '456' + assert len(d2.content_dict_list['test1']) == 2 + assert d2.content_dict_list['test1'][0].title == '123' + assert d2.content_dict_list['test1'][1].title == '456' + + +@pytest.mark.parametrize('protocol', ['protobuf', 'pickle']) +@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None]) +def test_to_from_base64_complex(protocol, compress): + d = MyComplexDoc( + content_dict_doclist={ + 'test1': DocList[MySimpleDoc]( + [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + ) + }, + content_dict_list={ + 'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')] + }, + aux_dict={'a': 0}, + ) + bstr = d.to_base64(protocol=protocol, compress=compress) + d2 = MyComplexDoc.from_base64(bstr, protocol=protocol, compress=compress) + assert d2.aux_dict == {'a': 0} + assert len(d2.content_dict_doclist['test1']) == 2 + assert d2.content_dict_doclist['test1'][0].title == '123' + assert d2.content_dict_doclist['test1'][1].title == '456' + assert len(d2.content_dict_list['test1']) == 2 + assert d2.content_dict_list['test1'][0].title == '123' + assert d2.content_dict_list['test1'][1].title == '456'