Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
603012d
feat: subindex init
AnneYang720 Apr 20, 2023
8c11e5e
feat: subindex init and index for elastic
AnneYang720 Apr 20, 2023
64a823c
feat: add parent_id to nested doclist
AnneYang720 Apr 20, 2023
867ab36
feat: get with subindex
AnneYang720 Apr 21, 2023
9c66ba8
feat: subindex index for hnsw and del
AnneYang720 Apr 21, 2023
b514e3a
feat: _validate_search_field subindex check
AnneYang720 Apr 21, 2023
f8a00ed
Merge remote-tracking branch 'origin/main' into feat-subindex
AnneYang720 Apr 21, 2023
ff7f5d5
feat: find support subindex
AnneYang720 Apr 21, 2023
b3eb8f2
feat: subindex for find_batched
AnneYang720 Apr 21, 2023
741c0de
feat: have dynamic parent_id
AnneYang720 Apr 25, 2023
def048e
fix: type judge in _convert_dict_to_doc
AnneYang720 Apr 25, 2023
4cc8521
feat: subindex del for hnswlib
AnneYang720 Apr 25, 2023
3bfc976
feat: filter_subindex
AnneYang720 Apr 25, 2023
9d53e33
Merge branch 'main' into feat-subindex
AnneYang720 Apr 25, 2023
3726410
feat: subindex find that returns two results
AnneYang720 Apr 26, 2023
4959fb0
Merge branch 'main' into feat-subindex
AnneYang720 Apr 26, 2023
b6c2bd4
fix: no parent_id in nest schema
AnneYang720 Apr 26, 2023
d1b3ada
fix: mypy
AnneYang720 Apr 26, 2023
44b656e
test: subindex tests for elastic
AnneYang720 Apr 26, 2023
8bcf7cc
fix: mypy
AnneYang720 Apr 26, 2023
aaef7b8
fix: __getitem__
AnneYang720 Apr 26, 2023
34af8a8
test: subindex tests for hnswlib
AnneYang720 Apr 26, 2023
9f644b0
fix: subindex init of elatic
AnneYang720 Apr 26, 2023
1ffe260
test: subindex tests for elastic v7
AnneYang720 Apr 26, 2023
5fd4e78
Merge branch 'main' into feat-subindex
AnneYang720 Apr 26, 2023
bff28a5
feat: create new schema for subindex
AnneYang720 Apr 26, 2023
19feecf
Merge branch 'main' into feat-subindex
AnneYang720 Apr 26, 2023
e88646f
fix: subindex return type
AnneYang720 Apr 27, 2023
c8b12cd
fix: mypy
AnneYang720 Apr 27, 2023
ef07185
test: add type check for subindex
AnneYang720 Apr 27, 2023
174127a
feat: store original subindex schema
AnneYang720 Apr 28, 2023
560fedb
fix: mypy
AnneYang720 Apr 28, 2023
510c609
fix: _convert_dict_to_doc
AnneYang720 Apr 28, 2023
2666539
Merge branch 'main' into feat-subindex
AnneYang720 Apr 28, 2023
f27ba7a
fix: subindex doc generation
AnneYang720 Apr 28, 2023
a25cc5c
test: add tests for abstrac methods with subindex
AnneYang720 Apr 28, 2023
fac39aa
refactor: rename params
AnneYang720 Apr 28, 2023
7e93002
feat: support subindex for weaviate
AnneYang720 Apr 29, 2023
fbffe0b
test: for method find_subindex
AnneYang720 Apr 29, 2023
509d778
fix: mypy
AnneYang720 Apr 29, 2023
c6395b9
Merge branch 'main' into feat-subindex
AnneYang720 May 4, 2023
ea7d481
test: weaviate subindex test
AnneYang720 May 5, 2023
e5d7965
fix: weaviate subindex find
AnneYang720 May 5, 2023
81e9475
feat: support subindex for qdrant
AnneYang720 May 5, 2023
146ae7d
test: for qdrant subindex
AnneYang720 May 5, 2023
5b509aa
fix: mypy
AnneYang720 May 5, 2023
3b4c083
Merge branch 'main' into feat-subindex
AnneYang720 May 5, 2023
029804f
fix: new client for qdrant subindex test
AnneYang720 May 5, 2023
a23efd6
fix: qdrant subindex test gdbm error
AnneYang720 May 5, 2023
12dde4d
fix: find_subindex
AnneYang720 May 5, 2023
aa24d00
Merge branch 'main' into feat-subindex
AnneYang720 May 5, 2023
c23411b
docs: add subindex documentation and docstring
AnneYang720 May 6, 2023
34b10a5
Merge branch 'main' into feat-subindex
AnneYang720 May 6, 2023
4df2363
Merge remote-tracking branch 'origin/main' into feat-subindex
AnneYang720 May 8, 2023
9ec3c79
fix: use default index_name
AnneYang720 May 8, 2023
c080024
Merge branch 'main' into feat-subindex
AnneYang720 May 8, 2023
ce7f08f
Merge branch 'main' into feat-subindex
AnneYang720 May 8, 2023
4e32df3
docs: add docstring
AnneYang720 May 8, 2023
8b7989a
refactor: init subindex in abstract
AnneYang720 May 8, 2023
4d7ecd7
fix: mypy
AnneYang720 May 8, 2023
7c83c63
fix: change work_dir or collection_name first
AnneYang720 May 8, 2023
24dd2c4
Merge remote-tracking branch 'origin/main' into feat-subindex
AnneYang720 May 8, 2023
cb61bf3
docs: add index_name property
AnneYang720 May 8, 2023
3090a2b
Merge branch 'main' into feat-subindex
AnneYang720 May 9, 2023
d2f8f78
Merge branch 'main' into feat-subindex
AnneYang720 May 10, 2023
722944f
refactor: minor adjustment
AnneYang720 May 10, 2023
011178f
Update docs/user_guide/storing/docindex.md
AnneYang720 May 10, 2023
0eb0112
docs: fix link
AnneYang720 May 10, 2023
255d59a
feat: remove subindex from find
AnneYang720 May 11, 2023
12cbc3d
docs: update subindex part
AnneYang720 May 11, 2023
1940b10
fix: hnswlib subdocs should be original schema
AnneYang720 May 11, 2023
3159be3
Merge branch 'main' into feat-subindex
AnneYang720 May 11, 2023
dcb7863
docs: docstring for hnswlib _get_items
AnneYang720 May 11, 2023
e1870c9
Merge branch 'main' into feat-subindex
AnneYang720 May 11, 2023
6154a18
fix: fix docstring
AnneYang720 May 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 269 additions & 11 deletions docarray/index/abstract.py

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions docarray/index/backends/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import docarray.typing
from docarray import BaseDoc
from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import BaseDocIndex, _ColumnInfo, _raise_not_composable
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor
Expand All @@ -39,6 +40,8 @@


if TYPE_CHECKING:
import tensorflow as tf # type: ignore
import torch
from elastic_transport import NodeConfig
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
Expand Down Expand Up @@ -90,6 +93,8 @@ def __init__(self, db_config=None, **kwargs):
mappings.update(self._db_config.index_mappings)

for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
continue
if col.db_type == 'dense_vector' and (
not col.n_dim and col.config['dims'] < 0
):
Expand All @@ -100,7 +105,6 @@ def __init__(self, db_config=None, **kwargs):

mappings['properties'][col_name] = self._create_index_mapping(col)

# print(mappings['properties'])
if self._client.indices.exists(index=self.index_name):
self._client_put_mapping(mappings)
else:
Expand Down Expand Up @@ -334,6 +338,8 @@ def _index(
refresh: bool = True,
chunk_size: Optional[int] = None,
):
self._index_subindex(column_to_data)

data = self._transpose_col_value_dict(column_to_data)
requests = []

Expand All @@ -343,6 +349,8 @@ def _index(
'_id': row['id'],
}
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
continue
if col.db_type == 'dense_vector' and np.all(row[col_name] == 0):
row[col_name] = row[col_name] + 1.0e-9
if row[col_name] is None:
Expand Down Expand Up @@ -383,7 +391,7 @@ def _del_items(

self._refresh(self.index_name)

def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
def _get_items(self, doc_ids: Sequence[str]) -> Sequence[Dict[str, Any]]:
accumulated_docs = []
accumulated_docs_id_not_found = []

Expand Down Expand Up @@ -515,6 +523,13 @@ def _text_search_batched(
)
return _FindResultBatched(documents=list(das), scores=scores)

def _filter_by_parent_id(self, id: str) -> List[str]:
resp = self._client_search(
query={'term': {'parent_id': id}}, fields=['id'], _source=False
)
ids = [hit['fields']['id'][0] for hit in resp['hits']['hits']]
return ids

###############################################
# Helpers #
###############################################
Expand Down
121 changes: 98 additions & 23 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import glob
import hashlib
import os
import sqlite3
Expand All @@ -7,6 +8,7 @@
TYPE_CHECKING,
Any,
Dict,
Generator,
Generic,
List,
Optional,
Expand All @@ -21,6 +23,7 @@
import numpy as np

from docarray import BaseDoc, DocList
from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
Expand Down Expand Up @@ -67,11 +70,16 @@
class HnswDocumentIndex(BaseDocIndex, Generic[TSchema]):
def __init__(self, db_config=None, **kwargs):
"""Initialize HnswDocumentIndex"""
if db_config is not None and getattr(db_config, 'index_name'):
db_config.work_dir = db_config.index_name.replace("__", "/")

super().__init__(db_config=db_config, **kwargs)
self._db_config = cast(HnswDocumentIndex.DBConfig, self._db_config)
self._work_dir = self._db_config.work_dir
self._logger.debug(f'Working directory set to {self._work_dir}')
load_existing = os.path.exists(self._work_dir) and os.listdir(self._work_dir)
load_existing = os.path.exists(self._work_dir) and glob.glob(
f'{self._work_dir}/*.bin'
)
Path(self._work_dir).mkdir(parents=True, exist_ok=True)

# HNSWLib setup
Expand All @@ -90,6 +98,8 @@ def __init__(self, db_config=None, **kwargs):
}
self._hnsw_indices = {}
for col_name, col in self._column_infos.items():
if issubclass(col.docarray_type, AnyDocArray):
continue
if not col.config:
# non-tensor type; don't create an index
continue
Expand Down Expand Up @@ -118,6 +128,17 @@ def __init__(self, db_config=None, **kwargs):
self._sqlite_conn.commit()
self._logger.info(f'{self.__class__.__name__} has been initialized')

@property
def index_name(self):
return self._db_config.work_dir # type: ignore

@property
def out_schema(self) -> Type[BaseDoc]:
"""Return the real schema of the index."""
if self._is_subindex:
return self._ori_schema
return cast(Type[BaseDoc], self._schema)

###############################################
# Inner classes for query builder and configs #
###############################################
Expand Down Expand Up @@ -184,9 +205,23 @@ def python_type_to_db_type(self, python_type: Type) -> Any:

return None # all types allowed, but no db type needed

def _index(self, column_data_dic, **kwargs):
def _index(
self,
column_to_data: Dict[str, Generator[Any, None, None]],
docs_validated: Sequence[BaseDoc] = [],
):
self._index_subindex(column_to_data)

# not needed, we implement `index` directly
...
hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated)
# indexing into HNSWLib and SQLite sequentially
# could be improved by processing in parallel
for col_name, index in self._hnsw_indices.items():
data = column_to_data[col_name]
data_np = [self._to_numpy(arr) for arr in data]
data_stacked = np.stack(data_np)
index.add_items(data_stacked, ids=hashed_ids)
index.save_index(self._hnsw_locations[col_name])

def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
"""Index Documents into the index.
Expand All @@ -206,16 +241,10 @@ def index(self, docs: Union[BaseDoc, Sequence[BaseDoc]], **kwargs):
n_docs = 1 if isinstance(docs, BaseDoc) else len(docs)
self._logger.debug(f'Indexing {n_docs} documents')
docs_validated = self._validate_docs(docs)
self._update_subindex_data(docs_validated)
data_by_columns = self._get_col_value_dict(docs_validated)
hashed_ids = tuple(self._to_hashed_id(doc.id) for doc in docs_validated)
# indexing into HNSWLib and SQLite sequentially
# could be improved by processing in parallel
for col_name, index in self._hnsw_indices.items():
data = data_by_columns[col_name]
data_np = [self._to_numpy(arr) for arr in data]
data_stacked = np.stack(data_np)
index.add_items(data_stacked, ids=hashed_ids)
index.save_index(self._hnsw_locations[col_name])

self._index(data_by_columns, docs_validated, **kwargs)

self._send_docs_to_sqlite(docs_validated)
self._sqlite_conn.commit()
Expand Down Expand Up @@ -312,6 +341,15 @@ def _text_search_batched(

def _del_items(self, doc_ids: Sequence[str]):
# delete from the indices
for field_name, type_, _ in self._flatten_schema(
cast(Type[BaseDoc], self._schema)
):
if issubclass(type_, AnyDocArray):
for id in doc_ids:
doc = self.__getitem__(id)
sub_ids = [sub_doc.id for sub_doc in getattr(doc, field_name)]
del self._subindices[field_name][sub_ids]

try:
for doc_id in doc_ids:
id_ = self._to_hashed_id(doc_id)
Expand All @@ -323,8 +361,15 @@ def _del_items(self, doc_ids: Sequence[str]):
self._delete_docs_from_sqlite(doc_ids)
self._sqlite_conn.commit()

def _get_items(self, doc_ids: Sequence[str]) -> Sequence[TSchema]:
out_docs = self._get_docs_sqlite_doc_id(doc_ids)
def _get_items(self, doc_ids: Sequence[str], out: bool = True) -> Sequence[TSchema]:
"""Get Documents from the hnswlib index, by `id`.
If no document is found, a KeyError is raised.

:param doc_ids: ids to get from the Document index
:param out: return the documents in the original schema(True) or inner schema(False) for subindex
:return: Sequence of Documents, sorted corresponding to the order of `doc_ids`. Duplicate `doc_ids` can be omitted in the output.
"""
out_docs = self._get_docs_sqlite_doc_id(doc_ids, out)
if len(out_docs) == 0:
raise KeyError(f'No document with id {doc_ids} found')
return out_docs
Expand Down Expand Up @@ -391,7 +436,7 @@ def _send_docs_to_sqlite(self, docs: Sequence[BaseDoc]):
((id_, self._doc_to_bytes(doc)) for id_, doc in zip(ids, docs)),
)

def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int]):
def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int], out: bool = True):
for id_ in univ_ids:
# I hope this protects from injection attacks
# properly binding with '?' doesn't work for some reason
Expand All @@ -401,13 +446,17 @@ def _get_docs_sqlite_unsorted(self, univ_ids: Sequence[int]):
'SELECT data FROM docs WHERE doc_id IN %s' % sql_id_list,
)
rows = self._sqlite_cursor.fetchall()
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
return docs_cls([self._doc_from_bytes(row[0]) for row in rows])
schema = self.out_schema if out else self._schema
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
return docs_cls([self._doc_from_bytes(row[0], out) for row in rows])

def _get_docs_sqlite_doc_id(self, doc_ids: Sequence[str]) -> DocList[TSchema]:
def _get_docs_sqlite_doc_id(
self, doc_ids: Sequence[str], out: bool = True
) -> DocList[TSchema]:
hashed_ids = tuple(self._to_hashed_id(id_) for id_ in doc_ids)
docs_unsorted = self._get_docs_sqlite_unsorted(hashed_ids)
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
docs_unsorted = self._get_docs_sqlite_unsorted(hashed_ids, out)
schema = self.out_schema if out else self._schema
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], schema))
return docs_cls(sorted(docs_unsorted, key=lambda doc: doc_ids.index(doc.id)))

def _get_docs_sqlite_hashed_id(self, hashed_ids: Sequence[int]) -> DocList:
Expand All @@ -416,7 +465,7 @@ def _get_docs_sqlite_hashed_id(self, hashed_ids: Sequence[int]) -> DocList:
def _in_position(doc):
return hashed_ids.index(self._to_hashed_id(doc.id))

docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self._schema))
docs_cls = DocList.__class_getitem__(cast(Type[BaseDoc], self.out_schema))
return docs_cls(sorted(docs_unsorted, key=_in_position))

def _delete_docs_from_sqlite(self, doc_ids: Sequence[Union[str, int]]):
Expand All @@ -436,6 +485,32 @@ def _get_num_docs_sqlite(self) -> int:
def _doc_to_bytes(self, doc: BaseDoc) -> bytes:
return doc.to_protobuf().SerializeToString()

def _doc_from_bytes(self, data: bytes) -> BaseDoc:
schema_cls = cast(Type[BaseDoc], self._schema)
def _doc_from_bytes(self, data: bytes, out: bool = True) -> BaseDoc:
schema = self.out_schema if out else self._schema
schema_cls = cast(Type[BaseDoc], schema)
return schema_cls.from_protobuf(DocProto.FromString(data))

def _get_root_doc_id(self, id: str, root: str, sub: str) -> str:
"""Get the root_id given the id of a subindex Document and the root and subindex name for hnswlib.

