Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docarray/array/doc_list/doc_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Type,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -267,8 +268,18 @@ def validate(
):
from docarray.array.doc_vec.doc_vec import DocVec

if isinstance(value, (cls, DocVec)):
if isinstance(value, cls):
return value
elif isinstance(value, DocVec):
if (
issubclass(value.doc_type, cls.doc_type)
or value.doc_type == cls.doc_type
):
return cast(T, value.to_doc_list())
else:
raise ValueError(
f'DocList[value.doc_type] is not compatible with {cls}'
)
elif isinstance(value, cls):
return cls(value)
elif isinstance(value, Iterable):
Expand Down
8 changes: 8 additions & 0 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ def validate(
) -> T:
if isinstance(value, cls):
return value
elif isinstance(value, DocList):
if (
issubclass(value.doc_type, cls.doc_type)
or value.doc_type == cls.doc_type
):
return cast(T, value.to_doc_vec())
else:
raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}')
elif isinstance(value, DocList.__class_getitem__(cls.doc_type)):
return cast(T, value.to_doc_vec())
elif isinstance(value, Sequence):
Expand Down
2 changes: 1 addition & 1 deletion tests/units/array/stack/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_to_device():

def test_to_device_with_nested_da():
class Video(BaseDoc):
images: DocList[ImageDoc]
images: DocVec[ImageDoc]

da_image = DocVec[ImageDoc](
[ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor
Expand Down