Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions docarray/document/mixins/plot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import copy
from typing import Optional

import numpy as np

from ...helper import deprecate_by


Expand Down Expand Up @@ -97,6 +102,122 @@ def display(self):

plot = deprecate_by(display, removed_at='0.5')

def plot_matches_sprites(
self,
top_k: int = 10,
channel_axis: int = -1,
inv_normalize: bool = False,
skip_empty: bool = False,
canvas_size: int = 1920,
min_size: int = 100,
output: Optional[str] = None,
):
"""Generate a sprite image for the query and its matching images in this Document object.

An image sprite is a collection of images put into a single image. Query image is on the left
followed by matching images. The Document object should contain matches.

:param top_k: the number of top matching documents to show in the sprite.
:param channel_axis: the axis id of the color channel, ``-1`` indicates the color channel info at the last axis
:param inv_normalize: If set to True, inverse the normalization of a float32 image :attr:`.tensor` into a uint8
image :attr:`.tensor` inplace.
:param skip_empty: skip matches which has no .uri or .tensor.
:param canvas_size: the width of the canvas
:param min_size: the minimum size of the image
:param output: Optional path to store the visualization. If not given, show in UI
"""
if not self or not self.matches:
raise ValueError(f'{self!r} is empty or has no matches')

if not self.uri and self.tensor is None:
raise ValueError(
f'Document has neither `uri` nor `tensor`, cannot be plotted'
)

if top_k <= 0:
raise ValueError(f'`limit` must be larger than 0, receiving {top_k}')

import matplotlib.pyplot as plt

img_per_row = top_k + 2
if top_k > len(self.matches):
img_per_row = len(self.matches) + 2

img_size = int((canvas_size - 50) / img_per_row)
if img_size < min_size:
# image is too small, recompute the image size and canvas size
img_size = min_size
canvas_size = img_per_row * img_size + 50

try:
_d = copy.deepcopy(self)
if _d.content_type != 'tensor':
_d.load_uri_to_image_tensor() # the channel axis is -1

if inv_normalize:
# inverse normalise to uint8 and set the channel axis to -1
_d.set_image_tensor_inv_normalization(channel_axis)

_d.set_image_tensor_channel_axis(channel_axis, -1)

# Maintain the aspect ratio keeping the width fixed
h, w, _ = _d.tensor.shape
img_h, img_w = int(h * (img_size / float(w))), img_size

sprite_img = np.ones([img_h + 20, canvas_size, 3], dtype='uint8')

_d.set_image_tensor_shape(shape=(img_h, img_w))

sprite_img[10 : img_h + 10, 10 : 10 + img_w] = _d.tensor
pos = canvas_size // img_per_row

for col_id, d in enumerate(self.matches, start=2):
if not d.uri and d.tensor is None:
if skip_empty:
continue
else:
raise ValueError(
f'Document match has neither `uri` nor `tensor`, cannot be plotted'
)
_d = copy.deepcopy(d)
if _d.content_type != 'tensor':
_d.load_uri_to_image_tensor()

if inv_normalize:
_d.set_image_tensor_inv_normalization(channel_axis=channel_axis)

_d.set_image_tensor_channel_axis(
channel_axis, -1
).set_image_tensor_shape(shape=(img_h, img_w))

# paste it on the main canvas
sprite_img[
10 : img_h + 10,
(col_id * pos) : ((col_id * pos) + img_w),
] = _d.tensor

col_id += 1
if col_id >= img_per_row:
break
except Exception as ex:
raise ValueError('Bad image tensor. Try different `channel_axis`') from ex
from PIL import Image

im = Image.fromarray(sprite_img)

if output:
with open(output, 'wb') as fp:
im.save(fp)
else:
plt.figure(figsize=(img_per_row, 2))
plt.gca().set_axis_off()
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
plt.margins(0, 0)
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.imshow(im, interpolation="none")
plt.show()


def _convert_display_uri(uri, mime_type):
import urllib
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions docs/fundamentals/documentarray/visualization.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,32 @@ If a DocumentArray contains all image Documents, you can plot all images in one

```python
from docarray import DocumentArray

docs = DocumentArray.from_files('*.jpg')
docs.plot_image_sprites()
```

```{figure} images/sprite-image.png
:width: 60%
```
(plot-matches)=
### Plot Matches

If an image Document contains the matching images in its `.matches` attribute, you can visualise the matching results using {meth}`~docarray.document.mixins.plot.PlotMixin.plot_matches_sprites`.