:param id: id of the subindex Document
:param root: root index name
:param sub: subindex name
:return: the root_id of the Document
"""
subindex = self._subindices[root]

if not sub:
sub_doc = subindex._get_items([id], out=False) # type: ignore
parent_id = (
sub_doc[0]['parent_id']
if isinstance(sub_doc[0], dict)
else sub_doc[0].parent_id
)
return parent_id
else:
fields = sub.split('__')
cur_root_id = subindex._get_root_doc_id(
id, fields[0], '__'.join(fields[1:])
)
return self._get_root_doc_id(cur_root_id, root, '')
46 changes: 40 additions & 6 deletions docarray/index/backends/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import docarray.typing.id
from docarray import BaseDoc, DocList
from docarray.array.any_array import AnyDocArray
from docarray.index.abstract import (
BaseDocIndex,
_ColumnInfo,
Expand Down Expand Up @@ -65,6 +66,9 @@ class QdrantDocumentIndex(BaseDocIndex, Generic[TSchema]):

def __init__(self, db_config=None, **kwargs):
"""Initialize QdrantDocumentIndex"""
if db_config is not None and getattr(db_config, 'index_name'):
db_config.collection_name = db_config.index_name

super().__init__(db_config=db_config, **kwargs)
self._db_config: QdrantDocumentIndex.DBConfig = cast(
QdrantDocumentIndex.DBConfig, self._db_config
Expand Down Expand Up @@ -98,6 +102,10 @@ def collection_name(self):

return self._db_config.collection_name or default_collection_name

@property
def index_name(self):
return self.collection_name

@dataclass
class Query:
"""Dataclass describing a query."""
Expand Down Expand Up @@ -264,11 +272,14 @@ def _initialize_collection(self):
try:
self._client.get_collection(self.collection_name)
except (UnexpectedResponse, RpcError, ValueError):
vectors_config = {
column_name: self._to_qdrant_vector_params(column_info)
for column_name, column_info in self._column_infos.items()
if column_info.db_type == 'vector'
}
vectors_config = {}

for column_name, column_info in self._column_infos.items():
if column_info.db_type == 'vector':
vectors_config[column_name] = self._to_qdrant_vector_params(
column_info
)

self._client.create_collection(
collection_name=self.collection_name,
vectors_config=vectors_config,
Expand All @@ -288,6 +299,8 @@ def _initialize_collection(self):
)

def _index(self, column_to_data: Dict[str, Generator[Any, None, None]]):
self._index_subindex(column_to_data)

rows = self._transpose_col_value_dict(column_to_data)
# TODO: add batching the documents to avoid timeouts
points = [self._build_point_from_row(row) for row in rows]
Expand Down Expand Up @@ -332,7 +345,10 @@ def _get_items(
with_payload=True,
with_vectors=True,
)
return [self._convert_to_doc(point) for point in response]
return sorted(
[self._convert_to_doc(point) for point in response],
key=lambda x: doc_ids.index(x['id']),
)

def execute_query(self, query: Union[Query, RawQuery], *args, **kwargs) -> DocList:
"""
Expand Down Expand Up @@ -532,11 +548,29 @@ def _text_search_batched(
],
)

def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
response, _ = self._client.scroll(
collection_name=self._db_config.collection_name, # type: ignore
scroll_filter=rest.Filter(
must=[
rest.FieldCondition(
key='parent_id', match=rest.MatchValue(value=id)
)
]
),
with_payload=rest.PayloadSelectorInclude(include=['id']),
)

ids = [point.payload['id'] for point in response] # type: ignore
return ids

def _build_point_from_row(self, row: Dict[str, Any]) -> rest.PointStruct:
point_id = self._to_qdrant_id(row.get('id'))
vectors: Dict[str, List[float]] = {}
payload: Dict[str, Any] = {'__generated_vectors': []}
for column_name, column_info in self._column_infos.items():
if issubclass(column_info.docarray_type, AnyDocArray):
continue
if column_info.db_type in ['id', 'payload']:
payload[column_name] = row.get(column_name)
continue
Expand Down
Loading