Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docarray/array/doc_vec/column_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ def __getitem__(self: T, item: IndexIterType) -> T:
self.tensor_type,
)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, ColumnStorage):
return False
if self.tensor_type != other.tensor_type:
return False
for col_map_self, col_map_other in zip(self.columns.maps, other.columns.maps):
if col_map_self.keys() != col_map_other.keys():
return False
for key_self in col_map_self.keys():
if key_self == 'id':
continue
if col_map_self[key_self] != col_map_other[key_self]:
return False
return True


class ColumnStorageView(dict, MutableMapping[str, Any]):
index: int
Expand Down
11 changes: 11 additions & 0 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,17 @@ def __iter__(self):
def __len__(self):
return len(self._storage)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, DocVec):
return False
if self.doc_type != other.doc_type:
return False
if self.tensor_type != other.tensor_type:
return False
if self._storage != other._storage:
return False
return True

####################
# IO related #
####################
Expand Down
32 changes: 32 additions & 0 deletions tests/units/array/stack/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,3 +585,35 @@ def test_doc_view_dict(batch):
d = doc_view_two.dict()
assert d['tensor'].shape == (3, 224, 224)
assert d['id'] == doc_view_two.id


def test_doc_vec_equality():
class Text(BaseDoc):
text: str

da = DocVec[Text]([Text(text='hello') for _ in range(10)])
da2 = DocList[Text]([Text(text='hello') for _ in range(10)])

assert da != da2
assert da == da2.to_doc_vec()


def test_doc_vec_nested(batch_nested_doc):
batch, Doc, Inner = batch_nested_doc
batch2 = DocVec[Doc]([Doc(inner=Inner(hello='hello')) for _ in range(10)])

assert batch == batch2


def test_doc_vec_tensor_type():
class ImageDoc(BaseDoc):
tensor: AnyTensor

da = DocVec[ImageDoc]([ImageDoc(tensor=np.zeros((3, 224, 224))) for _ in range(10)])

da2 = DocVec[ImageDoc](
[ImageDoc(tensor=torch.zeros(3, 224, 224)) for _ in range(10)],
tensor_type=TorchTensor,
)

assert da != da2