@@ -23,6 +23,16 @@ class _FindResult(NamedTuple):
2323 scores : AnyTensor
2424
2525
26+ class FindResultBatched (NamedTuple ):
27+ documents : List [DocList ]
28+ scores : List [AnyTensor ]
29+
30+
31+ class _FindResultBatched (NamedTuple ):
32+ documents : Union [List [DocList ], List [List [Dict [str , Any ]]]]
33+ scores : List [AnyTensor ]
34+
35+
2636def find (
2737 index : AnyDocArray ,
2838 query : Union [AnyTensor , BaseDoc ],
@@ -95,15 +105,16 @@ class MyDocument(BaseDoc):
95105 and the second element contains the corresponding scores.
96106 """
97107 query = _extract_embedding_single (query , search_field )
98- return find_batched (
108+ docs , scores = find_batched (
99109 index = index ,
100110 query = query ,
101111 search_field = search_field ,
102112 metric = metric ,
103113 limit = limit ,
104114 device = device ,
105115 descending = descending ,
106- )[0 ]
116+ )
117+ return FindResult (documents = docs [0 ], scores = scores [0 ])
107118
108119
109120def find_batched (
@@ -114,7 +125,7 @@ def find_batched(
114125 limit : int = 10 ,
115126 device : Optional [str ] = None ,
116127 descending : Optional [bool ] = None ,
117- ) -> List [ FindResult ] :
128+ ) -> FindResultBatched :
118129 """
119130 Find the closest Documents in the index to the queries.
120131 Supports PyTorch and NumPy embeddings.
@@ -142,23 +153,23 @@ class MyDocument(BaseDoc):
142153
143154 # use DocList as query
144155 query = DocList[MyDocument]([MyDocument(embedding=torch.rand(128)) for _ in range(3)])
145- results = find_batched(
156+ docs, scores = find_batched(
146157 index=index,
147158 query=query,
148159 search_field='embedding',
149160 metric='cosine_sim',
150161 )
151- top_matches, scores = results [0]
162+ top_matches, scores = docs[0], scores [0]
152163
153164 # use tensor as query
154165 query = torch.rand(3, 128)
155- results = find_batched(
166+ docs, scores = find_batched(
156167 index=index,
157168 query=query,
158169 search_field='embedding',
159170 metric='cosine_sim',
160171 )
161- top_matches, scores = results [0]
172+ top_matches, scores = docs[0], scores [0]
162173 ```
163174
164175 ---
@@ -176,8 +187,8 @@ class MyDocument(BaseDoc):
176187 can be either `cpu` or a `cuda` device.
177188 :param descending: sort the results in descending order.
178189 Per default, this is chosen based on the `metric` argument.
179- :return: a list of named tuples of the form (DocList, AnyTensor),
180- where the first element contains the closes matches for each query,
190+ :return: A named tuple of the form (DocList, AnyTensor),
191+ where the first element contains the closest matches for each query,
181192 and the second element contains the corresponding scores.
182193 """
183194 if descending is None :
@@ -197,14 +208,17 @@ class MyDocument(BaseDoc):
197208 dists , k = limit , device = device , descending = descending
198209 )
199210
200- results = []
201- for indices_per_query , scores_per_query in zip (top_indices , top_scores ):
211+ batched_docs : List [DocList ] = []
212+ scores = []
213+ for _ , (indices_per_query , scores_per_query ) in enumerate (
214+ zip (top_indices , top_scores )
215+ ):
202216 docs_per_query : DocList = DocList ([])
203217 for idx in indices_per_query : # workaround until #930 is fixed
204218 docs_per_query .append (index [idx ])
205- docs_per_query = DocList (docs_per_query )
206- results .append (FindResult ( scores = scores_per_query , documents = docs_per_query ) )
207- return results
219+ batched_docs . append ( DocList (docs_per_query ) )
220+ scores .append (scores_per_query )
221+ return FindResultBatched ( documents = batched_docs , scores = scores )
208222
209223
210224def _extract_embedding_single (
0 commit comments