Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,17 @@ def _iter_file_extensions(ps):
num_docs += 1
if size is not None and num_docs >= size:
break


def _shallow_copy_doc(doc):
cls = doc.__class__
shallow_copy = cls.__new__(cls)

field_set = set(doc.__fields_set__)
object.__setattr__(shallow_copy, '__fields_set__', field_set)

for field_name, field_ in doc.__fields__.items():
val = doc.__getattr__(field_name)
setattr(shallow_copy, field_name, val)

return shallow_copy
118 changes: 111 additions & 7 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
import numpy as np

from docarray import BaseDoc, DocList
from docarray.array.any_array import AnyDocArray
from docarray.helper import _shallow_copy_doc
from docarray.index.abstract import BaseDocIndex, _raise_not_supported
from docarray.index.backends.helper import (
_collect_query_args,
_execute_find_and_filter_query,
)
from docarray.typing import AnyTensor, NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import safe_issubclass
from docarray.utils.filter import filter_docs
from docarray.utils.find import (
FindResult,
Expand Down Expand Up @@ -69,6 +72,11 @@ def __init__(
self._docs = DocList.__class_getitem__(
cast(Type[BaseDoc], self._schema)
).load_binary(file=index_file_path)

data_by_columns = self._get_col_value_dict(self._docs)
self._update_subindex_data(self._docs)
self._index_subindex(data_by_columns)

else:
self._logger.warning(
f'Index file does not exist: {index_file_path}. '
Expand Down Expand Up @@ -101,6 +109,13 @@ def python_type_to_db_type(self, python_type: Type) -> Any:
"""
return python_type

@property
def out_schema(self) -> Type[BaseDoc]:
"""Return the original schema (without the parent_id from new_schema type)"""
if self._is_subindex:
return self._ori_schema
return cast(Type[BaseDoc], self._schema)

class QueryBuilder(BaseDocIndex.QueryBuilder):
def __init__(self, query: Optional[List[Tuple[str, Dict]]] = None):
super().__init__()
Expand Down Expand Up @@ -152,6 +167,12 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
# implementing the public option because conversion to column dict is not needed
docs = self._validate_docs(docs)
self._docs.extend(docs)

# Add parent_id to all sub-index documents and store sub-index documents
data_by_columns = self._get_col_value_dict(docs)
self._update_subindex_data(docs)
self._index_subindex(data_by_columns)

self._rebuild_embedding()

def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
Expand Down Expand Up @@ -184,6 +205,17 @@ def _del_items(self, doc_ids: Sequence[str]):

:param doc_ids: ids to delete from the Document Store
"""
for field_, type_, _ in self._flatten_schema(cast(Type[BaseDoc], self._schema)):
if safe_issubclass(type_, AnyDocArray):
for id in doc_ids:
doc_ = self._get_items([id])
if len(doc_) == 0:
raise KeyError(
f"The document (id = '{id}') does not exist in the ExactNNIndexer."
)
sub_ids = [sub_doc.id for sub_doc in getattr(doc_[0], field_)]
del self._subindices[field_][sub_ids]

indices = []
for i, doc in enumerate(self._docs):
if doc.id in doc_ids:
Expand All @@ -192,21 +224,58 @@ def _del_items(self, doc_ids: Sequence[str]):
del self._docs[indices]
self._rebuild_embedding()

def _ori_items(self, doc: BaseDoc) -> BaseDoc:
"""
The Indexer's backend stores parent_id to support nested data. However,
this method enables us to retrieve the original items in their original
type, which is what the user interacts with.

:param doc: The input document in New_Schema format from the Indexer's backend.
:return: The input document with its original schema.
"""

ori_doc = _shallow_copy_doc(doc)
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self.out_schema)
):
if safe_issubclass(type_, AnyDocArray):
_list = getattr(ori_doc, field_name)
for i, nested_doc in enumerate(_list):
sub_indexer: InMemoryExactNNIndex = cast(
InMemoryExactNNIndex, self._subindices[field_name]
)
nested_doc = self._subindices[field_name]._ori_schema(
**nested_doc.__dict__
)

_list[i] = sub_indexer._ori_items(nested_doc)

return ori_doc

def _get_items(
self, doc_ids: Sequence[str]
self, doc_ids: Sequence[str], raw: bool = False
) -> Union[Sequence[TSchema], Sequence[Dict[str, Any]]]:
"""Get Documents from the index, by `id`.
If no document is found, a KeyError is raised.

:param doc_ids: ids to get from the Document index
:param raw: if raw, output the new_schema type (with parent id)
:return: Sequence of Documents, sorted corresponding to the order of `doc_ids`.
Duplicate `doc_ids` can be omitted in the output.
"""
indices = []

out_docs = []
for i, doc in enumerate(self._docs):
if doc.id in doc_ids:
indices.append(i)
return self._docs[indices]
if raw:
out_docs.append(doc)
else:
ori_doc = self._ori_items(doc)
schema_cls = cast(Type[BaseDoc], self.out_schema)
new_doc = schema_cls(**ori_doc.__dict__)
out_docs.append(new_doc)

return out_docs

def execute_query(self, query: List[Tuple[str, Dict]], *args, **kwargs) -> Any:
"""
Expand Down Expand Up @@ -267,9 +336,17 @@ def find(
metric=config['space'],
cache=self._embedding_map,
)
docs_with_schema = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))(
docs
)

docs_ = []
for doc in docs:
ori_doc = self._ori_items(doc)
schema_cls = cast(Type[BaseDoc], self.out_schema)
docs_.append(schema_cls(**ori_doc.__dict__))

docs_with_schema = DocList.__class_getitem__(
cast(Type[BaseDoc], self.out_schema)
)(docs_)

return FindResult(documents=docs_with_schema, scores=scores)

def _find(
Expand Down Expand Up @@ -359,3 +436,30 @@ def _text_search_batched(
def persist(self, file: str = 'in_memory_index.bin') -> None:
"""Persist InMemoryExactNNIndex into a binary file."""
self._docs.save_binary(file=file)

def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
"""Get the root_id given the id of a subindex Document and the root and subindex name

:param id: id of the subindex Document
:param root: root index name
:param sub: subindex name
:return: the root_id of the Document
"""
subindex: InMemoryExactNNIndex = cast(
InMemoryExactNNIndex, self._subindices[root]
)

if not sub:
sub_doc = subindex._get_items([id], raw=True)
parent_id = (
sub_doc[0]['parent_id']
if isinstance(sub_doc[0], dict)
else sub_doc[0].parent_id
)
return parent_id
else:
fields = sub.split('__')
cur_root_id = subindex._get_root_doc_id(
id, fields[0], '__'.join(fields[1:])
)
return self._get_root_doc_id(cur_root_id, root, '')
1 change: 0 additions & 1 deletion tests/index/hnswlib/test_subindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def test_subindex_get(index):
doc = index['1']
assert type(doc) == MyDoc
assert doc.id == '1'

assert len(doc.docs) == 5
assert type(doc.docs[0]) == SimpleDoc
assert doc.docs[0].id == 'docs-1-0'
Expand Down
53 changes: 53 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,56 @@ def find_similar_docs(
matches, scores = find_similar_docs(db, queries, 'embedding', 5)
assert len(matches) == num_queries
assert len(matches[0]) == 5


def test_nested_document_find():
from numpy import all

from docarray.typing import VideoUrl

class VideoDoc(BaseDoc):
url: VideoUrl
tensor_video: NdArray[256]

class MyDoc(BaseDoc):
docs: DocList[VideoDoc]
tensor: NdArray[256]

doc_index = InMemoryExactNNIndex[MyDoc]()

index_docs = [
MyDoc(
id=f'{i}',
docs=DocList[VideoDoc](
[
VideoDoc(
url=f'http://example.ai/videos/{i}-{j}',
tensor_video=(np.ones(256)) * i,
)
for j in range(10)
]
),
tensor=np.ones(256),
)
for i in range(10)
]

# index the Documents
doc_index.index(index_docs)

root_docs, sub_docs, scores = doc_index.find_subindex(
np.ones(256), subindex='docs', search_field='tensor_video', limit=5
)

assert doc_index.num_docs() == 10
assert doc_index._subindices['docs'].num_docs() == 100

assert type(sub_docs) == DocList[VideoDoc]
assert type(sub_docs[0]) == VideoDoc
assert type(root_docs[0]) == MyDoc
assert len(scores) == 5
assert all(scores) == 1.0

del doc_index['0']
assert doc_index.num_docs() == 9
assert doc_index._subindices['docs'].num_docs() == 90
Loading