Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from docarray.utils.find import (
FindResult,
FindResultBatched,
_da_attr_type,
_extract_embeddings,
_FindResult,
_FindResultBatched,
find,
Expand Down Expand Up @@ -85,6 +87,8 @@ def __init__(
cast(Type[BaseDoc], self._schema)
)()

self._embedding_map: Dict[str, Tuple[AnyTensor, Optional[List[int]]]] = {}
Comment thread
JoanFM marked this conversation as resolved.

def python_type_to_db_type(self, python_type: Type) -> Any:
"""Map python type to database type.
Takes any python type and returns the corresponding database column type.
Expand Down Expand Up @@ -146,6 +150,7 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
# implementing the public option because conversion to column dict is not needed
docs = self._validate_docs(docs)
self._docs.extend(docs)
self._rebuild_embedding()

def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
raise NotImplementedError
Expand All @@ -156,6 +161,22 @@ def num_docs(self) -> int:
"""
return len(self._docs)

def _rebuild_embedding(self):
"""
Reconstructs the embeddings map for each field. This is performed to store pre-stacked
embeddings, thereby optimizing performance by avoiding repeated stacking of embeddings.

Note: '_embedding_map' is a dictionary mapping fields to their corresponding embeddings.
"""
if self.num_docs() == 0:
self._embedding_map = dict()
else:
for field_, embedding in self._embedding_map.items():
embedding_type = _da_attr_type(self._docs, field_)
self._embedding_map[field_] = _extract_embeddings(
self._docs, field_, embedding_type
)

def _del_items(self, doc_ids: Sequence[str]):
"""Delete Documents from the index.

Expand All @@ -167,6 +188,7 @@ def _del_items(self, doc_ids: Sequence[str]):
indices.append(i)

del self._docs[indices]
self._rebuild_embedding()

def _get_items(
self, doc_ids: Sequence[str]
Expand Down Expand Up @@ -241,6 +263,7 @@ def find(
search_field=search_field,
limit=limit,
metric=config['space'],
cache=self._embedding_map,
)
docs_with_schema = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))(
docs
Expand Down Expand Up @@ -285,6 +308,7 @@ def find_batched(
search_field=search_field,
limit=limit,
metric=config['space'],
cache=self._embedding_map,
)

return find_res
Expand Down
24 changes: 16 additions & 8 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ['find', 'find_batched']

from typing import Any, Dict, List, NamedTuple, Optional, Type, Union, cast, Tuple
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast

from typing_inspect import is_union_type

Expand Down Expand Up @@ -47,6 +47,7 @@ def find(
limit: int = 10,
device: Optional[str] = None,
descending: Optional[bool] = None,
cache: Optional[Dict[str, Tuple[AnyTensor, Optional[List[int]]]]] = None,
) -> FindResult:
"""
Find the closest Documents in the index to the query.
Expand Down Expand Up @@ -119,6 +120,7 @@ class MyDocument(BaseDoc):
limit=limit,
device=device,
descending=descending,
cache=cache,
)
return FindResult(documents=docs[0], scores=scores[0])

Expand All @@ -131,6 +133,7 @@ def find_batched(
limit: int = 10,
device: Optional[str] = None,
descending: Optional[bool] = None,
cache: Optional[Dict[str, Tuple[AnyTensor, Optional[List[int]]]]] = None,
) -> FindResultBatched:
"""
Find the closest Documents in the index to the queries.
Expand Down Expand Up @@ -206,9 +209,17 @@ class MyDocument(BaseDoc):
comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
index_embeddings, valid_idx = _extract_embeddings(
index, search_field, embedding_type
)
if cache is not None and search_field in cache:
index_embeddings, valid_idx = cache[search_field]
else:
index_embeddings, valid_idx = _extract_embeddings(
index, search_field, embedding_type
)
if cache is not None:
cache[search_field] = (
index_embeddings,
valid_idx,
) # cache embedding for next query
query_embeddings, _ = _extract_embeddings(query, search_field, embedding_type)

# compute distances and return top results
Expand All @@ -226,10 +237,7 @@ class MyDocument(BaseDoc):
for _, (indices_per_query, scores_per_query) in enumerate(
zip(top_indices, top_scores)
):
doc_type = cast(Type[BaseDoc], index.doc_type)
docs_per_query: DocList = DocList.__class_getitem__(doc_type)()
for idx in indices_per_query: # workaround until #930 is fixed
docs_per_query.append(candidate_index[int(idx)])
docs_per_query: DocList = candidate_index[indices_per_query]
batched_docs.append(docs_per_query)
scores.append(scores_per_query)
return FindResultBatched(documents=batched_docs, scores=scores)
Expand Down
91 changes: 89 additions & 2 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Optional

import numpy as np
import pytest
from pydantic import Field
from typing import Optional
from torch import rand

from docarray import BaseDoc, DocList
from docarray.index.backends.in_memory import InMemoryExactNNIndex
from docarray.typing import NdArray
from docarray.typing import NdArray, TorchTensor


class SchemaDoc(BaseDoc):
Expand Down Expand Up @@ -162,3 +164,88 @@ class DocTest(BaseDoc):
assert len(res.documents) == 50
for doc in res.documents:
assert doc.index % 2 != 0


def test_index_avoid_stack_embedding():
class MyDoc(BaseDoc):
embedding1: TorchTensor
embedding2: TorchTensor
embedding3: TorchTensor

data = DocList[MyDoc](
[
MyDoc(
embedding1=rand(128),
embedding2=rand(128),
embedding3=rand(128),
)
for _ in range(10)
]
)

db = InMemoryExactNNIndex[MyDoc](data)

query = MyDoc(
embedding1=rand(128),
embedding2=rand(128),
embedding3=rand(128),
)

for i in range(3):
db.find(query, search_field=f"embedding{i + 1}")
assert len(db._embedding_map) == i + 1

data_copy = data.copy()

for i in range(9):
db._del_items(data_copy[i].id)
assert db._embedding_map["embedding1"][0].shape[0] == db.num_docs()

db._del_items(data_copy[9].id) # Delete the last element
assert len(db._embedding_map) == 0


def test_index_find_speedup():
class MyDocument(BaseDoc):
embedding: TorchTensor
embedding2: TorchTensor
embedding3: TorchTensor

def generate_doc_list(num_docs: int, dims: int) -> DocList[MyDocument]:
return DocList[MyDocument](
[
MyDocument(
embedding=rand(dims),
embedding2=rand(dims),
embedding3=rand(dims),
)
for _ in range(num_docs)
]
)

def create_inmemory_index(
data_list: DocList[MyDocument],
) -> InMemoryExactNNIndex[MyDocument]:
return InMemoryExactNNIndex[MyDocument](data_list)

def find_similar_docs(
index: InMemoryExactNNIndex[MyDocument],
queries: DocList[MyDocument],
search_field: str = 'embedding',
limit: int = 5,
) -> tuple:
return index.find_batched(queries, search_field=search_field, limit=limit)

# Generating document lists
num_docs, num_queries, dims = 2000, 1000, 128
data_list = generate_doc_list(num_docs, dims)
queries = generate_doc_list(num_queries, dims)

# Creating index
db = create_inmemory_index(data_list)

# Finding similar documents
for _ in range(5):
matches, scores = find_similar_docs(db, queries, 'embedding', 5)
assert len(matches) == num_queries
assert len(matches[0]) == 5