Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
Tuple,
Type,
TypeVar,
Union,
get_origin,
)

import numpy as np
Expand Down Expand Up @@ -286,9 +288,18 @@ def _get_content_from_node_proto(
raise ValueError(
'field_type cannot be None when trying to deserialize a BaseDoc'
)
return_field = field_type.from_protobuf(
getattr(value, content_key)
) # we get to the parent class
try:
return_field = field_type.from_protobuf(
getattr(value, content_key)
) # we get to the parent class
except Exception:
if get_origin(field_type) is Union:
raise ValueError(
'Union type is not supported for proto deserialization. Please use JSON serialization instead'
)
raise ValueError(
f'{field_type} is not supported for proto deserialization'
)
elif content_key == 'doc_array':
if field_name is None:
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions docarray/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
Union,
)

from docarray.utils._internal._typing import safe_issubclass

if TYPE_CHECKING:
from docarray import BaseDoc

Expand Down Expand Up @@ -147,9 +149,9 @@ def _get_field_type_by_access_path(
return doc_type._get_field_type(field)
else:
d = doc_type._get_field_type(field)
if issubclass(d, DocList):
if safe_issubclass(d, DocList):
return _get_field_type_by_access_path(d.doc_type, remaining)
elif issubclass(d, BaseDoc):
elif safe_issubclass(d, BaseDoc):
return _get_field_type_by_access_path(d, remaining)
else:
return None
Expand Down
4 changes: 2 additions & 2 deletions docarray/utils/_internal/_typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, ForwardRef, Optional
from typing import Any, ForwardRef, Optional, Union

from typing_extensions import get_origin
from typing_inspect import get_args, is_typevar, is_union_type
Expand Down Expand Up @@ -47,7 +47,7 @@ def safe_issubclass(x: type, a_tuple: type) -> bool:
Note that if the origin of 'x' is a list or tuple, the function immediately returns 'False'.
"""
if (
(get_origin(x) in (list, tuple, dict, set))
(get_origin(x) in (list, tuple, dict, set, Union))
or is_typevar(x)
or (type(x) == ForwardRef)
or is_typevar(x)
Expand Down
4 changes: 4 additions & 0 deletions docs/user_guide/sending/serialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,4 +303,8 @@ assert dv_from_proto_numpy.tensor_type == NdArray
assert isinstance(dv_from_proto_numpy.tensor, NdArray)
```

!!! note
Serialization to protobuf is not supported for union types involving `BaseDoc` types.



21 changes: 21 additions & 0 deletions tests/units/array/test_array_from_to_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,24 @@ def test_from_to_base64(protocol, compress, show_progress):
assert d1.image.url == d2.image.url
assert da[1].image.url is None
assert da2[1].image.url is None


def test_union_type_error(tmp_path):
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about Union[int, float] or Union[int, str]

@maxwelljin maxwelljin Jun 16, 2023

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the testing, the Union with basic types should work in all serialization methods. Only the Union with document types will cause those issues.

ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
docs.from_bytes(docs.to_bytes())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_bytes(docs_basic.to_bytes())
assert docs_copy == docs_basic
23 changes: 23 additions & 0 deletions tests/units/array/test_array_from_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,26 @@ class Book(BaseDoc):
tmp_file = str(tmpdir / 'tmp.csv')
with pytest.raises(TypeError):
docs.to_csv(tmp_file)


def test_union_type_error(tmp_path):
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
docs.to_csv(str(tmp_path) + ".csv")
DocList[CustomDoc].from_csv(str(tmp_path) + ".csv")

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_basic.to_csv(str(tmp_path) + ".csv")
docs_copy = DocList[BasisUnion].from_csv(str(tmp_path) + ".csv")
assert docs_copy == docs_basic
14 changes: 14 additions & 0 deletions tests/units/array/test_array_from_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,17 @@ def test_from_to_json():
assert d1.image.url == d2.image.url
assert da[1].image.url is None
assert da2[1].image.url is None


def test_union_type():
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

docs_copy = docs.from_json(docs.to_json())
assert docs == docs_copy
22 changes: 22 additions & 0 deletions tests/units/array/test_array_from_to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,25 @@ class Book(BaseDoc):
docs = DocList([Book(title='hello'), Book(title='world')])
with pytest.raises(TypeError):
docs.to_dataframe()


@pytest.mark.proto
def test_union_type_error():
from typing import Union

from docarray.documents import TextDoc

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
DocList[CustomDoc].from_dataframe(docs.to_dataframe())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_dataframe(docs_basic.to_dataframe())
assert docs_copy == docs_basic
20 changes: 20 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,23 @@ class ResultTestDoc(BaseDoc):
assert docs[0].matches[0].id == '0'
assert len(docs[0].matches) == 2
assert len(docs) == 1


@pytest.mark.proto
def test_union_type_error():
from typing import Union

class CustomDoc(BaseDoc):
ud: Union[TextDoc, ImageDoc] = TextDoc(text='union type')

docs = DocList[CustomDoc]([CustomDoc(ud=TextDoc(text='union type'))])

with pytest.raises(ValueError):
DocList[CustomDoc].from_protobuf(docs.to_protobuf())

class BasisUnion(BaseDoc):
ud: Union[int, str]

docs_basic = DocList[BasisUnion]([BasisUnion(ud="hello")])
docs_copy = DocList[BasisUnion].from_protobuf(docs_basic.to_protobuf())
assert docs_copy == docs_basic