Skip to content

Commit 808acf9

Browse files
feat: torch tensor type (#800)
* feat: add tensor type for ndarray * fix: fix mypy typing * feat: torch tensor type Signed-off-by: Johannes Messner <messnerjo@gmail.com> * fix: protobuf for pytorch type Signed-off-by: Johannes Messner <messnerjo@gmail.com> * ci: install all extras in the ci Signed-off-by: Johannes Messner <messnerjo@gmail.com> * refactor: make nice looking * docs: update docarray/typing/tensor/torch_tensor.py Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> Signed-off-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> * refactor: code style Signed-off-by: Johannes Messner <messnerjo@gmail.com> * fix: black and mypy Signed-off-by: Johannes Messner <messnerjo@gmail.com> * fix: suppress mypy import error * ci: fix ci install Signed-off-by: Johannes Messner <messnerjo@gmail.com> Signed-off-by: Johannes Messner <messnerjo@gmail.com> Signed-off-by: Johannes Messner <44071807+JohannesMessner@users.noreply.github.com> Co-authored-by: samsja <55492238+samsja@users.noreply.github.com>
1 parent 218b123 commit 808acf9

12 files changed

Lines changed: 1039 additions & 188 deletions

File tree

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
run: |
5959
python -m pip install --upgrade pip
6060
python -m pip install poetry
61-
poetry install -E common --without dev
61+
poetry install --all-extras --without dev
6262
- name: Test basic import
6363
run: poetry run python -c 'from docarray import DocumentArray, Document'
6464

@@ -110,7 +110,7 @@ jobs:
110110
run: |
111111
python -m pip install --upgrade pip
112112
python -m pip install poetry
113-
poetry install -E common
113+
poetry install --all-extras
114114
115115
- name: Test
116116
id: test

docarray/document/mixins/proto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from docarray.document.abstract_document import AbstractDocument
66
from docarray.document.base_node import BaseNode
77
from docarray.proto import DocumentProto, NodeProto
8-
from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor
8+
from docarray.typing import ID, AnyUrl, Embedding, ImageUrl, Tensor, TorchTensor
99

1010
T = TypeVar('T', bound='ProtoMixin')
1111

@@ -27,6 +27,8 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocumentProto') -> T:
2727
# the check should be delegated to the type level
2828
if content_type == 'tensor':
2929
fields[field] = Tensor._read_from_proto(value.tensor)
30+
elif content_type == 'torch_tensor':
31+
fields[field] = TorchTensor._read_from_proto(value.torch_tensor)
3032
elif content_type == 'embedding':
3133
fields[field] = Embedding._read_from_proto(value.embedding)
3234
elif content_type == 'any_url':

docarray/proto/docarray.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ message NodeProto {
5555

5656
string id = 9;
5757

58+
NdArrayProto torch_tensor = 10;
59+
5860
}
5961

6062
}

docarray/proto/pb2/docarray_pb2.py

Lines changed: 18 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docarray/typing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from docarray.typing.embedding import Embedding
22
from docarray.typing.id import ID
3-
from docarray.typing.tensor import Tensor
3+
from docarray.typing.tensor import Tensor, TorchTensor
44
from docarray.typing.url import AnyUrl, ImageUrl
55

6-
__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID']
6+
__all__ = ['Tensor', 'Embedding', 'ImageUrl', 'AnyUrl', 'ID', 'TorchTensor']

docarray/typing/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from docarray.typing.tensor.tensor import Tensor
2+
from docarray.typing.tensor.torch_tensor import TorchTensor
23

