Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

from pydantic import BaseConfig, parse_obj_as
from typing_inspect import typingGenericAlias

from docarray.array.any_array import AnyDocArray
from docarray.array.doc_list.doc_list import DocList
Expand Down Expand Up @@ -148,7 +149,9 @@ def _verify_optional_field_of_docs(docs):
for i, doc in enumerate(docs):
if getattr(doc, field_name) is not None:
raise ValueError(
f'Field {field_name} is put to None for the first doc. This mean that all of the other docs should have this field set to None as well. This is not the case for {doc} at index {i}'
f'Field {field_name} is put to None for the first doc. This mean that '
f'all of the other docs should have this field set to None as well. '
f'This is not the case for {doc} at index {i}'
)

def _check_doc_field_not_none(field_name, doc):
Expand All @@ -159,6 +162,18 @@ def _check_doc_field_not_none(field_name, doc):

if is_tensor_union(field_type):
field_type = tensor_type
# all generic tensor types such as AnyTensor, ImageTensor, etc. are subclasses of AbstractTensor.
# Perform check only if the field_type is not an alias and is a subclass of AbstractTensor
elif not isinstance(field_type, typingGenericAlias) and issubclass(
field_type, AbstractTensor
):
# check if the tensor associated with the field_name in the document is a subclass of the tensor_type
# e.g. if the field_type is AnyTensor but the type(docs[0][field_name]) is ImageTensor,
# then we change the field_type to ImageTensor, since AnyTensor is a union of all the tensor types
# and does not override any methods of specific tensor types
tensor = getattr(docs[0], field_name)
if issubclass(tensor.__class__, tensor_type):
field_type = tensor_type

if isinstance(field_type, type):
if tf_available and issubclass(field_type, TensorFlowTensor):
Expand Down
96 changes: 87 additions & 9 deletions docarray/typing/tensor/audio/audio_tensor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,101 @@
from typing import Union
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast

import numpy as np

from docarray.typing.tensor.audio.abstract_audio_tensor import AbstractAudioTensor
from docarray.typing.tensor.audio.audio_ndarray import AudioNdArray
from docarray.typing.tensor.tensor import AnyTensor
from docarray.utils._internal.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
import torch

from docarray.typing.tensor.audio.audio_torch_tensor import AudioTorchTensor
from docarray.typing.tensor.torch_tensor import TorchTensor

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.audio.audio_tensorflow_tensor import (
AudioTensorFlowTensor as AudioTFTensor,
AudioTensorFlowTensor,
)
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor


if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

T = TypeVar("T", bound="AudioTensor")


class AudioTensor(AnyTensor, AbstractAudioTensor):
"""
Represents an audio tensor object that can be used with TensorFlow, PyTorch, and NumPy type.

---
'''python
from docarray import BaseDoc
from docarray.typing import AudioTensor


class MyAudioDoc(BaseDoc):
tensor: AudioTensor


# Example usage with TensorFlow:
import tensorflow as tf

doc = MyAudioDoc(tensor=tf.zeros(1000, 2))
type(doc.tensor) # AudioTensorFlowTensor

# Example usage with PyTorch:
import torch

doc = MyAudioDoc(tensor=torch.zeros(1000, 2))
type(doc.tensor) # AudioTorchTensor

# Example usage with NumPy:
import numpy as np

doc = MyAudioDoc(tensor=np.zeros((1000, 2)))
type(doc.tensor) # AudioNdArray
'''
---

Raises:
TypeError: If the input value is not a compatible type (torch.Tensor, tensorflow.Tensor, numpy.ndarray).

"""

@classmethod
def __get_validators__(cls):
yield cls.validate

