|
22 | 22 | from docarray.array.any_array import AnyDocArray |
23 | 23 | from docarray.helper import _shallow_copy_doc |
24 | 24 | from docarray.index.abstract import BaseDocIndex, _raise_not_supported |
25 | | -from docarray.index.backends.helper import ( |
26 | | - _collect_query_args, |
27 | | - _execute_find_and_filter_query, |
28 | | -) |
| 25 | +from docarray.index.backends.helper import _collect_query_args |
29 | 26 | from docarray.typing import AnyTensor, NdArray |
30 | 27 | from docarray.typing.tensor.abstract_tensor import AbstractTensor |
31 | 28 | from docarray.utils._internal._typing import safe_issubclass |
@@ -293,12 +290,44 @@ def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any: |
293 | 290 | raise ValueError( |
294 | 291 | f'args and kwargs not supported for `execute_query` on {type(self)}' |
295 | 292 | ) |
296 | | - find_res = _execute_find_and_filter_query( |
297 | | - doc_index=self, |
298 | | - query=query, |
299 | | - reverse_order=True, |
300 | | - ) |
301 | | - return find_res |
| 293 | + return self._find_and_filter(query) |
| 294 | + |
| 295 | + def _find_and_filter(self, query: List[Tuple[str, Dict]]) -> FindResult: |
| 296 | + """ |
| 297 | + The function executes search operations such as 'find' and 'filter' in the order |
| 298 | + they appear in the query. The 'find' operation performs a vector similarity search. |
| 299 | + The 'filter' operation filters out documents based on a filter query. |
| 300 | + The documents are finally sorted based on their scores. |
| 301 | +
|
| 302 | + :param query: The query to execute. |
| 303 | + :return: A tuple of retrieved documents and their scores. |
| 304 | + """ |
| 305 | + out_docs = self._docs |
| 306 | + doc_to_score: Dict[BaseDoc, Any] = {} |
| 307 | + for op, op_kwargs in query: |
| 308 | + if op == 'find': |
| 309 | + out_docs, scores = find( |
| 310 | + index=out_docs, |
| 311 | + query=op_kwargs['query'], |
| 312 | + search_field=op_kwargs['search_field'], |
| 313 | + limit=op_kwargs.get('limit', len(out_docs)), |
| 314 | + metric=self._column_infos[op_kwargs['search_field']].config[ |
| 315 | + 'space' |
| 316 | + ], |
| 317 | + ) |
| 318 | + doc_to_score.update(zip(out_docs.id, scores)) |
| 319 | + elif op == 'filter': |
| 320 | + out_docs = filter_docs(out_docs, op_kwargs['filter_query']) |
| 321 | + if 'limit' in op_kwargs: |
| 322 | + out_docs = out_docs[: op_kwargs['limit']] |
| 323 | + else: |
| 324 | + raise ValueError(f'Query operation is not supported: {op}') |
| 325 | + |
| 326 | + scores_and_docs = zip([doc_to_score[doc.id] for doc in out_docs], out_docs) |
| 327 | + sorted_lists = sorted(scores_and_docs, reverse=True) |
| 328 | + out_scores, out_docs = zip(*sorted_lists) |
| 329 | + |
| 330 | + return FindResult(documents=out_docs, scores=out_scores) |
302 | 331 |
|
303 | 332 | def find( |
304 | 333 | self, |
|
0 commit comments