Skip to content

Commit fc1af11

Browse files
authored
refactor: change return type of find batched (#1339)
* refactor: change return type of find batched Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * refactor: change find batched to match return type Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * test: correct tests for find batched Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * fix: find return type Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * refactor: change scores type Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * refactor: change findresultbatched score type Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * refactor: change elastic return Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> * fix: elastic format response Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com> --------- Signed-off-by: Nikos Pitsillos <npitsillos@gmail.com>
1 parent c48d6c0 commit fc1af11

6 files changed

Lines changed: 96 additions & 86 deletions

File tree

docarray/index/abstract.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
Iterable,
1111
List,
1212
Mapping,
13-
NamedTuple,
1413
Optional,
1514
Sequence,
1615
Tuple,
@@ -30,7 +29,12 @@
3029
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3130
from docarray.utils._internal._typing import is_tensor_union
3231
from docarray.utils._internal.misc import import_library
33-
from docarray.utils.find import FindResult, _FindResult
32+
from docarray.utils.find import (
33+
FindResult,
34+
FindResultBatched,
35+
_FindResult,
36+
_FindResultBatched,
37+
)
3438

3539
if TYPE_CHECKING:
3640
import tensorflow as tf # type: ignore
@@ -47,16 +51,6 @@
4751
TSchema = TypeVar('TSchema', bound=BaseDoc)
4852

4953

50-
class FindResultBatched(NamedTuple):
51-
documents: List[DocList]
52-
scores: List[np.ndarray]
53-
54-
55-
class _FindResultBatched(NamedTuple):
56-
documents: Union[List[DocList], List[List[Dict[str, Any]]]]
57-
scores: List[np.ndarray]
58-
59-
6054
def _raise_not_composable(name):
6155
def _inner(self, *args, **kwargs):
6256
raise NotImplementedError(

docarray/index/backends/elastic.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,12 @@
2828

2929
import docarray.typing
3030
from docarray import BaseDoc
31-
from docarray.index.abstract import (
32-
BaseDocIndex,
33-
_ColumnInfo,
34-
_FindResultBatched,
35-
_raise_not_composable,
36-
)
31+
from docarray.index.abstract import BaseDocIndex, _ColumnInfo, _raise_not_composable
3732
from docarray.typing import AnyTensor
3833
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3934
from docarray.typing.tensor.ndarray import NdArray
4035
from docarray.utils._internal.misc import is_tf_available, is_torch_available
41-
from docarray.utils.find import _FindResult
36+
from docarray.utils.find import _FindResult, _FindResultBatched
4237

4338
TSchema = TypeVar('TSchema', bound=BaseDoc)
4439
T = TypeVar('T', bound='ElasticDocIndex')
@@ -387,7 +382,7 @@ def _find_batched(
387382
das, scores = zip(
388383
*[self._format_response(resp) for resp in responses['responses']]
389384
)
390-
return _FindResultBatched(documents=list(das), scores=np.array(scores))
385+
return _FindResultBatched(documents=list(das), scores=scores)
391386

392387
def _filter(
393388
self,
@@ -445,9 +440,7 @@ def _text_search_batched(
445440
das, scores = zip(
446441
*[self._format_response(resp) for resp in responses['responses']]
447442
)
448-
return _FindResultBatched(
449-
documents=list(das), scores=np.array(scores, dtype=object)
450-
)
443+
return _FindResultBatched(documents=list(das), scores=scores)
451444

452445
###############################################
453446
# Helpers #
@@ -544,7 +537,7 @@ def _format_response(self, response: Any) -> Tuple[List[Dict], NdArray]:
544537
docs.append(doc_dict)
545538
scores.append(result['_score'])
546539

547-
return docs, parse_obj_as(NdArray, scores)
540+
return docs, [parse_obj_as(NdArray, np.array(s)) for s in scores]
548541

549542
def _refresh(self, index_name: str):
550543
self._client.indices.refresh(index=index_name)

docarray/index/backends/hnswlib.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from docarray.index.abstract import (
2525
BaseDocIndex,
2626
_ColumnInfo,
27-
_FindResultBatched,
2827
_raise_not_composable,
2928
_raise_not_supported,
3029
)
@@ -33,7 +32,7 @@
3332
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3433
from docarray.utils._internal.misc import import_library, is_np_int
3534
from docarray.utils.filter import filter_docs
36-
from docarray.utils.find import _FindResult
35+
from docarray.utils.find import _FindResult, _FindResultBatched
3736

3837
if TYPE_CHECKING:
3938
import hnswlib

docarray/index/backends/qdrant.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,35 @@
11
import uuid
22
from dataclasses import dataclass, field
33
from typing import (
4-
TypeVar,
5-
Generic,
6-
Optional,
7-
cast,
8-
Sequence,
94
Any,
10-
Union,
11-
List,
125
Dict,
136
Generator,
14-
Type,
7+
Generic,
8+
List,
9+
Optional,
10+
Sequence,
1511
Tuple,
12+
Type,
13+
TypeVar,
14+
Union,
15+
cast,
1616
)
1717

1818
import numpy as np
19+
import qdrant_client
1920
from grpc import RpcError # type: ignore[import]
21+
from qdrant_client.conversions import common_types as types
22+
from qdrant_client.http import models as rest
2023
from qdrant_client.http.exceptions import UnexpectedResponse
2124

2225
import docarray.typing.id
2326
from docarray import BaseDoc, DocList
2427
from docarray.index.abstract import (
2528
BaseDocIndex,
26-
_FindResultBatched,
2729
_ColumnInfo,
30+
_FindResultBatched,
2831
_raise_not_composable,
2932
)
30-
31-
import qdrant_client
32-
from qdrant_client.conversions import common_types as types
33-
from qdrant_client.http import models as rest
34-
3533
from docarray.typing import NdArray
3634
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3735
from docarray.utils._internal.misc import torch_imported
@@ -391,7 +389,10 @@ def _find_batched(
391389
for response in responses
392390
],
393391
scores=[
394-
np.array([point.score for point in response]) for response in responses
392+
NdArray._docarray_from_native(
393+
np.array([point.score for point in response])
394+
)
395+
for response in responses
395396
],
396397
)
397398

@@ -454,7 +455,10 @@ def _text_search_batched(
454455
# semantic search over vectors. Thus, each document is scored with a value of 1
455456
return _FindResultBatched(
456457
documents=documents_batched,
457-
scores=[np.ones(len(docs)) for docs in documents_batched],
458+
scores=[
459+
NdArray._docarray_from_native(np.ones(len(docs)))
460+
for docs in documents_batched
461+
],
458462
)
459463

460464
def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:

docarray/utils/find.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ class _FindResult(NamedTuple):
2323
scores: AnyTensor
2424

2525

26+
class FindResultBatched(NamedTuple):
27+
documents: List[DocList]
28+
scores: List[AnyTensor]
29+
30+
31+
class _FindResultBatched(NamedTuple):
32+
documents: Union[List[DocList], List[List[Dict[str, Any]]]]
33+
scores: List[AnyTensor]
34+
35+
2636
def find(
2737
index: AnyDocArray,
2838
query: Union[AnyTensor, BaseDoc],
@@ -95,15 +105,16 @@ class MyDocument(BaseDoc):
95105
and the second element contains the corresponding scores.
96106
"""
97107
query = _extract_embedding_single(query, search_field)
98-
return find_batched(
108+
docs, scores = find_batched(
99109
index=index,
100110
query=query,
101111
search_field=search_field,
102112
metric=metric,
103113
limit=limit,
104114
device=device,
105115
descending=descending,
106-
)[0]
116+
)
117+
return FindResult(documents=docs[0], scores=scores[0])
107118

108119

109120
def find_batched(
@@ -114,7 +125,7 @@ def find_batched(
114125
limit: int = 10,
115126
device: Optional[str] = None,
116127
descending: Optional[bool] = None,
117-
) -> List[FindResult]:
128+
) -> FindResultBatched:
118129
"""
119130
Find the closest Documents in the index to the queries.
120131
Supports PyTorch and NumPy embeddings.
@@ -142,23 +153,23 @@ class MyDocument(BaseDoc):
142153
143154
# use DocList as query
144155
query = DocList[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)])
145-
results = find_batched(
156+
docs, scores = find_batched(
146157
index=index,
147158
query=query,
148159
search_field='embedding',
149160
metric='cosine_sim',
150161
)
151-
top_matches, scores = results[0]
162+
top_matches, scores = docs[0], scores[0]
152163
153164
# use tensor as query
154165
query = torch.rand(3, 128)
155-
results = find_batched(
166+
docs, scores = find_batched(
156167
index=index,
157168
query=query,
158169
search_field='embedding',
159170
metric='cosine_sim',
160171
)
161-
top_matches, scores = results[0]
172+
top_matches, scores = docs[0], scores[0]
162173
```
163174
164175
---
@@ -176,8 +187,8 @@ class MyDocument(BaseDoc):
176187
can be either `cpu` or a `cuda` device.
177188
:param descending: sort the results in descending order.
178189
Per default, this is chosen based on the `metric` argument.
179-
:return: a list of named tuples of the form (DocList, AnyTensor),
180-
where the first element contains the closes matches for each query,
190+
:return: A named tuple of the form (DocList, AnyTensor),
191+
where the first element contains the closest matches for each query,
181192
and the second element contains the corresponding scores.
182193
"""
183194
if descending is None:
@@ -197,14 +208,17 @@ class MyDocument(BaseDoc):
197208
dists, k=limit, device=device, descending=descending
198209
)
199210

200-
results = []
201-
for indices_per_query, scores_per_query in zip(top_indices, top_scores):
211+
batched_docs: List[DocList] = []
212+
scores = []
213+
for _, (indices_per_query, scores_per_query) in enumerate(
214+
zip(top_indices, top_scores)
215+
):
202216
docs_per_query: DocList = DocList([])
203217
for idx in indices_per_query: # workaround until #930 is fixed
204218
docs_per_query.append(index[idx])
205-
docs_per_query = DocList(docs_per_query)
206-
results.append(FindResult(scores=scores_per_query, documents=docs_per_query))
207-
return results
219+
batched_docs.append(DocList(docs_per_query))
220+
scores.append(scores_per_query)
221+
return FindResultBatched(documents=batched_docs, scores=scores)
208222

209223

210224
def _extract_embedding_single(

0 commit comments

Comments
 (0)