Skip to content

Commit 30cc652

Browse files
authored
feat: support filtering based on text keywords for qdrant (#849)
Signed-off-by: AnneY <evangeline-lun@foxmail.com>
1 parent 95d6c5f commit 30cc652

4 files changed

Lines changed: 122 additions & 32 deletions

File tree

docarray/array/storage/qdrant/backend.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
from qdrant_client import QdrantClient
17+
from qdrant_client.http import models
1718
from qdrant_client.http.models.models import (
1819
Distance,
1920
CreateCollection,
@@ -24,7 +25,7 @@
2425
)
2526

2627
from docarray import Document
27-
from docarray.array.storage.base.backend import BaseBackendMixin
28+
from docarray.array.storage.base.backend import BaseBackendMixin, TypeMap
2829
from docarray.array.storage.qdrant.helper import DISTANCES
2930
from docarray.helper import dataclass_from_dict, random_identity
3031
from docarray.math.helper import EPSILON
@@ -74,6 +75,15 @@ def distance(self) -> 'Distance':
7475
def _tmp_collection_name(cls) -> str:
7576
return uuid.uuid4().hex
7677

78+
TYPE_MAP = {
79+
'int': TypeMap(type='integer', converter=int),
80+
'float': TypeMap(type='float', converter=float),
81+
'bool': TypeMap(type='int', converter=bool),
82+
'str': TypeMap(type='keyword', converter=str),
83+
'text': TypeMap(type='text', converter=str),
84+
'geo': TypeMap(type='geo', converter=dict),
85+
}
86+
7787
def _init_storage(
7888
self,
7989
docs: Optional['DocumentArraySourceType'] = None,
@@ -172,6 +182,23 @@ def _initialize_qdrant_schema(self):
172182
hnsw_config=hnsw_config,
173183
)
174184

185+
for col, coltype in self._config.columns.items():
186+
if coltype == 'text':
187+
self.client.create_payload_index(
188+
collection_name=self.collection_name,
189+
field_name=col,
190+
field_schema=models.TextIndexParams(
191+
type="text",
192+
tokenizer=models.TokenizerType.WORD,
193+
),
194+
)
195+
else:
196+
self.client.create_payload_index(
197+
collection_name=self.collection_name,
198+
field_name=col,
199+
field_schema=self._map_type(coltype),
200+
)
201+
175202
def _collection_exists(self, collection_name):
176203
resp = self.client.get_collections()
177204
collections = [collection.name for collection in resp.collections]

docarray/array/storage/qdrant/find.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from docarray import Document, DocumentArray
55
from docarray.math import ndarray
66
from docarray.score import NamedScore
7-
from qdrant_client.http import models as rest
7+
from qdrant_client.http import models
88
from qdrant_client.http.models.models import Distance
99

1010
if TYPE_CHECKING: # pragma: no cover
@@ -59,7 +59,7 @@ def _find_similar_vectors(
5959
query_filter=filter,
6060
search_params=None
6161
if not search_params
62-
else rest.SearchParams(**search_params),
62+
else models.SearchParams(**search_params),
6363
limit=limit,
6464
append_payload=['_serialized'],
6565
)
@@ -117,7 +117,7 @@ def _find_with_filter(
117117
):
118118
list_of_points, _offset = self.client.scroll(
119119
collection_name=self.collection_name,
120-
scroll_filter=rest.Filter(**filter),
120+
scroll_filter=models.Filter(**filter),
121121
with_payload=True,
122122
limit=limit,
123123
)

docs/advanced/document-store/qdrant.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ The following configs can be set:
9797
| `root_id` | Boolean flag indicating whether to store `root_id` in the tags of chunk level Documents | True |
9898

9999

100-
101100
*You can read more about the HNSW parameters and their default values [here](https://qdrant.tech/documentation/indexing/#vector-index)
102101

103102
## Minimum example
@@ -150,8 +149,7 @@ print(da.find(np.random.random(D), limit=10))
150149
(qdrant-filter)=
151150
## Vector search with filter
152151

153-
Search with `.find` can be restricted by user-defined filters. Such filters can be constructed following the guidelines
154-
in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/)
152+
Search with `.find` can be restricted by user-defined filters. The supported tag types for filter are `'int'`, `'float'`, `'bool'`, `'str'`, `'text'` and `'geo'` as in [Qdrant](https://qdrant.tech/documentation/payload/). Such filters can be constructed following the guidelines in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/)
155153

156154

157155
### Example of `.find` with a filter
@@ -276,4 +274,4 @@ Points with "price" at most 7:
276274
embedding=[7. 7. 7.], price=7
277275
embedding=[1. 1. 1.], price=1
278276
embedding=[2. 2. 2.], price=2
279-
```
277+
```

tests/unit/array/mixins/test_find.py

Lines changed: 89 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -663,39 +663,104 @@ def test_filtering(
663663

664664

665665
@pytest.mark.parametrize(
666-
'storage,filter_gen,numeric_operators,operator',
666+
'columns',
667667
[
668-
*[
669-
tuple(
670-
[
671-
'qdrant',
672-
lambda operator, threshold: {
673-
'must': [{'key': 'price', 'match': {'value': threshold}}]
668+
[
669+
('price', 'float'),
670+
('category', 'str'),
671+
('info', 'text'),
672+
('location', 'geo'),
673+
],
674+
{'price': 'float', 'category': 'str', 'info': 'text', 'location': 'geo'},
675+
],
676+
)
677+
@pytest.mark.parametrize(
678+
'filter,checker',
679+
[
680+
(
681+
{
682+
'must': [
683+
{"key": "category", "match": {"value": "Shoes"}},
684+
{"key": "price", "range": {"gte": 5.0}},
685+
]
686+
},
687+
lambda r: r.tags['category'] == "Shoes" and r.tags['price'] >= 5.0,
688+
),
689+
(
690+
{
691+
'must_not': [
692+
{"key": "info", "match": {"text": "shoes"}},
693+
{
694+
"key": "location",
695+
"geo_radius": {
696+
"center": {"lon": -98.17, "lat": 38.71},
697+
"radius": 500.0 * 1000,
698+
},
674699
},
675-
numeric_operators_qdrant,
676-
'eq',
677700
]
678-
)
679-
],
701+
},
702+
lambda r: r.tags['info'].find("shoes") == -1
703+
and (
704+
haversine_distances(
705+
[
706+
[-98.17, 38.71],
707+
[r.tags['location']['lon'], r.tags['location']['lat']],
708+
]
709+
)
710+
* 6371
711+
)[0][1]
712+
> 500.0,
713+
),
714+
(
715+
{
716+
'should': [
717+
{"key": "info", "match": {"text": "shoes"}},
718+
{"key": "price", "range": {"gte": 5.0}},
719+
]
720+
},
721+
lambda r: r.tags['info'].find("shoes") != -1 or r.tags['price'] >= 5.0,
722+
),
680723
],
681724
)
682-
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
683-
def test_qdrant_filter_function(
684-
storage, filter_gen, operator, numeric_operators, start_storage, columns
685-
):
725+
def test_qdrant_filter_query(filter, checker, columns, start_storage):
686726
n_dim = 128
687727
da = DocumentArray(storage='qdrant', config={'n_dim': n_dim, 'columns': columns})
688-
da.extend([Document(id=f'r{i}', tags={'price': i}) for i in range(50)])
689-
thresholds = [10, 20, 30]
690-
for threshold in thresholds:
691-
filter = filter_gen(operator, threshold)
692-
results = da._filter(filter=filter)
693728

694-
assert len(results) > 0
729+
da.extend(
730+
[
731+
Document(
732+
id=f'r{i}',
733+
embedding=np.random.rand(n_dim),
734+
tags={
735+
'price': i + 0.5,
736+
'category': 'Shoes',
737+
'info': f'shoes {i}',
738+
'location': {"lon": -98.17 + i, "lat": 38.93 + i},
739+
},
740+
)
741+
for i in range(10)
742+
]
743+
)
695744

696-
assert all(
697-
[numeric_operators[operator](r.tags['price'], threshold) for r in results]
698-
)
745+
da.extend(
746+
[
747+
Document(
748+
id=f'r{i+10}',
749+
embedding=np.random.rand(n_dim),
750+
tags={
751+
'price': i + 0.5,
752+
'category': 'Jeans',
753+
'info': 'jeans {i}',
754+
'location': {"lon": -98.17 + i, "lat": 38.93 + i},
755+
},
756+
)
757+
for i in range(10)
758+
]
759+
)
760+
761+
results = da.find(np.random.rand(n_dim), filter=filter)
762+
assert len(results) > 0
763+
assert all([checker(r) for r in results])
699764

700765

701766
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])

0 commit comments

Comments
 (0)