Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ['find', 'find_batched']

from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast
from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast, Tuple

from typing_inspect import is_union_type

Expand Down Expand Up @@ -141,6 +141,8 @@ def find_batched(
search using approximate nearest neighbours search or hybrid search or
multi vector search please take a look at the [`BaseDoc`][docarray.base_doc.doc.BaseDoc]

!!! note
Only non-None embeddings will be considered from the `index` array

---

Expand Down Expand Up @@ -204,8 +206,10 @@ class MyDocument(BaseDoc):
comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
index_embeddings = _extract_embeddings(index, search_field, embedding_type)
query_embeddings = _extract_embeddings(query, search_field, embedding_type)
index_embeddings, valid_idx = _extract_embeddings(
index, search_field, embedding_type
)
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)

# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
Expand All @@ -215,14 +219,17 @@ class MyDocument(BaseDoc):
)

batched_docs: List[DocList] = []
candidate_index = index
if valid_idx is not None and len(valid_idx) < len(index):
candidate_index = index[valid_idx]
scores = []
for _, (indices_per_query, scores_per_query) in enumerate(
zip(top_indices, top_scores)
):
doc_type = cast(Type[BaseDoc], index.doc_type)
docs_per_query: DocList = DocList.__class_getitem__(doc_type)()
for idx in indices_per_query: # workaround until #930 is fixed
docs_per_query.append(index[int(idx)])
docs_per_query.append(candidate_index[int(idx)])
batched_docs.append(docs_per_query)
scores.append(scores_per_query)
return FindResultBatched(documents=batched_docs, scores=scores)
Expand Down Expand Up @@ -255,17 +262,23 @@ def _extract_embeddings(
data: Union[AnyDocArray, BaseDoc, AnyTensor],
search_field: str,
embedding_type: Type,
) -> AnyTensor:
) -> Tuple[AnyTensor, Optional[List[int]]]:
"""Extract the embeddings from the data.

:param data: the data
:param search_field: the embedding field
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: the embeddings
:return: a tuple of the embeddings and optionally a list of the non-null indices
"""
emb: AnyTensor
valid_idx = None
if isinstance(data, DocList):
emb_list = list(AnyDocArray._traverse(data, search_field))
emb_valid = [
(emb, i)
for i, emb in enumerate(AnyDocArray._traverse(data, search_field))
if emb is not None
]
emb_list, valid_idx = zip(*emb_valid)
Comment thread
JoanFM marked this conversation as resolved.
emb = embedding_type._docarray_stack(emb_list)
elif isinstance(data, (DocVec, BaseDoc)):
emb = next(AnyDocArray._traverse(data, search_field))
Expand All @@ -274,7 +287,7 @@ def _extract_embeddings(

if len(emb.shape) == 1:
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
return emb
return emb, valid_idx


def _da_attr_type(docs: AnyDocArray, access_path: str) -> Type[AnyTensor]:
Expand Down
21 changes: 21 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from pydantic import Field
from typing import Optional

from docarray import BaseDoc, DocList
from docarray.index.backends.in_memory import InMemoryExactNNIndex
Expand Down Expand Up @@ -141,3 +142,23 @@ def test_save_and_load(doc_index, tmpdir):
)

assert newer_doc_index.num_docs() == 0


def test_index_with_None_embedding():
class DocTest(BaseDoc):
index: int
embedding: Optional[NdArray[4]]

# Some of the documents have the embedding field set to None
dl = DocList[DocTest](
[
DocTest(index=i, embedding=np.random.rand(4) if i % 2 else None)
for i in range(100)
]
)

index = InMemoryExactNNIndex[DocTest](dl)
res = index.find(np.random.rand(4), search_field="embedding", limit=70)
assert len(res.documents) == 50
for doc in res.documents:
assert doc.index % 2 != 0