Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,16 @@ def __len__(self):
####################

@classmethod
def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
"""create a DocVec from a protobuf message"""
def from_protobuf(
cls: Type[T], pb_msg: 'DocVecProto', tensor_type: Type[AbstractTensor] = NdArray
) -> T:
"""create a DocVec from a protobuf message
:param pb_msg: the protobuf message to deserialize
:param tensor_type: the tensor type to use for the tensor columns.
Could be NdArray, TorchTensor, or TensorFlowTensor. Defaults to NdArray.
All tensors of the output DocVec will be of this type.
:return: The deserialized DocVec
"""

tensor_columns: Dict[str, Optional[AbstractTensor]] = {}
doc_columns: Dict[str, Optional['DocVec']] = {}
Expand All @@ -602,8 +610,9 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
# handle values that were None before serialization
tensor_columns[tens_col_name] = None
else:
# TODO(johannes): handle torch, tf, numpy
tensor_columns[tens_col_name] = NdArray.from_protobuf(tens_col_proto)
tensor_columns[tens_col_name] = tensor_type.from_protobuf(
tens_col_proto
)

for doc_col_name, doc_col_proto in pb_msg.doc_columns.items():
if _is_none_docvec_proto(doc_col_proto):
Expand All @@ -613,7 +622,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
col_doc_type: Type = cls.doc_type._get_field_type(doc_col_name)
doc_columns[doc_col_name] = DocVec.__class_getitem__(
col_doc_type
).from_protobuf(doc_col_proto)
).from_protobuf(doc_col_proto, tensor_type=tensor_type)

for docs_vec_col_name, docs_vec_col_proto in pb_msg.docs_vec_columns.items():
vec_list: Optional[ListAdvancedIndexing]
Expand All @@ -628,7 +637,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
).doc_type
vec_list.append(
DocVec.__class_getitem__(col_doc_type).from_protobuf(
doc_list_proto
doc_list_proto, tensor_type=tensor_type
)
)
docs_vec_columns[docs_vec_col_name] = vec_list
Expand All @@ -647,6 +656,7 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocVecProto') -> T:
doc_columns=doc_columns,
docs_vec_columns=docs_vec_columns,
any_columns=any_columns,
tensor_type=tensor_type,
)

return cls.from_columns_storage(storage)
Expand Down
38 changes: 38 additions & 0 deletions docs/user_guide/sending/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,42 @@ proto_message_dv = dv.to_protobuf()
dv_from_proto = DocVec[SimpleVecDoc].from_protobuf(proto_message_dv)
```

You can deserialize any [DocVec][docarray.array.doc_list.doc_list.DocVec] protobuf message to any tensor type,
by passing the `tensor_type=...` parameter to [`from_protobuf`][docarray.array.doc_list.doc_list.DocVec.from_protobuf]

This means that you can choose at deserialization time if you are working with numpy, PyTorch, or TensorFlow tensors.

If no `tensor_type` is passed, the default is `NdArray`.


```python
import torch

from docarray import BaseDoc, DocVec
from docarray.typing import TorchTensor, NdArray, AnyTensor


class AnyTensorDoc(BaseDoc):
tensor: AnyTensor


dv = DocVec[AnyTensorDoc](
[AnyTensorDoc(tensor=torch.ones(16)) for _ in range(8)], tensor_type=TorchTensor
)

proto_message_dv = dv.to_protobuf()

# deserialize to torch
dv_from_proto_torch = DocVec[AnyTensorDoc].from_protobuf(
proto_message_dv, tensor_type=TorchTensor
)
assert dv_from_proto_torch.tensor_type == TorchTensor
assert isinstance(dv_from_proto_torch.tensor, TorchTensor)

