Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

from .document import Document
from .array import DocumentArray
from .dataclasses import dataclass, Image, Text, Audio, JSON, field

from .dataclasses import dataclass, field

if 'DA_NO_RICH_HANDLER' not in os.environ:
from rich.traceback import install
Expand Down
5 changes: 2 additions & 3 deletions docarray/array/mixins/io/pushpull.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def push(self, name: str, show_progress: bool = False) -> Dict:
headers['Authorization'] = f'token {auth_token}'

_head, _tail = data.split(delimiter)
_head += self._stream_header
from rich import filesize
from .pbar import get_progressbar

Expand All @@ -103,9 +104,7 @@ def gen():

pbar.start_task(t)

chunk = _head + self._stream_header

yield chunk
yield _head

def _get_chunk(_batch):
return b''.join(
Expand Down
2 changes: 1 addition & 1 deletion docarray/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .types import dataclass, Text, Audio, JSON, Image, is_multimodal, field
from .types import dataclass, is_multimodal, field
49 changes: 28 additions & 21 deletions docarray/dataclasses/getter.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,43 @@
import json
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from docarray import Document


def image_getter(doc: 'Document', field_name: str):
if 'image_type' in doc._metadata:
if doc._metadata['image_type'] == 'uri':
return doc._metadata['image_uri']
elif doc._metadata['image_type'] == 'PIL':
from PIL import Image
def image_getter(doc: 'Document'):
if doc._metadata['image_type'] == 'uri':
return doc.uri
elif doc._metadata['image_type'] == 'PIL':
from PIL import Image

return Image.fromarray(doc.tensor)
elif doc._metadata['image_type'] == 'ndarray':
return doc.tensor
else:
raise ValueError('Invalid image Document')
return Image.fromarray(doc.tensor)
elif doc._metadata['image_type'] == 'ndarray':
return doc.tensor


def text_getter(doc: 'Document', field_name: str):
def text_getter(doc: 'Document'):
return doc.text


def audio_getter(doc: 'Document', field_name: str):
from PIL import Image
def audio_getter(doc: 'Document'):
return doc.uri or doc.tensor

return Image.fromarray(doc.tensor)

def video_getter(doc: 'Document'):
return doc.uri or doc.tensor

def json_getter(doc: 'Document', field_name: str):
if doc._metadata['json_type'] == 'str':
return json.dumps(doc.tags[field_name])
else:
return doc.tags[field_name]

def mesh_getter(doc: 'Document'):
return doc.uri or doc.tensor


def tabular_getter(doc: 'Document'):
return doc.uri


def blob_getter(doc: 'Document'):
return doc.uri or doc.blob


def json_getter(doc: 'Document'):
return doc.tags
79 changes: 54 additions & 25 deletions docarray/dataclasses/setter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -7,52 +6,82 @@
from docarray import Document


def image_setter(field_name: str, value) -> 'Document':
from PIL.Image import Image
def image_setter(value) -> 'Document':
from docarray import Document

doc = Document(modality='image')

if isinstance(value, str):
doc.uri = value
doc._metadata['image_type'] = 'uri'
doc._metadata['image_uri'] = value
doc.load_uri_to_image_tensor()
elif isinstance(value, Image):
doc.tensor = np.array(value)
doc._metadata['image_type'] = 'PIL'
else:
elif isinstance(value, np.ndarray):
doc.tensor = value
doc._metadata['image_type'] = 'ndarray'
else:
from PIL.Image import Image

if isinstance(value, Image):
doc.tensor = np.array(value)
doc._metadata['image_type'] = 'PIL'
return doc


def text_setter(field_name: str, value) -> 'Document':
def text_setter(value) -> 'Document':
from docarray import Document

return Document(text=value, modality='text')


def audio_setter(field_name: str, value) -> 'Document':
import librosa
def audio_setter(value) -> 'Document':
from docarray import Document

audio, sr = librosa.load(value)
return Document(
tensor=audio,
_metadata={'audio_sample_rate': sr, 'audio_uri': str(value)},
modality='audio',
)
if isinstance(value, np.ndarray):
return Document(tensor=value, _metadata={'audio_type': 'ndarray'})
else:
return Document(
uri=value, modality='audio', _metadata={'audio_type': 'uri'}
).load_uri_to_audio_tensor()


def json_setter(field_name: str, value) -> 'Document':
def video_setter(value) -> 'Document':
from docarray import Document

doc = Document()
if isinstance(value, str):
value = json.loads(value)
doc._metadata['json_type'] = 'str'
if isinstance(value, np.ndarray):
return Document(tensor=value, _metadata={'video_type': 'ndarray'})
else:
doc._metadata['json_type'] = 'dict'
doc.tags[field_name] = value
return doc
return Document(
uri=value, modality='video', _metadata={'video_type': 'uri'}
).load_uri_to_video_tensor()


def mesh_setter(value) -> 'Document':
from docarray import Document

if isinstance(value, np.ndarray):
return Document(tensor=value, _metadata={'mesh_type': 'ndarray'})
else:
return Document(
uri=value, modality='mesh', _metadata={'mesh_type': 'uri'}
).load_uri_to_point_cloud_tensor(1000)


def blob_setter(value) -> 'Document':
from docarray import Document

if isinstance(value, bytes):
return Document(blob=value, _metadata={'blob_type': 'bytes'})
else:
return Document(uri=value, _metadata={'blob_type': 'uri'}).load_uri_to_blob()


def json_setter(value) -> 'Document':
from docarray import Document

return Document(modality='json', tags=value)


def tabular_setter(value) -> 'Document':
from docarray import Document, DocumentArray

return Document(uri=value, chunks=DocumentArray.from_csv(value), modality='tabular')
77 changes: 25 additions & 52 deletions docarray/dataclasses/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,26 @@
MISSING,
)
from enum import Enum
from pathlib import Path
from typing import (
TypeVar,
ForwardRef,
Callable,
Optional,
TYPE_CHECKING,
overload,
Dict,
Type,
)

