Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docarray/index/backends/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _execute_find_and_filter_query(
"""
docs_found = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))([])
filter_conditions = []
filter_limit = None
doc_to_score: Dict[BaseDoc, Any] = {}
for op, op_kwargs in query:
if op == 'find':
Expand All @@ -39,6 +40,7 @@ def _execute_find_and_filter_query(
doc_to_score.update(zip(docs.__getattribute__('id'), scores))
elif op == 'filter':
filter_conditions.append(op_kwargs['filter_query'])
filter_limit = op_kwargs.get('limit')
Comment thread
JoanFM marked this conversation as resolved.
else:
raise ValueError(f'Query operation is not supported: {op}')

Expand All @@ -48,6 +50,9 @@ def _execute_find_and_filter_query(
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], doc_index._schema))
docs_filtered = docs_cls(filter_docs(docs_filtered, cond))

if filter_limit:
docs_filtered = docs_filtered[:filter_limit]

doc_index._logger.debug(f'{len(docs_filtered)} results found')
docs_and_scores = zip(
docs_filtered, (doc_to_score[doc.id] for doc in docs_filtered)
Expand Down
2 changes: 1 addition & 1 deletion docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def filter(
"""
self._logger.debug(f'Executing `filter` for the query {filter_query}')

docs = filter_docs(docs=self._docs, query=filter_query)
docs = filter_docs(docs=self._docs, query=filter_query)[:limit]
return cast(DocList, docs)

def _filter(self, filter_query: Any, limit: int) -> Union[DocList, List[Dict]]:
Expand Down
27 changes: 27 additions & 0 deletions tests/index/hnswlib/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,30 @@ class MyDoc(BaseDoc):
for q, matches in zip(queries, docs_responses):
assert len(matches) == 10
assert q.id == matches[0].id


@pytest.mark.parametrize(
'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 5)]
)
def test_query_builder_limits(find_limit, filter_limit, expected_docs, tmp_path):
class SimpleSchema(BaseDoc):
tensor: NdArray[10] = Field(space='l2')
price: int

index = HnswDocumentIndex[SimpleSchema](work_dir=str(tmp_path))

index_docs = [SimpleSchema(tensor=np.array([i] * 10), price=i) for i in range(10)]
index.index(index_docs)

query = SimpleSchema(tensor=np.array([3] * 10), price=3)

q = (
index.build_query()
.find(query=query, search_field='tensor', limit=find_limit)
.filter(filter_query={'price': {'$lte': 5}}, limit=filter_limit)
.build()
)

docs, scores = index.execute_query(q)

assert len(docs) == expected_docs
39 changes: 39 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,45 @@ def test_concatenated_queries(doc_index):
assert len(docs) == 4


@pytest.mark.parametrize(
'find_limit, filter_limit, expected_docs', [(10, 3, 3), (5, None, 3)]
)
def test_query_builder_limits(doc_index, find_limit, filter_limit, expected_docs):
query = SchemaDoc(text='query', price=3, tensor=np.array([3] * 10))

q = (
doc_index.build_query()
.find(query=query, search_field='tensor', limit=find_limit)
.filter(filter_query={'price': {'$lte': 5}}, limit=filter_limit)
.build()
)

docs, scores = doc_index.execute_query(q)

assert len(docs) == expected_docs


def test_filter(doc_index):
docs = doc_index.filter({'price': {'$eq': 3}})
assert len(docs) == 1
assert docs[0].price == 3

docs = doc_index.filter({'price': {'$lte': 5}})
assert len(docs) == 6
for doc in docs:
assert doc.price <= 5

docs = doc_index.filter({'price': {'$gte': 5}}, limit=3)
assert len(docs) == 3
for doc in docs:
assert doc.price >= 5

docs = doc_index.filter({'price': {'$neq': 2}}, limit=10)
assert len(docs) == 9
for doc in docs:
assert doc.price != 2


def test_save_and_load(doc_index, tmpdir):
initial_num_docs = doc_index.num_docs()

Expand Down