Skip to content

Commit f7bd4d0

Browse files
author
Joan Fontanals Martinez
committed
fix: reconstruct on filter
2 parents 02849c4 + 19aec21 commit f7bd4d0

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

docarray/index/backends/hnswlib.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def __init__(self, db_config=None, **kwargs):
138138
self._column_names: List[str] = []
139139
self._create_docs_table()
140140
self._sqlite_conn.commit()
141+
self._num_docs = 0 # recompute again when needed
141142
self._logger.info(f'{self.__class__.__name__} has been initialized')
142143

143144
@property
@@ -279,6 +280,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
279280
self._index(data_by_columns, docs_validated, **kwargs)
280281
self._send_docs_to_sqlite(docs_validated)
281282
self._sqlite_conn.commit()
283+
self._num_docs = 0 # recompute again when needed
282284

283285
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
284286
"""
@@ -329,7 +331,19 @@ def _filter(
329331
limit: int,
330332
) -> DocList:
331333
rows = self._execute_filter(filter_query=filter_query, limit=limit)
332-
return DocList[self.out_schema](self._doc_from_bytes(blob) for _, blob in rows) # type: ignore[name-defined]
334+
hashed_ids = [doc_id for doc_id, _ in rows]
335+
embeddings: OrderedDict[str, list] = OrderedDict()
336+
for col_name, index in self._hnsw_indices.items():
337+
embeddings[col_name] = index.get_items(hashed_ids)
338+
339+
docs = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))()
340+
for i, row in enumerate(rows):
341+
reconstruct_embeddings = {}
342+
for col_name in embeddings.keys():
343+
reconstruct_embeddings[col_name] = embeddings[col_name][i]
344+
docs.append(self._doc_from_bytes(row[1], reconstruct_embeddings))
345+
346+
return docs
333347

334348
def _filter_batched(
335349
self,
@@ -379,6 +393,7 @@ def _del_items(self, doc_ids: Sequence[str]):
379393

380394
self._delete_docs_from_sqlite(doc_ids)
381395
self._sqlite_conn.commit()
396+
self._num_docs = 0 # recompute again when needed
382397

383398
def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
384399
"""Get Documents from the hnswlib index, by `id`.
@@ -403,7 +418,9 @@ def num_docs(self) -> int:
403418
"""
404419
Get the number of documents.
405420
"""
406-
return self._get_num_docs_sqlite()
421+
if self._num_docs == 0:
422+
self._num_docs = self._get_num_docs_sqlite()
423+
return self._num_docs
407424

408425
###############################################
409426
# Helpers #
@@ -641,15 +658,18 @@ def accept_hashed_ids(id):
641658
"""Accepts IDs that are in hashed_ids."""
642659
return id in hashed_ids # type: ignore[operator]
643660

644-
# Choose the appropriate filter function based on whether hashed_ids was provided
645-
extra_kwargs = {}
646-
if hashed_ids:
647-
extra_kwargs['filter'] = accept_hashed_ids
661+
extra_kwargs = {'filter': accept_hashed_ids} if hashed_ids else {}
648662

649663
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
650664
k = min(limit, len(hashed_ids)) if hashed_ids else limit
651665
index = self._hnsw_indices[search_field]
652-
labels, distances = index.knn_query(queries, k=int(limit), **extra_kwargs)
666+
667+
try:
668+
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
669+
except RuntimeError:
670+
k = min(k, self.num_docs())
671+
labels, distances = index.knn_query(queries, k=k, **extra_kwargs)
672+
653673
result_das = [
654674
self._get_docs_sqlite_hashed_id(
655675
ids_per_query.tolist(),

0 commit comments

Comments
 (0)