Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docarray/array/mixins/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def find(

_limit = len(self) if limit is None else (limit + (1 if exclude_self else 0))

_, _ = ndarray.get_array_type(_query)
n_rows, n_dim = ndarray.get_array_rows(_query)

# Ensure query embedding to have the correct shape
Expand Down
19 changes: 16 additions & 3 deletions docarray/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@ def unravel(docs: Sequence['Document'], field: str) -> Optional['ArrayType']:
return None

framework, is_sparse = get_array_type(_first)
all_fields = [getattr(d, field) for d in docs]
cls_type = type(_first)

all_fields = [getattr(d, field) for d in docs]
none_idx = [idx for idx, v in enumerate(all_fields) if v is None]
if none_idx:
raise ValueError(
f'Document{none_idx}.{field} is None. Can not stack into `{field}s`.'
)

if framework == 'python':
return cls_type(all_fields)

Expand Down Expand Up @@ -124,7 +130,14 @@ def get_array_type(
return 'scipy', True

if raise_error_if_not_array:
raise TypeError(f'can not determine the array type: {module_tags}.{class_name}')
if array is not None:
raise TypeError(
f'can not determine the array type: {module_tags}.{class_name}'
)
else:
raise ValueError(
f'Empty ndarray. Did you forget to set .embedding/.tensor value and now you are operating on it?'
)
else:
return 'python', False

Expand Down Expand Up @@ -271,7 +284,7 @@ def detach_tensor_if_present(x: Any) -> Any:
:return: (num_rows, ndim)
"""
x_type, x_sparse = get_array_type(x, raise_error_if_not_array=False)
if x_type == 'torch' and x_sparse == False:
if x_type == 'torch' and not x_sparse:
import torch

x = torch.tensor(x.detach().numpy())
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/array/mixins/test_exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pytest

from docarray import DocumentArray


def test_embedding_ops_error():
da = DocumentArray.empty(100)
db = DocumentArray.empty(100)
da.embeddings = np.random.random([100, 256])

da[2].embedding = None
da[3].embedding = None

with pytest.raises(ValueError, match='[2, 3]'):
da.embeddings

db.embeddings = np.random.random([100, 256])
with pytest.raises(ValueError, match='[2, 3]'):
da.match(db)
with pytest.raises(ValueError, match='[2, 3]'):
db.match(da)
with pytest.raises(ValueError, match='[2, 3]'):
db.find(da)
with pytest.raises(ValueError, match='[2, 3]'):
da.find(db)

da.embeddings = None
with pytest.raises(ValueError, match='Did you forget to set'):
da.find(db)
db.embeddings = None
with pytest.raises(ValueError, match='Did you forget to set'):
da.find(db)
with pytest.raises(ValueError, match='Did you forget to set'):
db.find(da)
da.embeddings = np.random.random([100, 256])
with pytest.raises(ValueError, match='Did you forget to set'):
da.find(None)