AudioTensor = AudioNdArray
if tf_available and torch_available:
AudioTensor = Union[AudioNdArray, AudioTorchTensor, AudioTFTensor] # type: ignore
elif tf_available:
AudioTensor = Union[AudioNdArray, AudioTFTensor] # type: ignore
elif torch_available:
AudioTensor = Union[AudioNdArray, AudioTorchTensor] # type: ignore
@classmethod
def validate(
cls: Type[T],
value: Union[T, np.ndarray, Any],
field: "ModelField",
config: "BaseConfig",
):
if torch_available:
if isinstance(value, TorchTensor):
return cast(AudioTorchTensor, value)
elif isinstance(value, torch.Tensor):
return AudioTorchTensor._docarray_from_native(value) # noqa
if tf_available:
if isinstance(value, TensorFlowTensor):
return cast(AudioTensorFlowTensor, value)
elif isinstance(value, tf.Tensor):
return AudioTensorFlowTensor._docarray_from_native(value) # noqa
try:
return AudioNdArray.validate(value, field, config)
except Exception: # noqa
pass
raise TypeError(
f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
f"compatible type, got {type(value)}"
)
100 changes: 86 additions & 14 deletions docarray/typing/tensor/embedding/embedding.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,99 @@
from typing import Union
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast

import numpy as np

from docarray.typing.tensor.embedding.embedding_mixin import EmbeddingMixin
from docarray.typing.tensor.embedding.ndarray import NdArrayEmbedding
from docarray.utils._internal.misc import is_tf_available, is_torch_available
from docarray.typing.tensor.tensor import AnyTensor
from docarray.utils._internal.misc import is_tf_available, is_torch_available # noqa

torch_available = is_torch_available()
if torch_available:
import torch

from docarray.typing.tensor.embedding.torch import TorchEmbedding
from docarray.typing.tensor.torch_tensor import TorchTensor # noqa: F401


tf_available = is_tf_available()
if tf_available:
from docarray.typing.tensor.embedding.tensorflow import (
TensorFlowEmbedding as TFEmbedding,
)
import tensorflow as tf # type: ignore

from docarray.typing.tensor.embedding.tensorflow import TensorFlowEmbedding
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor # noqa: F401


if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField

T = TypeVar("T", bound="AnyEmbedding")


class AnyEmbedding(AnyTensor, EmbeddingMixin):
"""
Represents an embedding tensor object that can be used with TensorFlow, PyTorch, and NumPy type.

---
'''python
from docarray import BaseDoc
from docarray.typing import AnyEmbedding


class MyEmbeddingDoc(BaseDoc):
embedding: AnyEmbedding


# Example usage with TensorFlow:
import tensorflow as tf

doc = MyEmbeddingDoc(embedding=tf.zeros(1000, 2))
type(doc.embedding) # TensorFlowEmbedding

# Example usage with PyTorch:
import torch

doc = MyEmbeddingDoc(embedding=torch.zeros(1000, 2))
type(doc.embedding) # TorchEmbedding

# Example usage with NumPy:
import numpy as np

doc = MyEmbeddingDoc(embedding=np.zeros((1000, 2)))
type(doc.embedding) # NdArrayEmbedding
'''
---

Raises:
TypeError: If the type of the value is not one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray]
"""

if tf_available and torch_available:
AnyEmbedding = Union[NdArrayEmbedding, TorchEmbedding, TFEmbedding] # type: ignore
elif tf_available:
AnyEmbedding = Union[NdArrayEmbedding, TFEmbedding] # type: ignore
elif torch_available:
AnyEmbedding = Union[NdArrayEmbedding, TorchEmbedding] # type: ignore
else:
AnyEmbedding = Union[NdArrayEmbedding] # type: ignore
@classmethod
def __get_validators__(cls):
yield cls.validate

