diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 29a617e2c5f..951256ef2ce 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -10,6 +10,7 @@ Type, TypeVar, Union, + cast, overload, ) @@ -267,8 +268,18 @@ def validate( ): from docarray.array.doc_vec.doc_vec import DocVec - if isinstance(value, (cls, DocVec)): + if isinstance(value, cls): return value + elif isinstance(value, DocVec): + if ( + issubclass(value.doc_type, cls.doc_type) + or value.doc_type == cls.doc_type + ): + return cast(T, value.to_doc_list()) + else: + raise ValueError( + f'DocList[value.doc_type] is not compatible with {cls}' + ) elif isinstance(value, cls): return cls(value) elif isinstance(value, Iterable): diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index e27f6882fe9..52009d83175 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -278,6 +278,14 @@ def validate( ) -> T: if isinstance(value, cls): return value + elif isinstance(value, DocList): + if ( + issubclass(value.doc_type, cls.doc_type) + or value.doc_type == cls.doc_type + ): + return cast(T, value.to_doc_vec()) + else: + raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}') elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): return cast(T, value.to_doc_vec()) elif isinstance(value, Sequence): diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 385fd9f8a8a..cf78ddd7b41 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -359,7 +359,7 @@ def test_to_device(): def test_to_device_with_nested_da(): class Video(BaseDoc): - images: DocList[ImageDoc] + images: DocVec[ImageDoc] da_image = DocVec[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor