Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def _find_batched(
limit: int,
search_field: str = '',
) -> _FindResultBatched:
if not self.num_docs():
Comment thread
jupyterjazz marked this conversation as resolved.
Outdated
return _FindResultBatched(documents=[], scores=[])

index = self._hnsw_indices[search_field]
labels, distances = index.knn_query(queries, k=limit)
result_das = [
Expand All @@ -293,6 +296,9 @@ def _find_batched(
def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
) -> _FindResult:
if not self.num_docs():
return _FindResult(documents=[], scores=[]) # type: ignore

query_batched = np.expand_dims(query, axis=0)
docs, scores = self._find_batched(
queries=query_batched, limit=limit, search_field=search_field
Expand Down
8 changes: 8 additions & 0 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def find(
"""
self._logger.debug(f'Executing `find` for search field {search_field}')
self._validate_search_field(search_field)

if not self.num_docs():
Comment thread
jupyterjazz marked this conversation as resolved.
Outdated
return FindResult(documents=[], scores=[]) # type: ignore

config = self._column_infos[search_field].config

docs, scores = find(
Expand Down Expand Up @@ -233,6 +237,10 @@ def find_batched(
"""
self._logger.debug(f'Executing `find_batched` for search field {search_field}')
self._validate_search_field(search_field)

if not self.num_docs():
return FindResultBatched(documents=[], scores=[])

config = self._column_infos[search_field].config

find_res = find_batched(
Expand Down
9 changes: 9 additions & 0 deletions tests/index/hnswlib/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ class SimpleSchema(BaseDoc):
assert np.allclose(result.tens, np.zeros(10))


def test_find_empty_index(tmp_path):
empty_index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
query = SimpleDoc(tens=np.ones(10))

docs, scores = empty_index.find(query, search_field='tens', limit=5)
assert len(docs) == 0
assert len(scores) == 0


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_torch(tmp_path, space):
index = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path))
Expand Down
10 changes: 10 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class MyDoc(BaseDoc):
assert len(scores) == 5
assert doc_index.num_docs() == 10

empty_index = InMemoryExactNNIndex[MyDoc]()
docs, scores = empty_index.find(query, search_field='tensor', limit=5)
assert len(docs) == 0
assert len(scores) == 0


@pytest.mark.parametrize('space', ['cosine_sim', 'euclidean_dist', 'sqeuclidean_dist'])
@pytest.mark.parametrize('is_query_doc', [True, False])
Expand All @@ -96,6 +101,11 @@ class MyDoc(BaseDoc):
assert len(result) == 5
assert doc_index.num_docs() == 10

empty_index = InMemoryExactNNIndex[MyDoc]()
docs, scores = empty_index.find_batched(query, search_field='tensor', limit=5)
assert len(docs) == 0
assert len(scores) == 0


def test_concatenated_queries(doc_index):
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))
Expand Down