Skip to content

Commit 40549f4

Browse files
author
Joan Fontanals
authored
fix: fix anydoc deserialization (#1571)
Signed-off-by: Joan Fontanals Martinez <joan.martinez@jina.ai>
1 parent 0e6aa3b commit 40549f4

2 files changed

Lines changed: 77 additions & 2 deletions

File tree

docarray/base_doc/mixins/io.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,23 @@ def _get_content_from_node_proto(
309309
return_field = getattr(value, content_key)
310310

311311
elif content_key in arg_to_container.keys():
312-
field_type = cls.__fields__[field_name].type_ if field_name else None
312+
field_type = (
313+
cls.__fields__[field_name].type_
314+
if field_name and field_name in cls.__fields__
315+
else None
316+
)
313317
return_field = arg_to_container[content_key](
314318
cls._get_content_from_node_proto(node, field_type=field_type)
315319
for node in getattr(value, content_key).data
316320
)
317321

318322
elif content_key == 'dict':
319323
deser_dict: Dict[str, Any] = dict()
320-
field_type = cls.__fields__[field_name].type_ if field_name else None
324+
field_type = (
325+
cls.__fields__[field_name].type_
326+
if field_name and field_name in cls.__fields__
327+
else None
328+
)
321329
for key_name, node in value.dict.data.items():
322330
deser_dict[key_name] = cls._get_content_from_node_proto(
323331
node, field_type=field_type

tests/units/document/test_any_document.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
2+
import pytest
3+
from typing import Dict, List
24

5+
from docarray import DocList
36
from docarray.base_doc import AnyDoc, BaseDoc
47
from docarray.typing import NdArray
58

@@ -22,3 +25,67 @@ class CustomDoc(BaseDoc):
2225
assert any_doc.text == doc.text
2326
assert any_doc.inner.text == doc.inner.text
2427
assert (any_doc.inner.tensor == doc.inner.tensor).all()
28+
29+
30+
@pytest.mark.parametrize('protocol', ['proto', 'json'])
31+
def test_any_document_from_to(protocol):
32+
class InnerDoc(BaseDoc):
33+
text: str
34+
t: Dict[str, str]
35+
36+
class DocTest(BaseDoc):
37+
text: str
38+
tags: Dict[str, int]
39+
l: List[int]
40+
d: InnerDoc
41+
ld: DocList[InnerDoc]
42+
43+
inner_doc = InnerDoc(text='I am inner', t={'a': 'b'})
44+
da = DocList[DocTest](
45+
[
46+
DocTest(
47+
text='type1',
48+
tags={'type': 1},
49+
l=[1, 2],
50+
d=inner_doc,
51+
ld=DocList[InnerDoc]([inner_doc]),
52+
),
53+
DocTest(
54+
text='type2',
55+
tags={'type': 2},
56+
l=[1, 2],
57+
d=inner_doc,
58+
ld=DocList[InnerDoc]([inner_doc]),
59+
),
60+
]
61+
)
62+
63+
from docarray.base_doc import AnyDoc
64+
65+
if protocol == 'proto':
66+
aux = DocList[AnyDoc].from_protobuf(da.to_protobuf())
67+
else:
68+
aux = DocList[AnyDoc].from_json(da.to_json())
69+
assert len(aux) == 2
70+
assert len(aux.id) == 2
71+
for i, d in enumerate(aux):
72+
assert d.tags['type'] == i + 1
73+
assert d.text == f'type{i + 1}'
74+
assert d.l == [1, 2]
75+
if protocol == 'proto':
76+
assert isinstance(d.d, AnyDoc)
77+
assert d.d.text == 'I am inner' # inner Document is a Dict
78+
assert d.d.t == {'a': 'b'}
79+
else:
80+
assert isinstance(d.d, dict)
81+
assert d.d['text'] == 'I am inner' # inner Document is a Dict
82+
assert d.d['t'] == {'a': 'b'}
83+
assert len(d.ld) == 1
84+
if protocol == 'proto':
85+
assert isinstance(d.ld[0], AnyDoc)
86+
assert d.ld[0].text == 'I am inner'
87+
assert d.ld[0].t == {'a': 'b'}
88+
else:
89+
assert isinstance(d.ld[0], dict)
90+
assert d.ld[0]['text'] == 'I am inner'
91+
assert d.ld[0]['t'] == {'a': 'b'}

0 commit comments

Comments
 (0)