3-
__all__ = ['Tensor']
4+
__all__ = ['Tensor', 'TorchTensor']
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast
2+
3+
import numpy as np
4+
import torch # type: ignore
5+
6+
if TYPE_CHECKING:
7+
from pydantic.fields import ModelField
8+
from pydantic import BaseConfig
9+
import numpy as np
10+
11+
from docarray.document.base_node import BaseNode
12+
from docarray.proto import NdArrayProto, NodeProto
13+
14+
T = TypeVar('T', bound='TorchTensor')
15+
16+
torch_base = type(torch.Tensor) # type: Any
17+
node_base = type(BaseNode) # type: Any
18+
19+
20+
class metaTorchAndNode(torch_base, node_base):
21+
pass
22+
23+
24+
class TorchTensor(torch.Tensor, BaseNode, metaclass=metaTorchAndNode):
25+
# Subclassing torch.Tensor following the advice from here:
26+
# https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor
27+
@classmethod
28+
def __get_validators__(cls):
29+
# one or more validators may be yielded which will be called in the
30+
# order to validate the input, each validator will receive as an input
31+
# the value returned from the previous validator
32+
yield cls.validate
33+
34+
@classmethod
35+
def validate(
36+
cls: Type[T],
37+
value: Union[T, np.ndarray, Any],
38+
field: 'ModelField',
39+
config: 'BaseConfig',
40+
) -> T:
41+
if isinstance(value, TorchTensor):
42+
return cast(T, value)
43+
elif isinstance(value, torch.Tensor):
44+
return cls.from_native_torch_tensor(value)
45+
46+
else:
47+
try:
48+
arr: torch.Tensor = torch.tensor(value)
49+
return cls.from_native_torch_tensor(arr)
50+
except Exception:
51+
pass # handled below
52+
raise ValueError(f'Expected a torch.Tensor, got {type(value)}')
53+
54+
@classmethod
55+
def from_native_torch_tensor(cls: Type[T], value: torch.Tensor) -> T:
56+
"""Create a TorchTensor from a native torch.Tensor
57+
58+
:param value: the native torch.Tensor
59+
:return: a TorchTensor
60+
"""
61+
value.__class__ = cls
62+
return cast(T, value)
63+
64+
@classmethod
65+
def from_ndarray(cls: Type[T], value: np.ndarray) -> T:
66+
"""Create a TorchTensor from a numpy array
67+
68+
:param value: the numpy array
69+
:return: a TorchTensor
70+
"""
71+
return cls.from_native_torch_tensor(torch.from_numpy(value))
72+
73+
def _to_node_protobuf(self: T, field: str = 'torch_tensor') -> NodeProto:
74+
"""Convert Document into a NodeProto protobuf message. This function should
75+
be called when the Document is nested into another Document that need to be
76+
converted into a protobuf
77+
:param field: field in which to store the content in the node proto
78+
:return: the nested item protobuf message
79+
"""
80+
nd_proto = NdArrayProto()
81+
self._flush_tensor_to_proto(nd_proto, value=self)
82+
return NodeProto(**{field: nd_proto})
83+
84+
@classmethod
85+
def _read_from_proto(cls: Type[T], pb_msg: 'NdArrayProto') -> 'T':
86+
"""
87+
read ndarray from a proto msg
88+
:param pb_msg:
89+
:return: a numpy array
90+
"""
91+
source = pb_msg.dense
92+
if source.buffer:
93+
x = np.frombuffer(source.buffer, dtype=source.dtype)
94+
return cls.from_ndarray(x.reshape(source.shape))
95+
elif len(source.shape) > 0:
96+
return cls.from_ndarray(np.zeros(source.shape))
97+
else:
98+
raise ValueError(f'proto message {pb_msg} cannot be cast to a TorchTensor')
99+
100+
@staticmethod
101+
def _flush_tensor_to_proto(pb_msg: 'NdArrayProto', value: 'TorchTensor'):
102+
value_np = value.detach().cpu().numpy()
103+
pb_msg.dense.buffer = value_np.tobytes()
104+
pb_msg.dense.ClearField('shape')
105+
pb_msg.dense.shape.extend(list(value_np.shape))
106+
pb_msg.dense.dtype = value_np.dtype.str

0 commit comments

Comments
 (0)