Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
93d85d7
perf: speedup id construction time
davidbp Mar 10, 2022
3729744
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 10, 2022
61343a1
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 10, 2022
41c9aa5
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 11, 2022
d7e435d
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 11, 2022
bdcd9e2
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 14, 2022
54b7cef
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 14, 2022
30cd10a
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 15, 2022
ebae898
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 16, 2022
ef2b174
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 22, 2022
f11e826
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 22, 2022
d5a4846
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 28, 2022
db63e06
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 28, 2022
dfc1164
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 28, 2022
5f067f9
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 29, 2022
04e7f45
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 30, 2022
f8b0ca1
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 31, 2022
5fb4614
Merge branch 'main' of https://github.com/jina-ai/docarray
davidbp Mar 31, 2022
c797ff6
perf: decrease memory cost storing torch tensors
davidbp Mar 31, 2022
16aeec0
refactor: write detach tensor in ndarray
davidbp Mar 31, 2022
333b331
refactor: allow array type to handle non arrays
davidbp Mar 31, 2022
65217b8
refactor: rename method
davidbp Mar 31, 2022
d0b90f9
refactor: detach only for dense
davidbp Mar 31, 2022
b11866c
refactor: remove attribute modification at set tie
davidbp Apr 1, 2022
a266adc
refactor: move detached attribute modification to getstate
davidbp Apr 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docarray/array/mixins/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __setitem__(
raise IndexError(f'Unsupported index type {typename(index)}: {index}')

def _set_by_pair(self, idx1, idx2, value):

if isinstance(idx1, str) and not idx1.startswith('@'):
# second is an ID
# allows da[id1, id2] = [d1, d2]
Expand All @@ -136,6 +137,7 @@ def _set_by_pair(self, idx1, idx2, value):
and all(isinstance(attr, str) for attr in idx2)
and all(hasattr(self[idx1], attr) for attr in idx2)
):

for attr, _v in zip(idx2, value):
self._set_doc_attr_by_id(idx1, attr, _v)
else:
Expand Down
14 changes: 14 additions & 0 deletions docarray/document/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .data import DocumentData
from .mixins import AllMixins
from ..base import BaseDCType
from ..math.ndarray import detach_tensor_if_present

if TYPE_CHECKING:
from ..types import ArrayType, StructValueType, DocumentContentType
Expand Down Expand Up @@ -88,3 +89,16 @@ def __init__(

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __getstate__(self):
state = self.__dict__.copy()

for attribute in ['embedding', 'tensor']:
if hasattr(self, attribute):
setattr(
state['_data'],
attribute,
detach_tensor_if_present(getattr(state['_data'], attribute)),
)

return state
24 changes: 21 additions & 3 deletions docarray/math/ndarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Tuple, Sequence, Optional, List
from typing import TYPE_CHECKING, Tuple, Sequence, Optional, List, Any

import numpy as np

Expand Down Expand Up @@ -81,7 +81,9 @@ def ravel(value: 'ArrayType', docs: 'DocumentArray', field: str) -> None:
docs[d.id, field] = value[j, ...]


def get_array_type(array: 'ArrayType') -> Tuple[str, bool]:
def get_array_type(
array: 'ArrayType', raise_error_if_not_array: bool = True
) -> Tuple[str, bool]:
"""Get the type of ndarray without importing the framework

:param array: any array, scipy, numpy, tf, torch, etc.
Expand Down Expand Up @@ -121,7 +123,10 @@ def get_array_type(array: 'ArrayType') -> Tuple[str, bool]:
if 'scipy' in module_tags and 'sparse' in module_tags:
return 'scipy', True

raise TypeError(f'can not determine the array type: {module_tags}.{class_name}')
if raise_error_if_not_array:
raise TypeError(f'can not determine the array type: {module_tags}.{class_name}')
Comment thread
davidbp marked this conversation as resolved.
else:
return 'python', False


def to_numpy_array(value) -> 'np.ndarray':
Expand Down Expand Up @@ -258,3 +263,16 @@ def check_arraylike_equality(x: 'ArrayType', y: 'ArrayType'):
return same_array
else:
return same_array


def detach_tensor_if_present(x: Any) -> Any:
"""Check if input is a dense torch array and detaches the tensor from the current graph.
:param array: input array
:return: (num_rows, ndim)
"""
x_type, x_sparse = get_array_type(x, raise_error_if_not_array=False)
if x_type == 'torch' and x_sparse == False:
import torch

x = torch.tensor(x.detach().numpy())
return x