from .setter import (
image_setter,
text_setter,
audio_setter,
json_setter,
)
from .getter import (
image_getter,
text_getter,
audio_getter,
json_getter,
)
from .getter import *
from .setter import *

if TYPE_CHECKING:
import scipy.sparse
import tensorflow
import torch
import numpy as np
from ..typing import T
from docarray import Document
from PIL.Image import Image as PILImage

from ..typing import Image, Text, Audio, Video, Mesh, Tabular, Blob, JSON

__all__ = ['field', 'dataclass', 'is_multimodal']


class AttributeType(str, Enum):
Expand Down Expand Up @@ -71,12 +58,6 @@ def copy_from(self, f: '_Field'):
for s in f.__slots__:
setattr(self, s, getattr(f, s))

def get_field(self, doc: 'Document'):
return self.getter(doc, self.name)

def set_field(self, val) -> 'Document':
return self.setter(self.name, val)


@overload
def field(
Expand All @@ -91,38 +72,25 @@ def field(
hash=None,
compare=True,
metadata=None,
) -> _Field:
) -> Field:
...


def field(**kwargs) -> Field:
return Field(**kwargs)


Image = TypeVar(
'Image',
ForwardRef('np.ndarray'),
ForwardRef('tensorflow.Tensor'),
ForwardRef('torch.Tensor'),
str,
ForwardRef('PILImage'),
)

Text = TypeVar('Text', bound=str)

Audio = TypeVar(
'Audio',
str,
Path,
)

JSON = TypeVar('JSON', str, dict)

_TYPES_REGISTRY = {
Image: lambda x: field(setter=image_setter, getter=image_getter, _source_field=x),
Text: lambda x: field(setter=text_setter, getter=text_getter, _source_field=x),
Audio: lambda x: field(setter=audio_setter, getter=audio_getter, _source_field=x),
JSON: lambda x: field(setter=json_setter, getter=json_getter, _source_field=x),
Video: lambda x: field(setter=video_setter, getter=video_getter, _source_field=x),
Tabular: lambda x: field(
setter=tabular_setter, getter=tabular_getter, _source_field=x
),
Blob: lambda x: field(setter=blob_setter, getter=blob_getter, _source_field=x),
Mesh: lambda x: field(setter=mesh_setter, getter=mesh_getter, _source_field=x),
}


Expand Down Expand Up @@ -151,8 +119,8 @@ def dataclass(

Example usage:

>>> from docarray import dataclass, Image, Text
>>>
>>> from docarray.typing import Image, Text
>>> from docarray import dataclass
>>> @dataclass:
>>> class X:
>>> banner: Image = 'apple.png'
Expand Down Expand Up @@ -192,7 +160,7 @@ def deco(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
if not kwargs and len(args) == 2 and isinstance(args[1], Document):
return f(args[0], **from_document(type(args[0]), args[1]))
return f(args[0], **_from_document(type(args[0]), args[1]))
else:
return f(*args, **kwargs)

Expand Down Expand Up @@ -233,10 +201,15 @@ def wrap(cls):

def is_multimodal(obj) -> bool:
"""Returns True if obj is an instance of :meth:`.dataclass`."""
return _is_dataclass(obj) and hasattr(obj, '__is_multimodal__')
from docarray import Document

if isinstance(obj, Document):
return obj.is_multimodal
else:
return _is_dataclass(obj) and hasattr(obj, '__is_multimodal__')


def from_document(cls: Type['T'], doc: 'Document') -> 'T':
def _from_document(cls: Type['T'], doc: 'Document') -> 'T':
if not doc.is_multimodal:
raise ValueError(
f'{doc} is not a multimodal doc instantiated from a class wrapped by `docarray.dataclasses.tdataclass`.'
Expand Down Expand Up @@ -284,12 +257,12 @@ def from_document(cls: Type['T'], doc: 'Document') -> 'T':

def _get_doc_attribute(attribute_doc: 'Document', field):
if isinstance(field, Field):
return field.get_field(attribute_doc)
return field.getter(attribute_doc)
else:
raise ValueError('Invalid attribute type')


def _get_doc_nested_attribute(attribute_doc: 'Document', nested_cls: Type['T']) -> 'T':
if not is_multimodal(nested_cls):
raise ValueError(f'Nested attribute {nested_cls.__name__} is not a dataclass')
return nested_cls(**from_document(nested_cls, attribute_doc))
return nested_cls(**_from_document(nested_cls, attribute_doc))
7 changes: 4 additions & 3 deletions docarray/document/mixins/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def _from_dataclass(cls, obj) -> 'Document':
multi_modal_schema = {}
for key, field in obj.__dataclass_fields__.items():
attribute = getattr(obj, key)
if attribute is None:
continue

if field.type in [str, int, float, bool] and not isinstance(field, Field):
tags[key] = attribute
multi_modal_schema[key] = {
Expand Down Expand Up @@ -126,15 +129,13 @@ def get_multi_modal_attribute(self, attribute: str) -> 'DocumentArray':

@classmethod
def _from_obj(cls, obj, obj_type, field) -> typing.Tuple['Document', AttributeType]:
from docarray import Document

attribute_type = AttributeType.DOCUMENT

if is_multimodal(obj_type):
doc = cls(obj)
attribute_type = AttributeType.NESTED
elif isinstance(field, Field):
doc = field.set_field(obj)
doc = field.setter(obj)
else:
raise ValueError(f'Unsupported type annotation')
return doc, attribute_type
Loading