__all__ = ['AnyEmbedding']
@classmethod
def validate(
cls: Type[T],
value: Union[T, np.ndarray, Any],
field: "ModelField",
config: "BaseConfig",
):
if torch_available:
if isinstance(value, TorchTensor):
return cast(TorchEmbedding, value)
elif isinstance(value, torch.Tensor):
return TorchEmbedding._docarray_from_native(value) # noqa
if tf_available:
if isinstance(value, TensorFlowTensor):
return cast(TensorFlowEmbedding, value)
elif isinstance(value, tf.Tensor):
return TensorFlowEmbedding._docarray_from_native(value) # noqa
try:
return NdArrayEmbedding.validate(value, field, config)
except Exception: # noqa
pass
raise TypeError(
f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
f"compatible type, got {type(value)}"
)
101 changes: 91 additions & 10 deletions docarray/typing/tensor/image/image_tensor.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,104 @@
from typing import Union
from typing import TYPE_CHECKING, Any, Type, TypeVar, Union, cast

import numpy as np

from docarray.typing.tensor.image.abstract_image_tensor import AbstractImageTensor
from docarray.typing.tensor.image.image_ndarray import ImageNdArray
from docarray.typing.tensor.tensor import AnyTensor
from docarray.utils._internal.misc import is_tf_available, is_torch_available

torch_available = is_torch_available()
if torch_available:
from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor
import torch

from docarray.typing.tensor.image.image_torch_tensor import ImageTorchTensor
from docarray.typing.tensor.torch_tensor import TorchTensor

tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.image.image_tensorflow_tensor import (
ImageTensorFlowTensor as ImageTFTensor,
ImageTensorFlowTensor,
)
from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor


if TYPE_CHECKING:
from pydantic import BaseConfig
from pydantic.fields import ModelField


T = TypeVar("T", bound="ImageTensor")


class ImageTensor(AnyTensor, AbstractImageTensor):
"""
Represents an image tensor object that can be used with TensorFlow, PyTorch, and NumPy type.

---
'''python
from docarray import BaseDoc
from docarray.typing import ImageTensor


class MyImageDoc(BaseDoc):
image: ImageTensor


# Example usage with TensorFlow:
import tensorflow as tf

doc = MyImageDoc(image=tf.zeros((1000, 2)))
type(doc.image) # ImageTensorFlowTensor

# Example usage with PyTorch:
import torch

doc = MyImageDoc(image=torch.zeros((1000, 2)))
type(doc.image) # ImageTorchTensor

# Example usage with NumPy:
import numpy as np

doc = MyImageDoc(image=np.zeros((1000, 2)))
type(doc.image) # ImageNdArray
'''
---

Returns:
Union[ImageTorchTensor, ImageTensorFlowTensor, ImageNdArray]: The validated and converted image tensor.

Raises:
TypeError: If the input type is not one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray].
"""

@classmethod
def __get_validators__(cls):
yield cls.validate

ImageTensor = Union[ImageNdArray] # type: ignore
if tf_available and torch_available:
ImageTensor = Union[ImageNdArray, ImageTorchTensor, ImageTFTensor] # type: ignore
elif tf_available:
ImageTensor = Union[ImageNdArray, ImageTFTensor] # type: ignore
elif torch_available:
ImageTensor = Union[ImageNdArray, ImageTorchTensor] # type: ignore
@classmethod
def validate(
cls: Type[T],
value: Union[T, np.ndarray, Any],
field: "ModelField",
config: "BaseConfig",
):
if torch_available:
if isinstance(value, TorchTensor):
return cast(ImageTorchTensor, value)
elif isinstance(value, torch.Tensor):
return ImageTorchTensor._docarray_from_native(value) # noqa
if tf_available:
if isinstance(value, TensorFlowTensor):
return cast(ImageTensorFlowTensor, value)
elif isinstance(value, tf.Tensor):
return ImageTensorFlowTensor._docarray_from_native(value) # noqa
try:
return ImageNdArray.validate(value, field, config)
except Exception: # noqa
pass
raise TypeError(
f"Expected one of [torch.Tensor, tensorflow.Tensor, numpy.ndarray] "
f"compatible type, got {type(value)}"
)
Loading