Skip to content

Commit 2937e25

Browse files
authored
fix: make DocList compatible with BaseDocWithoutId (docarray#1805)
Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
1 parent 2a1cc9e commit 2937e25

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

docarray/array/any_array.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import numpy as np
2121

22-
from docarray.base_doc import BaseDoc
22+
from docarray.base_doc.doc import BaseDocWithoutId
2323
from docarray.display.document_array_summary import DocArraySummary
2424
from docarray.exceptions.exceptions import UnusableObjectError
2525
from docarray.typing.abstract_type import AbstractType
@@ -30,7 +30,7 @@
3030
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3131

3232
T = TypeVar('T', bound='AnyDocArray')
33-
T_doc = TypeVar('T_doc', bound=BaseDoc)
33+
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
3434
IndexIterType = Union[slice, Iterable[int], Iterable[bool], None]
3535

3636
UNUSABLE_ERROR_MSG = (
@@ -42,18 +42,18 @@
4242

4343

4444
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
45-
doc_type: Type[BaseDoc]
46-
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDoc], Type]] = {}
45+
doc_type: Type[BaseDocWithoutId]
46+
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}
4747

4848
def __repr__(self):
4949
return f'<{self.__class__.__name__} (length={len(self)})>'
5050

5151
@classmethod
52-
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
52+
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
5353
if not isinstance(item, type):
5454
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
5555
# this do nothing that checking that item is valid type var or str
56-
if not safe_issubclass(item, BaseDoc):
56+
if not safe_issubclass(item, BaseDocWithoutId):
5757
raise ValueError(
5858
f'{cls.__name__}[item] item should be a Document not a {item} '
5959
)
@@ -66,7 +66,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
6666
global _DocArrayTyped
6767

6868
class _DocArrayTyped(cls): # type: ignore
69-
doc_type: Type[BaseDoc] = cast(Type[BaseDoc], item)
69+
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
7070

7171
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
7272

docarray/array/doc_list/doc_list.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
from pydantic import parse_obj_as
1818
from typing_extensions import SupportsIndex
19-
from typing_inspect import is_union_type, is_typevar
19+
from typing_inspect import is_typevar, is_union_type
2020

2121
from docarray.array.any_array import AnyDocArray
2222
from docarray.array.doc_list.io import IOMixinDocList
2323
from docarray.array.doc_list.pushpull import PushPullMixin
2424
from docarray.array.list_advance_indexing import IndexIterType, ListAdvancedIndexing
25-
from docarray.base_doc import AnyDoc, BaseDoc
25+
from docarray.base_doc import AnyDoc
26+
from docarray.base_doc.doc import BaseDocWithoutId
2627
from docarray.typing import NdArray
2728
from docarray.utils._internal.pydantic import is_pydantic_v2
2829

@@ -40,7 +41,7 @@
4041
from docarray.typing.tensor.abstract_tensor import AbstractTensor
4142

4243
T = TypeVar('T', bound='DocList')
43-
T_doc = TypeVar('T_doc', bound=BaseDoc)
44+
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
4445

4546

4647
class DocList(
@@ -120,7 +121,7 @@ class Image(BaseDoc):
120121
121122
"""
122123

123-
doc_type: Type[BaseDoc] = AnyDoc
124+
doc_type: Type[BaseDocWithoutId] = AnyDoc
124125

125126
def __init__(
126127
self,
@@ -229,7 +230,7 @@ def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
229230
not is_union_type(field_type)
230231
and is_field_required
231232
and isinstance(field_type, type)
232-
and safe_issubclass(field_type, BaseDoc)
233+
and safe_issubclass(field_type, BaseDocWithoutId)
233234
):
234235
# calling __class_getitem__ ourselves is a hack otherwise mypy complain
235236
# most likely a bug in mypy though
@@ -273,7 +274,7 @@ def to_doc_vec(
273274
@classmethod
274275
def _docarray_validate(
275276
cls: Type[T],
276-
value: Union[T, Iterable[BaseDoc]],
277+
value: Union[T, Iterable[BaseDocWithoutId]],
277278
):
278279
from docarray.array.doc_vec.doc_vec import DocVec
279280

@@ -333,9 +334,9 @@ def __getitem__(self, item):
333334
return super().__getitem__(item)
334335

335336
@classmethod
336-
def __class_getitem__(cls, item: Union[Type[BaseDoc], TypeVar, str]):
337+
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
337338

338-
if isinstance(item, type) and safe_issubclass(item, BaseDoc):
339+
if isinstance(item, type) and safe_issubclass(item, BaseDocWithoutId):
339340
return AnyDocArray.__class_getitem__.__func__(cls, item) # type: ignore
340341
if (
341342
isinstance(item, object)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from docarray import DocList
2+
from docarray.base_doc.doc import BaseDocWithoutId
3+
4+
5+
def test_doc_list():
6+
class A(BaseDocWithoutId):
7+
text: str
8+
9+
cls_doc_list = DocList[A]
10+
11+
assert isinstance(cls_doc_list, type)

0 commit comments

Comments
 (0)