Skip to content

Commit f65e023

Browse files
Charlotte Gerhaheralexcg1
andauthored
feat: add in-memory doc index (#1441)
Signed-off-by: anna-charlotte <charlotte.gerhaher@jina.ai> Signed-off-by: Charlotte Gerhaher <charlotte.gerhaher@jina.ai> Co-authored-by: Alex Cureton-Griffiths <alexcg1@users.noreply.github.com>
1 parent 7febaca commit f65e023

10 files changed

Lines changed: 696 additions & 43 deletions

File tree

docarray/index/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import types
22
from typing import TYPE_CHECKING
33

4+
from docarray.index.backends.in_memory import InMemoryDocIndex
45
from docarray.utils._internal.misc import (
56
_get_path_from_docarray_root_level,
67
import_library,
@@ -13,7 +14,7 @@
1314
from docarray.index.backends.qdrant import QdrantDocumentIndex # noqa: F401
1415
from docarray.index.backends.weaviate import WeaviateDocumentIndex # noqa: F401
1516

16-
__all__ = []
17+
__all__ = ['InMemoryDocIndex']
1718

1819

1920
def __getattr__(name: str):

docarray/index/backends/helper.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
from typing import Any, Dict, List, Tuple, Type, cast
2+
3+
from docarray import BaseDoc, DocList
4+
from docarray.index.abstract import BaseDocIndex
5+
from docarray.utils.filter import filter_docs
6+
from docarray.utils.find import FindResult
7+
8+
9+
def _collect_query_args(method_name: str): # TODO: use partialmethod instead
10+
def inner(self, *args, **kwargs):
11+
if args:
12+
raise ValueError(
13+
f'Positional arguments are not supported for '
14+
f'`{type(self)}.{method_name}`.'
15+
f' Use keyword arguments instead.'
16+
)
17+
updated_query = self._queries + [(method_name, kwargs)]
18+
return type(self)(updated_query)
19+
20+
return inner
21+
22+
23+
def _execute_find_and_filter_query(
24+
doc_index: BaseDocIndex, query: List[Tuple[str, Dict]]
25+
) -> FindResult:
26+
"""
27+
Executes all find calls from query first using `doc_index.find()`,
28+
and filtering queries after that using DocArray's `filter_docs()`.
29+
30+
Text search is not supported.
31+
"""
32+
docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([])
33+
filter_conditions = []
34+
doc_to_score: Dict[BaseDoc, Any] = {}
35+
for op, op_kwargs in query:
36+
if op == 'find':
37+
docs, scores = doc_index.find(**op_kwargs)
38+
docs_found.extend(docs)
39+
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
40+
elif op == 'filter':
41+
filter_conditions.append(op_kwargs['filter_query'])
42+
else:
43+
raise ValueError(f'Query operation is not supported: {op}')
44+
45+
doc_index._logger.debug(f'Executing query {query}')
46+
docs_filtered = docs_found
47+
for cond in filter_conditions:
48+
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))
49+
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))
50+
51+
doc_index._logger.debug(f'{len(docs_filtered)} results found')
52+
docs_and_scores = zip(
53+
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
54+
)
55+
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
56+
out_docs, out_scores = zip(*docs_sorted)
57+
58+
return FindResult(documents=out_docs, scores=out_scores)

docarray/index/backends/hnswlib.py

Lines changed: 9 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@
2727
_raise_not_composable,
2828
_raise_not_supported,
2929
)
30+
from docarray.index.backends.helper import (
31+
_collect_query_args,
32+
_execute_find_and_filter_query,
33+
)
3034
from docarray.proto import DocProto
3135
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3236
from docarray.typing.tensor.ndarray import NdArray
3337
from docarray.utils._internal.misc import import_library, is_np_int
34-
from docarray.utils.filter import filter_docs
3538
from docarray.utils.find import _FindResult, _FindResultBatched
3639

3740
if TYPE_CHECKING:
@@ -61,20 +64,6 @@
6164
T = TypeVar('T', bound='HnswDocumentIndex')
6265

6366

64-
def _collect_query_args(method_name: str): # TODO: use partialmethod instead
65-
def inner(self, *args, **kwargs):
66-
if args:
67-
raise ValueError(
68-
f'Positional arguments are not supported for '
69-
f'`{type(self)}.{method_name}`.'
70-
f' Use keyword arguments instead.'
71-
)
72-
updated_query = self._queries + [(method_name, kwargs)]
73-
return type(self)(updated_query)
74-
75-
return inner
76-
77-
7867
class HnswDocumentIndex(BaseDocIndex, Generic[TSchema]):
7968
def __init__(self, db_config=None, **kwargs):
8069
"""Initialize HnswDocumentIndex"""
@@ -232,7 +221,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
232221

233222
def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
234223
"""
235-
Execute a query on the WeaviateDocumentIndex.
224+
Execute a query on the HnswDocumentIndex.
236225
237226
Can take two kinds of inputs:
238227
@@ -249,31 +238,11 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
249238
raise ValueError(
250239
f'args and kwargs not supported for `execute_query` on {type(self)}'
251240
)
252-
253-
ann_docs = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))([])
254-
filter_conditions = []
255-
doc_to_score: Dict[BaseDoc, Any] = {}
256-
for op, op_kwargs in query:
257-
if op == 'find':
258-
docs, scores = self.find(**op_kwargs)
259-
ann_docs.extend(docs)
260-
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
261-
elif op == 'filter':
262-
filter_conditions.append(op_kwargs['filter_query'])
263-
264-
self._logger.debug(f'Executing query {query}')
265-
docs_filtered = ann_docs
266-
for cond in filter_conditions:
267-
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
268-
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))
269-
270-
self._logger.debug(f'{len(docs_filtered)} results found')
271-
docs_and_scores = zip(
272-
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
241+
find_res = _execute_find_and_filter_query(
242+
doc_index=self,
243+
query=query,
273244
)
274-
docs_sorted = sorted(docs_and_scores, key=lambda x: x[1])
275-
out_docs, out_scores = zip(*docs_sorted)
276-
return _FindResult(documents=out_docs, scores=out_scores)
245+
return find_res
277246

278247
def _find_batched(
279248
self,

0 commit comments

Comments
 (0)