Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docarray/index/backends/hnswlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def _find_batched(
limit: int,
search_field: str = '',
) -> _FindResultBatched:
if self.num_docs() == 0:
return _FindResultBatched(documents=[], scores=[]) # type: ignore

index = self._hnsw_indices[search_field]
labels, distances = index.knn_query(queries, k=limit)
result_das = [
Expand All @@ -293,6 +296,9 @@ def _find_batched(
def _find(
self, query: np.ndarray, limit: int, search_field: str = ''
) -> _FindResult:
if self.num_docs() == 0:
return _FindResult(documents=[], scores=[]) # type: ignore

query_batched = np.expand_dims(query, axis=0)
docs, scores = self._find_batched(
queries=query_batched, limit=limit, search_field=search_field
Expand Down
8 changes: 8 additions & 0 deletions docarray/index/backends/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ def find(
"""
self._logger.debug(f'Executing `find` for search field {search_field}')
self._validate_search_field(search_field)

if self.num_docs() == 0:
return FindResult(documents=[], scores=[]) # type: ignore

config = self._column_infos[search_field].config

docs, scores = find(
Expand Down Expand Up @@ -233,6 +237,10 @@ def find_batched(
"""
self._logger.debug(f'Executing `find_batched` for search field {search_field}')
self._validate_search_field(search_field)

if self.num_docs() == 0:
return FindResultBatched(documents=[], scores=[]) # type: ignore

config = self._column_infos[search_field].config

find_res = find_batched(
Expand Down
9 changes: 9 additions & 0 deletions tests/index/hnswlib/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ class SimpleSchema(BaseDoc):
assert np.allclose(result.tens, np.zeros(10))


def test_find_empty_index(tmp_path):
empty_index = HnswDocumentIndex[SimpleDoc](work_dir=str(tmp_path))
query = SimpleDoc(tens=np.ones(10))

docs, scores = empty_index.find(query, search_field='tens', limit=5)
assert len(docs) == 0
assert len(scores) == 0


@pytest.mark.parametrize('space', ['cosine', 'l2', 'ip'])
def test_find_torch(tmp_path, space):
index = HnswDocumentIndex[TorchDoc](work_dir=str(tmp_path))
Expand Down
10 changes: 10 additions & 0 deletions tests/index/in_memory/test_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class MyDoc(BaseDoc):
assert len(scores) == 5
assert doc_index.num_docs() == 10

empty_index = InMemoryExactNNIndex[MyDoc]()
docs, scores = empty_index.find(query, search_field='tensor', limit=5)
assert len(docs) == 0
assert len(scores) == 0


@pytest.mark.parametrize('space', ['cosine_sim', 'euclidean_dist', 'sqeuclidean_dist'])
@pytest.mark.parametrize('is_query_doc', [True, False])
Expand All @@ -96,6 +101,11 @@ class MyDoc(BaseDoc):
assert len(result) == 5
assert doc_index.num_docs() == 10

empty_index = InMemoryExactNNIndex[MyDoc]()
docs, scores = empty_index.find_batched(query, search_field='tensor', limit=5)
assert len(docs) == 0
assert len(scores) == 0


def test_concatenated_queries(doc_index):
query = SchemaDoc(text='query', price=0, tensor=np.ones(10))
Expand Down