Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docarray/dataclasses/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,23 +234,23 @@ def _from_document(cls: Type['T'], doc: 'Document') -> 'T':
]:
attributes[key] = doc.tags[key]
elif attribute_info['attribute_type'] == AttributeType.DOCUMENT:
attribute_doc = doc.chunks[position]
attribute_doc = doc.chunks[int(position)]
attribute = _get_doc_attribute(attribute_doc, field)
attributes[key] = attribute
elif attribute_info['attribute_type'] == AttributeType.ITERABLE_DOCUMENT:
attribute_list = []
for chunk_doc in doc.chunks[position].chunks:
for chunk_doc in doc.chunks[int(position)].chunks:
attribute_list.append(_get_doc_attribute(chunk_doc, field))
attributes[key] = attribute_list
elif attribute_info['attribute_type'] == AttributeType.NESTED:
nested_cls = field.type
attributes[key] = _get_doc_nested_attribute(
doc.chunks[position], nested_cls
doc.chunks[int(position)], nested_cls
)
elif attribute_info['attribute_type'] == AttributeType.ITERABLE_NESTED:
nested_cls = cls.__dataclass_fields__[key].type.__args__[0]
attribute_list = []
for chunk_doc in doc.chunks[position].chunks:
for chunk_doc in doc.chunks[int(position)].chunks:
attribute_list.append(_get_doc_nested_attribute(chunk_doc, nested_cls))
attributes[key] = attribute_list
else:
Expand Down
4 changes: 2 additions & 2 deletions docarray/document/mixins/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,12 @@ def get_multi_modal_attribute(self, attribute: str) -> 'DocumentArray':
position = self._metadata['multi_modal_schema'][attribute].get('position')
Comment thread
alaeddine-13 marked this conversation as resolved.

if attribute_type in [AttributeType.DOCUMENT, AttributeType.NESTED]:
return DocumentArray([self.chunks[position]])
return DocumentArray([self.chunks[int(position)]])
Comment thread
alaeddine-13 marked this conversation as resolved.
elif attribute_type in [
AttributeType.ITERABLE_DOCUMENT,
AttributeType.ITERABLE_NESTED,
]:
return self.chunks[position].chunks
return self.chunks[int(position)].chunks
Comment thread
alaeddine-13 marked this conversation as resolved.
else:
raise ValueError(
f'Invalid attribute {attribute}: must a Document attribute or nested dataclass'
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/document/test_multi_modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,12 @@ class MMDocument:
assert deserialized_doc.chunks[1].tensor.shape == (10, 10, 3)
assert deserialized_doc.tags['version'] == 20

images = deserialized_doc.get_multi_modal_attribute('image')
titles = doc.get_multi_modal_attribute('title')

assert images[0].tensor.shape == (10, 10, 3)
assert titles[0].text == 'hello world'

assert 'multi_modal_schema' in deserialized_doc._metadata

expected_schema = [
Expand All @@ -491,8 +497,10 @@ class MMDocument:
]
_assert_doc_schema(deserialized_doc, expected_schema)

translated_obj = MMDocument(doc)
assert translated_obj == obj
translated_obj = MMDocument(deserialized_doc)
assert (translated_obj.image == obj.image).all()
assert translated_obj.title == obj.title
assert translated_obj.version == obj.version


def test_json_type():
Expand Down