Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions docarray/array/mixins/parallel.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import sys
from contextlib import nullcontext
from math import ceil
from types import LambdaType
from typing import Callable, TYPE_CHECKING, Generator, Optional, overload, TypeVar
from typing import (
Callable,
TYPE_CHECKING,
Generator,
Optional,
overload,
TypeVar,
Union,
)

if TYPE_CHECKING:
from ...types import T
from ... import Document, DocumentArray
from multiprocessing.pool import ThreadPool, Pool


T_DA = TypeVar('T_DA')

Expand All @@ -20,6 +31,7 @@ def apply(
backend: str = 'thread',
num_worker: Optional[int] = None,
show_progress: bool = False,
pool: Optional[Union['Pool', 'ThreadPool']] = None,
) -> 'T':
"""Apply each element in itself with ``func``, return itself after modified.

Expand All @@ -35,6 +47,7 @@ def apply(
and the original object do **not** share the same memory.

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.
:param show_progress: show a progress bar

"""
Expand All @@ -60,6 +73,7 @@ def map(
backend: str = 'thread',
num_worker: Optional[int] = None,
show_progress: bool = False,
pool: Optional[Union['Pool', 'ThreadPool']] = None,
) -> Generator['T', None, None]:
"""Return an iterator that applies function to every **element** of iterable in parallel, yielding the results.

Expand All @@ -81,6 +95,7 @@ def map(

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param show_progress: show a progress bar
:param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.

:yield: anything return from ``func``
"""
Expand All @@ -89,7 +104,14 @@ def map(

from rich.progress import track

with _get_pool(backend, num_worker) as p:
if pool:
p = pool
ctx_p = nullcontext()
else:
p = _get_pool(backend, num_worker)
ctx_p = p

with ctx_p:
for x in track(
p.imap(func, self), total=len(self), disable=not show_progress
):
Expand All @@ -104,6 +126,7 @@ def apply_batch(
num_worker: Optional[int] = None,
shuffle: bool = False,
show_progress: bool = False,
pool: Optional[Union['Pool', 'ThreadPool']] = None,
) -> 'T':
"""Apply each element in itself with ``func``, return itself after modified.

Expand All @@ -122,6 +145,7 @@ def apply_batch(
:param batch_size: Size of each generated batch (except the last one, which might be smaller, default: 32)
:param shuffle: If set, shuffle the Documents before dividing into minibatches.
:param show_progress: show a progress bar
:param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.

"""
...
Expand Down Expand Up @@ -150,6 +174,7 @@ def map_batch(
num_worker: Optional[int] = None,
shuffle: bool = False,
show_progress: bool = False,
pool: Optional[Union['Pool', 'ThreadPool']] = None,
) -> Generator['T', None, None]:
"""Return an iterator that applies function to every **minibatch** of iterable in parallel, yielding the results.
Each element in the returned iterator is :class:`DocumentArray`.
Expand All @@ -174,6 +199,7 @@ def map_batch(

:param num_worker: the number of parallel workers. If not given, then the number of CPUs in the system will be used.
:param show_progress: show a progress bar
:param pool: use an existing/external pool. If given, `backend` is ignored and you will be responsible for closing the pool.

:yield: anything return from ``func``
"""
Expand All @@ -183,7 +209,14 @@ def map_batch(

from rich.progress import track

with _get_pool(backend, num_worker) as p:
if pool:
p = pool
ctx_p = nullcontext()
else:
p = _get_pool(backend, num_worker)
ctx_p = p

with ctx_p:
for x in track(
p.imap(func, self.batch(batch_size=batch_size, shuffle=shuffle)),
total=ceil(len(self) / batch_size),
Expand Down
23 changes: 15 additions & 8 deletions docs/fundamentals/document/fluent-interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ from docarray import Document

d = (
Document(uri='apple.png')
.load_uri_to_image_tensor()
.set_image_tensor_shape((64, 64))
.set_image_tensor_normalization()
.save_image_tensor_to_file('apple1.png')
.load_uri_to_image_tensor()
.set_image_tensor_shape((64, 64))
.set_image_tensor_normalization()
.save_image_tensor_to_file('apple1.png')
)
```

Expand All @@ -34,10 +34,12 @@ from docarray import Document

d = Document(uri='apple.png')

(d.load_uri_to_image_tensor()
.set_image_tensor_shape((64, 64))
.set_image_tensor_normalization()
.save_image_tensor_to_file('apple1.png'))
(
d.load_uri_to_image_tensor()
.set_image_tensor_shape((64, 64))
.set_image_tensor_normalization()
.save_image_tensor_to_file('apple1.png')
)
```


Expand Down Expand Up @@ -118,6 +120,11 @@ Provide helper functions to convert to/from a Pydantic model
- {meth}`~docarray.document.mixins.pydantic.PydanticMixin.from_pydantic_model`


### Strawberry
Provide helper functions to convert to/from a Strawberry model
- {meth}`~docarray.document.mixins.strawberry.StrawberryMixin.from_strawberry_type`


### AudioData
Provide helper functions for {class}`Document` to support audio data.
- {meth}`~docarray.document.mixins.audio.AudioDataMixin.load_uri_to_audio_tensor`
Expand Down
127 changes: 78 additions & 49 deletions docs/fundamentals/documentarray/find.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,40 @@ Let's see some examples in action. First, let's prepare a DocumentArray we will
```python
from jina import Document, DocumentArray

da = DocumentArray([Document(text='journal', weight=25, tags={'h': 14, 'w': 21, 'uom': 'cm'}, modality='A'),
Document(text='notebook', weight=50, tags={'h': 8.5, 'w': 11, 'uom': 'in'}, modality='A'),
Document(text='paper', weight=100, tags={'h': 8.5, 'w': 11, 'uom': 'in'}, modality='D'),
Document(text='planner', weight=75, tags={'h': 22.85, 'w': 30, 'uom': 'cm'}, modality='D'),
Document(text='postcard', weight=45, tags={'h': 10, 'w': 15.25, 'uom': 'cm'}, modality='A')])
da = DocumentArray(
[
Document(
text='journal',
weight=25,
tags={'h': 14, 'w': 21, 'uom': 'cm'},
modality='A',
),
Document(
text='notebook',
weight=50,
tags={'h': 8.5, 'w': 11, 'uom': 'in'},
modality='A',
),
Document(
text='paper',
weight=100,
tags={'h': 8.5, 'w': 11, 'uom': 'in'},
modality='D',
),
Document(
text='planner',
weight=75,
tags={'h': 22.85, 'w': 30, 'uom': 'cm'},
modality='D',
),
Document(
text='postcard',
weight=45,
tags={'h': 10, 'w': 15.25, 'uom': 'cm'},
modality='A',
),
]
)

da.summary()
```
Expand Down Expand Up @@ -75,17 +104,17 @@ r = da.find({'modality': {'$eq': 'D'}})
pprint(r.to_dict(exclude_none=True)) # just for pretty print
```

```text
[{'id': '92aee5d665d0c4dd34db10d83642aded',
'modality': 'D',
'tags': {'h': 8.5, 'uom': 'in', 'w': 11.0},
'text': 'paper',
'weight': 100.0},
{'id': '1a9d2139b02bc1c7842ecda94b347889',
'modality': 'D',
'tags': {'h': 22.85, 'uom': 'cm', 'w': 30.0},
'text': 'planner',
'weight': 75.0}]
```json
[{"id": "92aee5d665d0c4dd34db10d83642aded",
"modality": "D",
"tags": {"h": 8.5, "uom": "in", "w": 11.0},
"text": "paper",
"weight": 100.0},
{"id": "1a9d2139b02bc1c7842ecda94b347889",
"modality": "D",
"tags": {"h": 22.85, "uom": "cm", "w": 30.0},
"text": "planner",
"weight": 75.0}]
```

To select all Documents whose `.tags['h']>10`,
Expand All @@ -94,17 +123,17 @@ To select all Documents whose `.tags['h']>10`,
r = da.find({'tags__h': {'$gt': 10}})
```

```text
[{'id': '4045a9659875fd1299e482d710753de3',
'modality': 'A',
'tags': {'h': 14.0, 'uom': 'cm', 'w': 21.0},
'text': 'journal',
'weight': 25.0},
{'id': 'cf7691c445220b94b88ff116911bad24',
'modality': 'D',
'tags': {'h': 22.85, 'uom': 'cm', 'w': 30.0},
'text': 'planner',
'weight': 75.0}]
```json
[{"id": "4045a9659875fd1299e482d710753de3",
"modality": "A",
"tags": {"h": 14.0, "uom": "cm", "w": 21.0},
"text": "journal",
"weight": 25.0},
{"id": "cf7691c445220b94b88ff116911bad24",
"modality": "D",
"tags": {"h": 22.85, "uom": "cm", "w": 30.0},
"text": "planner",
"weight": 75.0}]
```

Beside using a predefined value, one can also use a substitution with `{field}`, notice the curly brackets there. For example,
Expand All @@ -113,12 +142,12 @@ Beside using a predefined value, one can also use a substitution with `{field}`,
r = da.find({'tags__h': {'$gt': '{tags__w}'}})
```

```text
[{'id': '44c6a4b18eaa005c6dbe15a28a32ebce',
'modality': 'A',
'tags': {'h': 14.0, 'uom': 'cm', 'w': 10.0},
'text': 'journal',
'weight': 25.0}]
```json
[{"id": "44c6a4b18eaa005c6dbe15a28a32ebce",
"modality": "A",
"tags": {"h": 14.0, "uom": "cm", "w": 10.0},
"text": "journal",
"weight": 25.0}]
```


Expand All @@ -140,20 +169,20 @@ You can combine multiple conditions using the following operators
r = da.find({'$or': [{'weight': {'$eq': 45}}, {'modality': {'$eq': 'D'}}]})
```

```text
[{'id': '22985b71b6d483c31cbe507ed4d02bd1',
'modality': 'D',
'tags': {'h': 8.5, 'uom': 'in', 'w': 11.0},
'text': 'paper',
'weight': 100.0},
{'id': 'a071faf19feac5809642e3afcd3a5878',
'modality': 'D',
'tags': {'h': 22.85, 'uom': 'cm', 'w': 30.0},
'text': 'planner',
'weight': 75.0},
{'id': '411ecc70a71a3f00fc3259bf08c239d1',
'modality': 'A',
'tags': {'h': 10.0, 'uom': 'cm', 'w': 15.25},
'text': 'postcard',
'weight': 45.0}]
```json
[{"id": "22985b71b6d483c31cbe507ed4d02bd1",
"modality": "D",
"tags": {"h": 8.5, "uom": "in", "w": 11.0},
"text": "paper",
"weight": 100.0},
{"id": "a071faf19feac5809642e3afcd3a5878",
"modality": "D",
"tags": {"h": 22.85, "uom": "cm", "w": 30.0},
"text": "planner",
"weight": 75.0},
{"id": "411ecc70a71a3f00fc3259bf08c239d1",
"modality": "A",
"tags": {"h": 10.0, "uom": "cm", "w": 15.25},
"text": "postcard",
"weight": 45.0}]
```
6 changes: 3 additions & 3 deletions docs/fundamentals/documentarray/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ construct
serialization
access-elements
access-attributes
embedding
find
matching
evaluation
parallelization
visualization
post-external
embedding
matching
evaluation
```
11 changes: 11 additions & 0 deletions tests/unit/array/mixins/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool

import pytest

from docarray import DocumentArray, Document
Expand Down Expand Up @@ -25,6 +28,14 @@ def foo_batch(da: DocumentArray):
return da


@pytest.mark.parametrize('pool', [None, Pool(), ThreadPool()])
def test_parallel_map_apply_external_pool(pytestconfig, pool):
da = DocumentArray.from_files(f'{pytestconfig.rootdir}/**/*.jpeg')
assert da.tensors is None
da.apply(foo, pool=pool)
assert da.tensors is not None


@pytest.mark.parametrize(
'da_cls, config',
[
Expand Down