# deserialize to numpy (default)
dv_from_proto_numpy = DocVec[AnyTensorDoc].from_protobuf(proto_message_dv)
assert dv_from_proto_numpy.tensor_type == NdArray
assert isinstance(dv_from_proto_numpy.tensor, NdArray)
```


101 changes: 101 additions & 0 deletions tests/units/array/stack/test_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,104 @@ class MyDoc(BaseDoc):

assert da_after._storage.any_columns['text'] == [None, None]
assert da_after._storage.any_columns['d'] == [None, None]


@pytest.mark.proto
@pytest.mark.parametrize('tensor_type', [NdArray, TorchTensor])
def test_proto_tensor_type(tensor_type):
class InnerDoc(BaseDoc):
embedding: tensor_type

class MyDoc(BaseDoc):
tensor: tensor_type
inner: InnerDoc
inner_v: DocVec[InnerDoc]

def _get_rand_tens():
arr = np.random.random(512)
return tensor_type.from_ndarray(arr) if tensor_type == TorchTensor else arr

da = DocVec[MyDoc](
[
MyDoc(
tensor=_get_rand_tens(),
inner=InnerDoc(embedding=_get_rand_tens()),
inner_v=DocVec[InnerDoc]([InnerDoc(embedding=_get_rand_tens())]),
),
MyDoc(
tensor=_get_rand_tens(),
inner=InnerDoc(embedding=_get_rand_tens()),
inner_v=DocVec[InnerDoc]([InnerDoc(embedding=_get_rand_tens())]),
),
]
)
assert isinstance(da.tensor, tensor_type)
assert da.tensor.shape == (2, 512)
assert isinstance(da.inner.embedding, tensor_type)
assert da.inner.embedding.shape == (2, 512)
assert isinstance(da.inner_v[0].embedding, tensor_type)
assert da.inner_v[0].embedding.shape == (1, 512)

proto = da.to_protobuf()
da_after = DocVec[MyDoc].from_protobuf(proto, tensor_type=tensor_type)

assert isinstance(da_after.tensor, tensor_type)
assert (da.tensor == da_after.tensor).all()
assert isinstance(da_after.inner.embedding, tensor_type)
assert (da.inner.embedding == da_after.inner.embedding).all()
assert isinstance(da_after.inner_v[0].embedding, tensor_type)
assert (da.inner_v[0].embedding == da_after.inner_v[0].embedding).all()


@pytest.mark.tensorflow
def test_proto_tensor_type_tf():
import tensorflow as tf

from docarray.typing import TensorFlowTensor

class InnerDoc(BaseDoc):
embedding: TensorFlowTensor

class MyDoc(BaseDoc):
tensor: TensorFlowTensor
inner: InnerDoc
inner_v: DocVec[InnerDoc]

def _get_rand_tens():
arr = np.random.random(512)
return TensorFlowTensor.from_ndarray(arr)

da = DocVec[MyDoc](
[
MyDoc(
tensor=_get_rand_tens(),
inner=InnerDoc(embedding=_get_rand_tens()),
inner_v=DocVec[InnerDoc]([InnerDoc(embedding=_get_rand_tens())]),
),
MyDoc(
tensor=_get_rand_tens(),
inner=InnerDoc(embedding=_get_rand_tens()),
inner_v=DocVec[InnerDoc]([InnerDoc(embedding=_get_rand_tens())]),
),
]
)
assert isinstance(da.tensor, TensorFlowTensor)
assert len(da.tensor) == 2
assert isinstance(da.inner.embedding, TensorFlowTensor)
assert len(da.inner.embedding) == 2
assert isinstance(da.inner_v[0].embedding, TensorFlowTensor)
assert len(da.inner_v[0].embedding) == 1

proto = da.to_protobuf()
da_after = DocVec[MyDoc].from_protobuf(proto, tensor_type=TensorFlowTensor)

assert isinstance(da_after.tensor, TensorFlowTensor)
assert tf.math.reduce_all(tf.equal(da.tensor.tensor, da_after.tensor.tensor))
assert isinstance(da_after.inner.embedding, TensorFlowTensor)
assert tf.math.reduce_all(
tf.equal(da.inner.embedding.tensor, da_after.inner.embedding.tensor)
)
assert isinstance(da_after.inner_v[0].embedding, TensorFlowTensor)
assert tf.math.reduce_all(
tf.equal(da.inner_v[0].embedding.tensor, da_after.inner_v[0].embedding.tensor)
)