Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def _get_content_from_node_proto(
)

return_field: Any

if docarray_type in content_type_dict:
return_field = content_type_dict[docarray_type].from_protobuf(
getattr(value, content_key)
Expand All @@ -308,13 +307,18 @@ def _get_content_from_node_proto(
f'{field_type} is not supported for proto deserialization'
)
elif content_key == 'doc_array':
if field_name is None:
if field_type is not None and field_name is None:
return_field = field_type.from_protobuf(getattr(value, content_key))
elif field_name is not None:
return_field = cls._get_field_annotation_array(
field_name
).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
else:
raise ValueError(
'field_name cannot be None when trying to deserialize a BaseDoc'
'field_name and field_type cannot be None when trying to deserialize a DocArray'
)
return_field = cls._get_field_annotation_array(field_name).from_protobuf(
getattr(value, content_key)
) # we get to the parent class
elif content_key is None:
return_field = None
elif docarray_type is None:
Expand All @@ -330,8 +334,6 @@ def _get_content_from_node_proto(
elif content_key in arg_to_container.keys():
if field_name and field_name in cls._docarray_fields():
field_type = cls._get_field_inner_type(field_name)
else:
field_type = None

if isinstance(field_type, GenericAlias):
field_type = get_args(field_type)[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/units/array/test_array_from_to_json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Dict, List

import numpy as np
import pytest
Expand Down
39 changes: 39 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from typing import Dict, List

from docarray import BaseDoc, DocList
from docarray.base_doc import AnyDoc
Expand Down Expand Up @@ -111,3 +112,41 @@ class BasisUnion(BaseDoc):
docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf())
assert docs_copy == docs_basic


class MySimpleDoc(BaseDoc):
title: str


class MyComplexDoc(BaseDoc):
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
content_dict_list: Dict[str, List[MySimpleDoc]]
aux_dict: Dict[str, int]


def test_to_from_proto_complex():
da = DocList[MyComplexDoc](
[
MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
]
)
da2 = DocList[MyComplexDoc].from_protobuf(da.to_protobuf())
assert len(da2) == 1
d2 = da2[0]
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'
63 changes: 62 additions & 1 deletion tests/units/document/test_from_to_bytes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from typing import Dict, List

from docarray import BaseDoc
from docarray import BaseDoc, DocList
from docarray.documents import ImageDoc
from docarray.typing import NdArray

Expand All @@ -11,6 +12,16 @@ class MyDoc(BaseDoc):
image: ImageDoc


class MySimpleDoc(BaseDoc):
title: str


class MyComplexDoc(BaseDoc):
content_dict_doclist: Dict[str, DocList[MySimpleDoc]]
content_dict_list: Dict[str, List[MySimpleDoc]]
aux_dict: Dict[str, int]


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes(protocol, compress):
Expand Down Expand Up @@ -39,3 +50,53 @@ def test_to_from_base64(protocol, compress):
assert d2.text == 'hello'
assert d2.embedding.tolist() == [1, 2, 3, 4, 5]
assert d2.image.url == 'aux.png'


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_bytes_complex(protocol, compress):
d = MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
bstr = d.to_bytes(protocol=protocol, compress=compress)
d2 = MyComplexDoc.from_bytes(bstr, protocol=protocol, compress=compress)
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'


@pytest.mark.parametrize('protocol', ['protobuf', 'pickle'])
@pytest.mark.parametrize('compress', ['lz4', 'bz2', 'lzma', 'zlib', 'gzip', None])
def test_to_from_base64_complex(protocol, compress):
d = MyComplexDoc(
content_dict_doclist={
'test1': DocList[MySimpleDoc](
[MySimpleDoc(title='123'), MySimpleDoc(title='456')]
)
},
content_dict_list={
'test1': [MySimpleDoc(title='123'), MySimpleDoc(title='456')]
},
aux_dict={'a': 0},
)
bstr = d.to_base64(protocol=protocol, compress=compress)
d2 = MyComplexDoc.from_base64(bstr, protocol=protocol, compress=compress)
assert d2.aux_dict == {'a': 0}
assert len(d2.content_dict_doclist['test1']) == 2
assert d2.content_dict_doclist['test1'][0].title == '123'
assert d2.content_dict_doclist['test1'][1].title == '456'
assert len(d2.content_dict_list['test1']) == 2
assert d2.content_dict_list['test1'][0].title == '123'
assert d2.content_dict_list['test1'][1].title == '456'