```python
import numpy as np
from docarray import DocumentArray

da = DocumentArray.from_files('*.jpg')
da.embeddings = np.random.random([len(da), 10])
da.match(da)
da[0].plot_matches_sprites(top_k=5, channel_axis=-1, inv_normalize=False)
```

```{figure} images/sprite-match.png
:width: 60%
```

(visualize-embeddings)=
## Embedding projector
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,34 @@
import tempfile
import os
import time

import pytest
from elasticsearch import Elasticsearch

cur_dir = os.path.dirname(os.path.abspath(__file__))
compose_yml = os.path.abspath(
os.path.join(cur_dir, 'unit', 'array', 'docker-compose.yml')
)


@pytest.fixture(autouse=True)
def tmpfile(tmpdir):
tmpfile = f'docarray_test_{next(tempfile._get_candidate_names())}.db'
return tmpdir / tmpfile


@pytest.fixture(scope='session')
def start_storage():
os.system(
f"docker-compose -f {compose_yml} --project-directory . up --build -d "
f"--remove-orphans"
)
es = Elasticsearch(hosts='http://localhost:9200/')
while not es.ping():
time.sleep(0.5)

yield
os.system(
f"docker-compose -f {compose_yml} --project-directory . down "
f"--remove-orphans"
)
25 changes: 0 additions & 25 deletions tests/unit/array/conftest.py

This file was deleted.

105 changes: 105 additions & 0 deletions tests/unit/document/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os

import numpy as np
import pytest

from docarray import DocumentArray, Document
from docarray.array.annlite import DocumentArrayAnnlite
from docarray.array.elastic import DocumentArrayElastic, ElasticConfig
from docarray.array.qdrant import DocumentArrayQdrant
from docarray.array.sqlite import DocumentArraySqlite
from docarray.array.storage.annlite import AnnliteConfig
from docarray.array.storage.qdrant import QdrantConfig
from docarray.array.storage.weaviate import WeaviateConfig
from docarray.array.weaviate import DocumentArrayWeaviate


@pytest.fixture()
def embed_docs(pytestconfig):
index_files = [
f'{pytestconfig.rootdir}/tests/image-data/*.jpg',
]
query_file = [
f'{pytestconfig.rootdir}/tests/image-data/*.png',
]
dai = DocumentArray.from_files(index_files)
daq = DocumentArray.from_files(query_file)

for doc in dai + daq:
doc.embedding = np.random.random(128)

return daq, dai


def test_empty_doc(embed_docs):
da = DocumentArray([Document(embedding=np.random.random(128))])
with pytest.raises(ValueError):
da[0].plot_matches_sprites()

daq, dai = embed_docs

with pytest.raises(ValueError):
daq[0].plot_matches_sprites()

with pytest.raises(ValueError):
daq[0].plot_matches_sprites(top_k=0)


@pytest.mark.parametrize('top_k', [1, 10, 20])
@pytest.mark.parametrize(
'da_cls,config',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, ElasticConfig(n_dim=128)),
],
)
def test_matches_sprites(
pytestconfig, tmpdir, da_cls, config, embed_docs, start_storage, top_k
):
da, das = embed_docs
if config:
das = da_cls(das, config=config)
else:
das = da_cls(das)
da.match(das)
da[0].plot_matches_sprites(top_k, output=tmpdir / 'sprint_da.png')
assert os.path.exists(tmpdir / 'sprint_da.png')


@pytest.mark.parametrize('image_source', ['tensor', 'uri'])
@pytest.mark.parametrize(
'da_cls,config_gen',
[
(DocumentArray, None),
(DocumentArraySqlite, None),
(DocumentArrayAnnlite, lambda: AnnliteConfig(n_dim=128)),
(DocumentArrayWeaviate, lambda: WeaviateConfig(n_dim=128)),
(DocumentArrayQdrant, lambda: QdrantConfig(n_dim=128, scroll_batch_size=8)),
(DocumentArrayElastic, lambda: ElasticConfig(n_dim=128)),
],
)
def test_matches_sprite_image_generator(
pytestconfig,
tmpdir,
image_source,
da_cls,
config_gen,
embed_docs,
start_storage,
):
da, das = embed_docs
if image_source == 'tensor':
da.apply(lambda d: d.load_uri_to_image_tensor())
das.apply(lambda d: d.load_uri_to_image_tensor())

if config_gen:
das = da_cls(das, config=config_gen())
else:
das = da_cls(das)
da.match(das)
da[0].plot_matches_sprites(output=tmpdir / 'sprint_da.png')
assert os.path.exists(tmpdir / 'sprint_da.png')