@@ -138,6 +138,7 @@ def __init__(self, db_config=None, **kwargs):
138
138
self ._column_names : List [str ] = []
139
139
self ._create_docs_table ()
140
140
self ._sqlite_conn .commit ()
141
+ self ._num_docs = 0 # recompute again when needed
141
142
self ._logger .info (f'{ self .__class__ .__name__ } has been initialized' )
142
143
143
144
@property
@@ -279,6 +280,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
279
280
self ._index (data_by_columns , docs_validated , ** kwargs )
280
281
self ._send_docs_to_sqlite (docs_validated )
281
282
self ._sqlite_conn .commit ()
283
+ self ._num_docs = 0 # recompute again when needed
282
284
283
285
def execute_query (self , query : List [Tuple [str , Dict ]], * args , ** kwargs ) -> Any :
284
286
"""
@@ -329,7 +331,19 @@ def _filter(
329
331
limit : int ,
330
332
) -> DocList :
331
333
rows = self ._execute_filter (filter_query = filter_query , limit = limit )
332
- return DocList [self .out_schema ](self ._doc_from_bytes (blob ) for _ , blob in rows ) # type: ignore[name-defined]
334
+ hashed_ids = [doc_id for doc_id , _ in rows ]
335
+ embeddings : OrderedDict [str , list ] = OrderedDict ()
336
+ for col_name , index in self ._hnsw_indices .items ():
337
+ embeddings [col_name ] = index .get_items (hashed_ids )
338
+
339
+ docs = DocList .__class_getitem__ (cast (Type [BaseDoc ], self .out_schema ))()
340
+ for i , row in enumerate (rows ):
341
+ reconstruct_embeddings = {}
342
+ for col_name in embeddings .keys ():
343
+ reconstruct_embeddings [col_name ] = embeddings [col_name ][i ]
344
+ docs .append (self ._doc_from_bytes (row [1 ], reconstruct_embeddings ))
345
+
346
+ return docs
333
347
334
348
def _filter_batched (
335
349
self ,
@@ -379,6 +393,7 @@ def _del_items(self, doc_ids: Sequence[str]):
379
393
380
394
self ._delete_docs_from_sqlite (doc_ids )
381
395
self ._sqlite_conn .commit ()
396
+ self ._num_docs = 0 # recompute again when needed
382
397
383
398
def _get_items (self , doc_ids : Sequence [str ], out : bool = True ) -> Sequence [TSchema ]:
384
399
"""Get Documents from the hnswlib index, by `id`.
@@ -403,7 +418,9 @@ def num_docs(self) -> int:
403
418
"""
404
419
Get the number of documents.
405
420
"""
406
- return self ._get_num_docs_sqlite ()
421
+ if self ._num_docs == 0 :
422
+ self ._num_docs = self ._get_num_docs_sqlite ()
423
+ return self ._num_docs
407
424
408
425
###############################################
409
426
# Helpers #
@@ -641,15 +658,18 @@ def accept_hashed_ids(id):
641
658
"""Accepts IDs that are in hashed_ids."""
642
659
return id in hashed_ids # type: ignore[operator]
643
660
644
- # Choose the appropriate filter function based on whether hashed_ids was provided
645
- extra_kwargs = {}
646
- if hashed_ids :
647
- extra_kwargs ['filter' ] = accept_hashed_ids
661
+ extra_kwargs = {'filter' : accept_hashed_ids } if hashed_ids else {}
648
662
649
663
# If hashed_ids is provided, k is the minimum of limit and the length of hashed_ids; else it is limit
650
664
k = min (limit , len (hashed_ids )) if hashed_ids else limit
651
665
index = self ._hnsw_indices [search_field ]
652
- labels , distances = index .knn_query (queries , k = int (limit ), ** extra_kwargs )
666
+
667
+ try :
668
+ labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
669
+ except RuntimeError :
670
+ k = min (k , self .num_docs ())
671
+ labels , distances = index .knn_query (queries , k = k , ** extra_kwargs )
672
+
653
673
result_das = [
654
674
self ._get_docs_sqlite_hashed_id (
655
675
ids_per_query .tolist (),
0 commit comments