Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,23 @@ def _get_content_from_node_proto(
return_field = getattr(value, content_key)

elif content_key in arg_to_container.keys():
field_type = cls.__fields__[field_name].type_ if field_name else None
field_type = (
cls.__fields__[field_name].type_
if field_name and field_name in cls.__fields__
else None
)
return_field = arg_to_container[content_key](
cls._get_content_from_node_proto(node, field_type=field_type)
for node in getattr(value, content_key).data
)

elif content_key == 'dict':
deser_dict: Dict[str, Any] = dict()
field_type = cls.__fields__[field_name].type_ if field_name else None
field_type = (
cls.__fields__[field_name].type_
if field_name and field_name in cls.__fields__
else None
)
for key_name, node in value.dict.data.items():
deser_dict[key_name] = cls._get_content_from_node_proto(
node, field_type=field_type
Expand Down
67 changes: 67 additions & 0 deletions tests/units/document/test_any_document.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import pytest
from typing import Dict, List

from docarray import DocList
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.typing import NdArray

Expand All @@ -22,3 +25,67 @@ class CustomDoc(BaseDoc):
assert any_doc.text == doc.text
assert any_doc.inner.text == doc.inner.text
assert (any_doc.inner.tensor == doc.inner.tensor).all()


@pytest.mark.parametrize('protocol', ['proto', 'json'])
def test_any_document_from_to(protocol):
class InnerDoc(BaseDoc):
text: str
t: Dict[str, str]

class DocTest(BaseDoc):
text: str
tags: Dict[str, int]
l: List[int]
d: InnerDoc
ld: DocList[InnerDoc]

inner_doc = InnerDoc(text='I am inner', t={'a': 'b'})
da = DocList[DocTest](
[
DocTest(
text='type1',
tags={'type': 1},
l=[1, 2],
d=inner_doc,
ld=DocList[InnerDoc]([inner_doc]),
),
DocTest(
text='type2',
tags={'type': 2},
l=[1, 2],
d=inner_doc,
ld=DocList[InnerDoc]([inner_doc]),
),
]
)

from docarray.base_doc import AnyDoc

if protocol == 'proto':
aux = DocList[AnyDoc].from_protobuf(da.to_protobuf())
else:
aux = DocList[AnyDoc].from_json(da.to_json())
assert len(aux) == 2
assert len(aux.id) == 2
for i, d in enumerate(aux):
assert d.tags['type'] == i + 1
assert d.text == f'type{i + 1}'
assert d.l == [1, 2]
if protocol == 'proto':
assert isinstance(d.d, AnyDoc)
assert d.d.text == 'I am inner' # inner Document is a Dict
assert d.d.t == {'a': 'b'}
else:
assert isinstance(d.d, dict)
assert d.d['text'] == 'I am inner' # inner Document is a Dict
assert d.d['t'] == {'a': 'b'}
assert len(d.ld) == 1
if protocol == 'proto':
assert isinstance(d.ld[0], AnyDoc)
assert d.ld[0].text == 'I am inner'
assert d.ld[0].t == {'a': 'b'}
else:
assert isinstance(d.ld[0], dict)
assert d.ld[0]['text'] == 'I am inner'
assert d.ld[0]['t'] == {'a': 'b'}