diff --git a/docarray/index/abstract.py b/docarray/index/abstract.py index 434d9734c82..2f8fddf70a1 100644 --- a/docarray/index/abstract.py +++ b/docarray/index/abstract.py @@ -230,6 +230,16 @@ def execute_query(self, query: Any, *args, **kwargs) -> Any: """ ... + @abstractmethod + def _doc_exists(self, doc_id: str) -> bool: + """ + Checks if a given document exists in the index. + + :param doc_id: The id of a document to check. + :return: True if the document exists in the index, False otherwise. + """ + ... + @abstractmethod def _find( self, @@ -403,6 +413,21 @@ def __delitem__(self, key: Union[str, Sequence[str]]): # delete data self._del_items(key) + def __contains__(self, item: BaseDoc) -> bool: + """ + Checks if a given document exists in the index. + + :param item: The document to check. + It must be an instance of BaseDoc or its subclass. + :return: True if the document exists in the index, False otherwise. + """ + if safe_issubclass(type(item), BaseDoc): + return self._doc_exists(str(item.id)) + else: + raise TypeError( + f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" + ) + def configure(self, runtime_config=None, **kwargs): """ Configure the DocumentIndex. @@ -1170,14 +1195,6 @@ def _get_root_doc_id(self, id: str, root: str, sub: str) -> str: ) return self._get_root_doc_id(cur_root_id, root, '') - def __contains__(self, item: BaseDoc) -> bool: - """Checks if a given BaseDoc item is contained in the index. - - :param item: the given BaseDoc - :return: if the given BaseDoc item is contained in the index - """ - return False # Will be overridden by backends - def subindex_contains(self, item: BaseDoc) -> bool: """Checks if a given BaseDoc item is contained in the index or any of its subindices. diff --git a/docarray/index/backends/elastic.py b/docarray/index/backends/elastic.py index e201dca4ac0..7981ba3d4e8 100644 --- a/docarray/index/backends/elastic.py +++ b/docarray/index/backends/elastic.py @@ -669,16 +669,11 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], List[Any]]: def _refresh(self, index_name: str): self._client.indices.refresh(index=index_name) - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - if len(item.id) == 0: - return False - ret = self._client_mget([item.id]) - return ret["docs"][0]["found"] - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + if len(doc_id) == 0: + return False + ret = self._client_mget([doc_id]) + return ret["docs"][0]["found"] ############################################### # API Wrappers # diff --git a/docarray/index/backends/hnswlib.py b/docarray/index/backends/hnswlib.py index b13e143a6c6..00124c7fbd0 100644 --- a/docarray/index/backends/hnswlib.py +++ b/docarray/index/backends/hnswlib.py @@ -393,18 +393,11 @@ def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSche raise KeyError(f'No document with id {doc_ids} found') return out_docs - def __contains__(self, item: BaseDoc): - if safe_issubclass(type(item), BaseDoc): - hash_id = self._to_hashed_id(item.id) - self._sqlite_cursor.execute( - f"SELECT data FROM docs WHERE doc_id = '{hash_id}'" - ) - rows = self._sqlite_cursor.fetchall() - return len(rows) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + hash_id = self._to_hashed_id(doc_id) + self._sqlite_cursor.execute(f"SELECT data FROM docs WHERE doc_id = '{hash_id}'") + rows = self._sqlite_cursor.fetchall() + return len(rows) > 0 def num_docs(self) -> int: """ diff --git a/docarray/index/backends/in_memory.py b/docarray/index/backends/in_memory.py index 62ee3f0ffee..8a8132c2394 100644 --- a/docarray/index/backends/in_memory.py +++ b/docarray/index/backends/in_memory.py @@ -431,13 +431,8 @@ def _text_search_batched( ) -> _FindResultBatched: raise NotImplementedError(f'{type(self)} does not support text search.') - def __contains__(self, item: BaseDoc): - if safe_issubclass(type(item), BaseDoc): - return any(doc.id == item.id for doc in self._docs) - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + return any(doc.id == doc_id for doc in self._docs) def persist(self, file: Optional[str] = None) -> None: """Persist InMemoryExactNNIndex into a binary file.""" diff --git a/docarray/index/backends/qdrant.py b/docarray/index/backends/qdrant.py index 0ddf77a1e96..5288da19881 100644 --- a/docarray/index/backends/qdrant.py +++ b/docarray/index/backends/qdrant.py @@ -317,21 +317,16 @@ def num_docs(self) -> int: """ return self._client.count(collection_name=self.collection_name).count - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - response, _ = self._client.scroll( - collection_name=self.index_name, - scroll_filter=rest.Filter( - must=[ - rest.HasIdCondition(has_id=[self._to_qdrant_id(item.id)]), - ], - ), - ) - return len(response) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) + def _doc_exists(self, doc_id: str) -> bool: + response, _ = self._client.scroll( + collection_name=self.index_name, + scroll_filter=rest.Filter( + must=[ + rest.HasIdCondition(has_id=[self._to_qdrant_id(doc_id)]), + ], + ), + ) + return len(response) > 0 def _del_items(self, doc_ids: Sequence[str]): items = self._get_items(doc_ids) diff --git a/docarray/index/backends/redis.py b/docarray/index/backends/redis.py index e82652d86e4..937c77efdaa 100644 --- a/docarray/index/backends/redis.py +++ b/docarray/index/backends/redis.py @@ -377,7 +377,7 @@ def _del_items(self, doc_ids: Sequence[str]) -> None: ): self._client.delete(*batch) - def _doc_exists(self, doc_id) -> bool: + def _doc_exists(self, doc_id: str) -> bool: """ Checks if a document exists in the index. @@ -610,18 +610,3 @@ def _text_search_batched( scores.append(results.scores) return _FindResultBatched(documents=docs, scores=scores) - - def __contains__(self, item: BaseDoc) -> bool: - """ - Checks if a given document exists in the index. - - :param item: The document to check. - It must be an instance of BaseDoc or its subclass. - :return: True if the document exists in the index, False otherwise. - """ - if safe_issubclass(type(item), BaseDoc): - return self._doc_exists(item.id) - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" - ) diff --git a/docarray/index/backends/weaviate.py b/docarray/index/backends/weaviate.py index 20b1c649c3b..b001888dd98 100644 --- a/docarray/index/backends/weaviate.py +++ b/docarray/index/backends/weaviate.py @@ -760,25 +760,20 @@ def _filter_by_parent_id(self, id: str) -> Optional[List[str]]: ] return ids - def __contains__(self, item: BaseDoc) -> bool: - if safe_issubclass(type(item), BaseDoc): - result = ( - self._client.query.get(self.index_name, ['docarrayid']) - .with_where( - { - "path": ['docarrayid'], - "operator": "Equal", - "valueString": f'{item.id}', - } - ) - .do() - ) - docs = result["data"]["Get"][self.index_name] - return docs is not None and len(docs) > 0 - else: - raise TypeError( - f"item must be an instance of BaseDoc or its subclass, not '{type(item).__name__}'" + def _doc_exists(self, doc_id: str) -> bool: + result = ( + self._client.query.get(self.index_name, ['docarrayid']) + .with_where( + { + "path": ['docarrayid'], + "operator": "Equal", + "valueString": f'{doc_id}', + } ) + .do() + ) + docs = result["data"]["Get"][self.index_name] + return docs is not None and len(docs) > 0 class QueryBuilder(BaseDocIndex.QueryBuilder): def __init__(self, document_index): diff --git a/tests/index/base_classes/test_base_doc_store.py b/tests/index/base_classes/test_base_doc_store.py index 69b63c57e88..09f46ee4535 100644 --- a/tests/index/base_classes/test_base_doc_store.py +++ b/tests/index/base_classes/test_base_doc_store.py @@ -97,6 +97,9 @@ def python_type_to_db_type(self, x): def num_docs(self): return 3 + def _doc_exists(self, doc_id: str) -> bool: + return False + _index = _identity _del_items = _identity _get_items = _identity diff --git a/tests/index/base_classes/test_configs.py b/tests/index/base_classes/test_configs.py index 7b7efbea596..b2a5f0ecfd5 100644 --- a/tests/index/base_classes/test_configs.py +++ b/tests/index/base_classes/test_configs.py @@ -35,7 +35,6 @@ class DBConfig(BaseDocIndex.DBConfig): @dataclass class RuntimeConfig(BaseDocIndex.RuntimeConfig): - default_ef: int = 50 @@ -61,6 +60,7 @@ def python_type_to_db_type(self, x): _filter_batched = _identity _text_search = _identity _text_search_batched = _identity + _doc_exists = _identity def test_defaults():