Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing_inspect import is_union_type

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.io import IOMixinArray
from docarray.array.doc_list.io import IOMixinDocList
from docarray.array.doc_list.pushpull import PushPullMixin
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
from docarray.base_doc import AnyDoc, BaseDoc
Expand All @@ -41,7 +41,7 @@
class DocList(
ListAdvancedIndexing[T_doc],
PushPullMixin,
IOMixinArray,
IOMixinDocList,
AnyDocArray[T_doc],
):
"""
Expand Down
65 changes: 46 additions & 19 deletions docarray/array/doc_list/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Type,
TypeVar,
Union,
cast,
)

import orjson
Expand All @@ -40,9 +41,12 @@
if TYPE_CHECKING:
import pandas as pd

from docarray.array.doc_vec.doc_vec import DocVec
from docarray.array.doc_vec.io import IOMixinDocVec
from docarray.proto import DocListProto
from docarray.typing.tensor.abstract_tensor import AbstractTensor

T = TypeVar('T', bound='IOMixinArray')
T = TypeVar('T', bound='IOMixinDocList')
T_doc = TypeVar('T_doc', bound=BaseDoc)

ARRAY_PROTOCOLS = {'protobuf-array', 'pickle-array', 'json-array'}
Expand Down Expand Up @@ -96,7 +100,7 @@ def __getitem__(self, item: slice):
return self.content[item]


class IOMixinArray(Iterable[T_doc]):
class IOMixinDocList(Iterable[T_doc]):
doc_type: Type[T_doc]

@abstractmethod
Expand Down Expand Up @@ -515,8 +519,6 @@ class Person(BaseDoc):
doc_dict = _access_path_dict_to_nested_dict(access_path2val)
docs.append(doc_type.parse_obj(doc_dict))

if not isinstance(docs, cls):
return cls(docs)
return docs

def to_dataframe(self) -> 'pd.DataFrame':
Expand Down Expand Up @@ -577,11 +579,13 @@ def _load_binary_all(
protocol: Optional[str],
compress: Optional[str],
show_progress: bool,
tensor_type: Optional[Type['AbstractTensor']] = None,
):
"""Read a `DocList` object from a binary file
:param protocol: protocol to use. It can be 'pickle-array', 'protobuf-array', 'pickle' or 'protobuf'
:param compress: compress algorithm to use between `lz4`, `bz2`, `lzma`, `zlib`, `gzip`
:param show_progress: show progress bar, only works when protocol is `pickle` or `protobuf`
:param tensor_type: only relevant for DocVec; tensor_type of the DocVec
:return: a `DocList`
"""
with file_ctx as fp:
Expand All @@ -603,12 +607,20 @@ def _load_binary_all(
proto = cls._get_proto_class()()
proto.ParseFromString(d)

return cls.from_protobuf(proto)
if tensor_type is not None:
cls_ = cast('IOMixinDocVec', cls)
return cls_.from_protobuf(proto, tensor_type=tensor_type)
else:
return cls.from_protobuf(proto)
elif protocol is not None and protocol == 'pickle-array':
return pickle.loads(d)

elif protocol is not None and protocol == 'json-array':
return cls.from_json(d)
if tensor_type is not None:
cls_ = cast('IOMixinDocVec', cls)
return cls_.from_json(d, tensor_type=tensor_type)
else:
return cls.from_json(d)

# Binary format for streaming case
else:
Expand Down Expand Up @@ -658,6 +670,10 @@ def _load_binary_all(
pbar.update(
t, advance=1, total_size=str(filesize.decimal(_total_size))
)
if tensor_type is not None:
cls__ = cast(Type['DocVec'], cls)
# mypy doesn't realize that cls_ is callable
return cls__(docs, tensor_type=tensor_type) # type: ignore
return cls(docs)

@classmethod
Expand Down Expand Up @@ -724,6 +740,27 @@ def _load_binary_stream(
t, advance=1, total_size=str(filesize.decimal(_total_size))
)

@staticmethod
def _get_file_context(
file: Union[str, bytes, pathlib.Path, io.BufferedReader, _LazyRequestReader],
protocol: str,
compress: Optional[str] = None,
) -> Tuple[Union[nullcontext, io.BufferedReader], Optional[str], Optional[str]]:
load_protocol: Optional[str] = protocol
load_compress: Optional[str] = compress
file_ctx: Union[nullcontext, io.BufferedReader]
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
file_ctx = nullcontext(file)
# by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
load_protocol, load_compress = _protocol_and_compress_from_file_path(
file, protocol, compress
)
file_ctx = open(file, 'rb')
else:
raise FileNotFoundError(f'cannot find file {file}')
return file_ctx, load_protocol, load_compress

@classmethod
def load_binary(
cls: Type[T],
Expand Down Expand Up @@ -753,19 +790,9 @@ def load_binary(
:return: a `DocList` object

"""
load_protocol: Optional[str] = protocol
load_compress: Optional[str] = compress
file_ctx: Union[nullcontext, io.BufferedReader]
if isinstance(file, (io.BufferedReader, _LazyRequestReader, bytes)):
file_ctx = nullcontext(file)
# by checking path existence we allow file to be of type Path, LocalPath, PurePath and str
elif isinstance(file, (str, pathlib.Path)) and os.path.exists(file):
load_protocol, load_compress = _protocol_and_compress_from_file_path(
file, protocol, compress
)
file_ctx = open(file, 'rb')
else:
raise FileNotFoundError(f'cannot find file {file}')
file_ctx, load_protocol, load_compress = cls._get_file_context(
file, protocol, compress
)
if streaming:
if load_protocol not in SINGLE_PROTOCOLS:
raise ValueError(
Expand Down
Loading