Skip to content

Commit c96707a

Browse files
authored
feat: InMemoryExactNNIndex pre filtering (#1713)
Signed-off-by: jupyterjazz <saba.sturua@jina.ai>
1 parent 2a866ae commit c96707a

2 files changed

Lines changed: 64 additions & 24 deletions

File tree

docarray/index/backends/in_memory.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from docarray.array.any_array import AnyDocArray
2323
from docarray.helper import _shallow_copy_doc
2424
from docarray.index.abstract import BaseDocIndex, _raise_not_supported
25-
from docarray.index.backends.helper import (
26-
_collect_query_args,
27-
_execute_find_and_filter_query,
28-
)
25+
from docarray.index.backends.helper import _collect_query_args
2926
from docarray.typing import AnyTensor, NdArray
3027
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3128
from docarray.utils._internal._typing import safe_issubclass
@@ -293,12 +290,44 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
293290
raise ValueError(
294291
f'args and kwargs not supported for `execute_query` on {type(self)}'
295292
)
296-
find_res = _execute_find_and_filter_query(
297-
doc_index=self,
298-
query=query,
299-
reverse_order=True,
300-
)
301-
return find_res
293+
return self._find_and_filter(query)
294+
295+
def _find_and_filter(self, query: List[Tuple[str, Dict]]) -> FindResult:
296+
"""
297+
The function executes search operations such as 'find' and 'filter' in the order
298+
they appear in the query. The 'find' operation performs a vector similarity search.
299+
The 'filter' operation filters out documents based on a filter query.
300+
The documents are finally sorted based on their scores.
301+
302+
:param query: The query to execute.
303+
:return: A tuple of retrieved documents and their scores.
304+
"""
305+
out_docs = self._docs
306+
doc_to_score: Dict[BaseDoc, Any] = {}
307+
for op, op_kwargs in query:
308+
if op == 'find':
309+
out_docs, scores = find(
310+
index=out_docs,
311+
query=op_kwargs['query'],
312+
search_field=op_kwargs['search_field'],
313+
limit=op_kwargs.get('limit', len(out_docs)),
314+
metric=self._column_infos[op_kwargs['search_field']].config[
315+
'space'
316+
],
317+
)
318+
doc_to_score.update(zip(out_docs.id, scores))
319+
elif op == 'filter':
320+
out_docs = filter_docs(out_docs, op_kwargs['filter_query'])
321+
if 'limit' in op_kwargs:
322+
out_docs = out_docs[: op_kwargs['limit']]
323+
else:
324+
raise ValueError(f'Query operation is not supported: {op}')
325+
326+
scores_and_docs = zip([doc_to_score[doc.id] for doc in out_docs], out_docs)
327+
sorted_lists = sorted(scores_and_docs, reverse=True)
328+
out_scores, out_docs = zip(*sorted_lists)
329+
330+
return FindResult(documents=out_docs, scores=out_scores)
302331

303332
def find(
304333
self,

tests/index/in_memory/test_in_memory.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
tf_available = is_tf_available()
1616
if tf_available:
1717
import tensorflow as tf
18-
from docarray.typing import TensorFlowTensor
1918

2019

2120
class SchemaDoc(BaseDoc):
@@ -165,37 +164,49 @@ def test_with_text_doc_torch():
165164
assert len(r) == 5
166165

167166

168-
def test_concatenated_queries(doc_index):
169-
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))
170-
167+
def test_query_builder_pre_filtering(doc_index):
171168
q = (
172169
doc_index.build_query()
173-
.find(query=query, search_field='tensor', limit=5)
174-
.filter(filter_query={'price': {'$neq': 5}})
170+
.filter(filter_query={'price': {'$lte': 3}})
171+
.find(query=np.ones(10), search_field='tensor', limit=5)
175172
.build()
176173
)
177174

178175
docs, scores = doc_index.execute_query(q)
179176

180177
assert len(docs) == 4
178+
for doc in docs:
179+
assert doc.price <= 3
181180

182181

183-
@pytest.mark.parametrize(
184-
'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 1)]
185-
)
186-
def test_query_builder_limits(doc_index, find_limit, filter_limit, expected_docs):
187-
query = SchemaDoc(text='query', price=3, tensor=np.array([3] * 10))
182+
def test_query_builder_post_filtering(doc_index):
183+
q = (
184+
doc_index.build_query()
185+
.find(query=np.ones(10), search_field='tensor')
186+
.filter(filter_query={'price': {'$gt': 3}}, limit=5)
187+
.build()
188+
)
188189

190+
docs, scores = doc_index.execute_query(q)
191+
192+
assert len(docs) == 5
193+
for doc in docs:
194+
assert doc.price > 3
195+
196+
197+
def test_query_builder_pre_post_filtering(doc_index):
189198
q = (
190199
doc_index.build_query()
191-
.find(query=query, search_field='tensor', limit=find_limit)
192-
.filter(filter_query={'price': {'$lte': 5}}, limit=filter_limit)
200+
.filter(filter_query={'price': {'$lte': 3}})
201+
.find(query=np.ones(10), search_field='tensor')
202+
.filter(filter_query={'text': {'$eq': 'hello 1'}})
193203
.build()
194204
)
195205

196206
docs, scores = doc_index.execute_query(q)
197207

198-
assert len(docs) == expected_docs
208+
assert len(docs) == 1
209+
assert docs[0].text == 'hello 1' and docs[0].price <= 3
199210

200211

201212
def test_filter(doc_index):

0 commit comments

Comments
 (0)