From 5a52cf92c0bc1bc129e0a2439819ca78ec2b1d49 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Wed, 24 May 2023 15:19:46 +0530 Subject: [PATCH 1/5] fix: issue 1545,DocList to DocVec casting and vice versa Signed-off-by: agaraman0 --- docarray/array/doc_list/doc_list.py | 4 +++- docarray/array/doc_vec/doc_vec.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 29a617e2c5f..30e203182a3 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -267,8 +267,10 @@ def validate( ): from docarray.array.doc_vec.doc_vec import DocVec - if isinstance(value, (cls, DocVec)): + if isinstance(value, cls): return value + elif isinstance(value, DocVec): + return value.to_doc_list() elif isinstance(value, cls): return cls(value) elif isinstance(value, Iterable): diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index e27f6882fe9..a96bf86cfd2 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -278,6 +278,8 @@ def validate( ) -> T: if isinstance(value, cls): return value + elif isinstance(value, DocList): + return value.to_doc_vec() elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): return cast(T, value.to_doc_vec()) elif isinstance(value, Sequence): From b834d4f3745613a0033634627af6454d8aecf713 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Sun, 4 Jun 2023 10:01:24 +0530 Subject: [PATCH 2/5] fix: mypy errors Signed-off-by: agaraman0 --- docarray/array/doc_list/doc_list.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docarray/array/doc_list/doc_list.py b/docarray/array/doc_list/doc_list.py index 30e203182a3..951256ef2ce 100644 --- a/docarray/array/doc_list/doc_list.py +++ b/docarray/array/doc_list/doc_list.py @@ -10,6 +10,7 @@ Type, TypeVar, Union, + cast, overload, ) @@ -270,7 +271,15 @@ def validate( if isinstance(value, cls): return value elif isinstance(value, DocVec): - return value.to_doc_list() + if ( + issubclass(value.doc_type, cls.doc_type) + or value.doc_type == cls.doc_type + ): + return cast(T, value.to_doc_list()) + else: + raise ValueError( + f'DocList[value.doc_type] is not compatible with {cls}' + ) elif isinstance(value, cls): return cls(value) elif isinstance(value, Iterable): From c78ae1c6ed18029cc3d0712b41e313c1a4a7f59b Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Sun, 4 Jun 2023 10:07:38 +0530 Subject: [PATCH 3/5] fix: black errors Signed-off-by: agaraman0 --- docarray/array/doc_vec/doc_vec.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index a96bf86cfd2..f5578647af8 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -279,7 +279,13 @@ def validate( if isinstance(value, cls): return value elif isinstance(value, DocList): - return value.to_doc_vec() + if ( + issubclass(value.doc_type, cls.doc_type) + or value.doc_type == cls.doc_type + ): + return cast(T, value.to_doc_list()) + else: + raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}') elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): return cast(T, value.to_doc_vec()) elif isinstance(value, Sequence): From d4aa3dfe896cea842ee792b422cc54524f4e3701 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Sun, 4 Jun 2023 10:19:40 +0530 Subject: [PATCH 4/5] fix: doc_vec broken changes Signed-off-by: agaraman0 --- docarray/array/doc_vec/doc_vec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docarray/array/doc_vec/doc_vec.py b/docarray/array/doc_vec/doc_vec.py index f5578647af8..52009d83175 100644 --- a/docarray/array/doc_vec/doc_vec.py +++ b/docarray/array/doc_vec/doc_vec.py @@ -283,7 +283,7 @@ def validate( issubclass(value.doc_type, cls.doc_type) or value.doc_type == cls.doc_type ): - return cast(T, value.to_doc_list()) + return cast(T, value.to_doc_vec()) else: raise ValueError(f'DocVec[value.doc_type] is not compatible with {cls}') elif isinstance(value, DocList.__class_getitem__(cls.doc_type)): From e0e1a52bf51873ed95fe1bfff04d574d3d699997 Mon Sep 17 00:00:00 2001 From: agaraman0 Date: Wed, 7 Jun 2023 09:35:18 +0530 Subject: [PATCH 5/5] fix: failing test case Signed-off-by: agaraman0 --- tests/units/array/stack/test_array_stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/units/array/stack/test_array_stacked.py b/tests/units/array/stack/test_array_stacked.py index 385fd9f8a8a..cf78ddd7b41 100644 --- a/tests/units/array/stack/test_array_stacked.py +++ b/tests/units/array/stack/test_array_stacked.py @@ -359,7 +359,7 @@ def test_to_device(): def test_to_device_with_nested_da(): class Video(BaseDoc): - images: DocList[ImageDoc] + images: DocVec[ImageDoc] da_image = DocVec[ImageDoc]( [ImageDoc(tensor=torch.zeros(3, 5))], tensor_type=TorchTensor