Skip to content

Commit 31da66b

Browse files
AnneYang720dong xiang
authored andcommitted
feat: redis supports geo filter (#579)
1 parent 5d06133 commit 31da66b

5 files changed

Lines changed: 139 additions & 31 deletions

File tree

docarray/array/storage/qdrant/find.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
TypeVar,
55
Sequence,
66
List,
7-
Dict,
7+
Union,
88
Optional,
9+
Dict,
910
)
10-
1111
from qdrant_client.http.models.models import Distance
1212

1313
from docarray import Document, DocumentArray
@@ -103,3 +103,32 @@ def _find(
103103
da = self._find_similar_vectors(q, limit=limit, filter=filter)
104104
closest_docs.append(da)
105105
return closest_docs
106+
107+
def _find_with_filter(
108+
self, filter: Optional[Dict], limit: Optional[Union[int, float]] = 10
109+
):
110+
list_of_points, _offset = self.client.scroll(
111+
collection_name=self.collection_name,
112+
scroll_filter=filter,
113+
with_payload=True,
114+
limit=limit,
115+
)
116+
da = DocumentArray()
117+
for result in list_of_points[:limit]:
118+
doc = Document.from_base64(
119+
result.payload['_serialized'], **self.serialize_config
120+
)
121+
da.append(doc)
122+
return da
123+
124+
def _filter(
125+
self, filter: Optional[Dict], limit: Optional[Union[int, float]] = 10
126+
) -> 'DocumentArray':
127+
"""Returns a subset of documents by filtering by the given filter (`Qdrant` filter)..
128+
:param limit: number of retrieved items
129+
:param filter: filter query used for filtering.
130+
For more information: https://docarray.jina.ai/advanced/document-store/qdrant/#qdrant
131+
:return: a `DocumentArray` containing the `Document` objects that verify the filter.
132+
"""
133+
134+
return self._find_with_filter(filter, limit=limit)

docarray/array/storage/redis/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from docarray.helper import dataclass_from_dict, random_identity, filter_dict
99

1010
from redis import Redis
11-
from redis.commands.search.field import NumericField, TextField, VectorField, GeoField
11+
from redis.commands.search.field import NumericField, TextField, VectorField
1212
from redis.commands.search.indexDefinition import IndexDefinition
1313

1414
if TYPE_CHECKING:
@@ -46,7 +46,7 @@ class BackendMixin(BaseBackendMixin):
4646
'float': TypeMap(type='float', converter=NumericField),
4747
'double': TypeMap(type='double', converter=NumericField),
4848
'long': TypeMap(type='long', converter=NumericField),
49-
'geo': TypeMap(type='geo', converter=GeoField),
49+
'bool': TypeMap(type='long', converter=NumericField),
5050
}
5151

5252
def _init_storage(

docarray/array/storage/redis/find.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import warnings
21
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar, Union
32

43
import numpy as np
@@ -39,13 +38,14 @@ class FindMixin(BaseFindMixin):
3938
def _find_similar_vectors(
4039
self,
4140
query: 'RedisArrayType',
42-
filter: Optional[Union[str, Dict]] = None,
41+
filter: Optional[Dict] = None,
4342
limit: Union[int, float] = 20,
4443
**kwargs,
4544
):
4645

4746
if filter:
48-
query_str = _get_redis_filter_query(filter)
47+
nodes = _build_query_nodes(filter)
48+
query_str = intersect(*nodes).to_string()
4949
else:
5050
query_str = '*'
5151

@@ -74,7 +74,7 @@ def _find(
7474
self,
7575
query: 'RedisArrayType',
7676
limit: Union[int, float] = 20,
77-
filter: Optional[Union[str, Dict]] = None,
77+
filter: Optional[Dict] = None,
7878
**kwargs,
7979
) -> List['DocumentArray']:
8080

@@ -90,10 +90,11 @@ def _find(
9090

9191
def _find_with_filter(
9292
self,
93-
filter: Union[str, Dict],
93+
filter: Dict,
9494
limit: Union[int, float] = 20,
9595
):
96-
query_str = _get_redis_filter_query(filter)
96+
nodes = _build_query_nodes(filter)
97+
query_str = intersect(*nodes).to_string()
9798
q = Query(query_str)
9899
q.paging(0, limit)
99100

@@ -107,7 +108,7 @@ def _find_with_filter(
107108

108109
def _filter(
109110
self,
110-
filter: Union[str, Dict],
111+
filter: Dict,
111112
limit: Union[int, float] = 20,
112113
) -> 'DocumentArray':
113114

@@ -217,19 +218,3 @@ def _build_query_nodes(filter):
217218
def _build_query_str(query):
218219
query_str = '|'.join(query.split(' '))
219220
return query_str
220-
221-
222-
def _get_redis_filter_query(filter: Union[str, Dict]):
223-
if isinstance(filter, dict):
224-
warnings.warn(
225-
"Dict syntax for redis filter will be deprecated, use string literals instead",
226-
DeprecationWarning,
227-
)
228-
nodes = _build_query_nodes(filter)
229-
query_str = intersect(*nodes).to_string()
230-
elif isinstance(filter, str):
231-
query_str = filter
232-
else:
233-
raise ValueError(f'Unexpected type of filter: {type(filter)}, expected str')
234-
235-
return query_str

docs/advanced/document-store/qdrant.md

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ da = DocumentArray(
6161
'host': 'localhost',
6262
'port': '6333',
6363
'n_dim': 10,
64+
'distance': "cosine",
6465
},
6566
)
6667

@@ -98,7 +99,7 @@ Create `docker-compose.yml`:
9899
version: '3.4'
99100
services:
100101
qdrant:
101-
image: qdrant/qdrant:v0.7.0
102+
image: qdrant/qdrant:v0.8.0
102103
ports:
103104
- "6333:6333"
104105
ulimits: # Only required for tests, as there are a lot of collections created
@@ -121,7 +122,9 @@ from docarray import DocumentArray
121122

122123
N, D = 100, 128
123124

124-
da = DocumentArray.empty(N, storage='qdrant', config={'n_dim': D}) # init
125+
da = DocumentArray.empty(
126+
N, storage='qdrant', config={'n_dim': D, 'distance': 'cosine'}
127+
) # init
125128

126129
da.embeddings = np.random.random([N, D])
127130

@@ -146,7 +149,7 @@ in [Qdrant's Documentation](https://qdrant.tech/documentation/filtering/)
146149
Consider Documents with embeddings `[0,0,0]` up to ` [9,9,9]` where the document with embedding `[i,i,i]`
147150
has as tag `price` with value `i`. We can create such example with the following code:
148151

149-
```python
152+
```pyt
150153
from docarray import Document, DocumentArray
151154
import numpy as np
152155

@@ -175,7 +178,7 @@ for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
175178

176179
Consider we want the nearest vectors to the embedding `[8. 8. 8.]`, with the restriction that
177180
prices must follow a filter. As an example, let's consider that retrieved documents must have `price` value lower
178-
or equal than `max_price`. We can encode this information in annlite using `filter = {'price': {'$lte': max_price}}`.
181+
or equal than `max_price`. We can encode this information in qdrant using `filter = {'price': {'$lte': max_price}}`.
179182

180183
Then the search with the proposed filter can be implemented and used with the following code:
181184

@@ -206,3 +209,58 @@ Embeddings Nearest Neighbours with "price" at most 7:
206209
embedding=[5. 5. 5.], price=5
207210
embedding=[4. 4. 4.], price=4
208211
```
212+
### Example of `.filter` with a filter
213+
Consider Documents have tags `price` with value `i`. We can create such example with the following code:
214+
```python
215+
from docarray import Document, DocumentArray
216+
import numpy as np
217+
218+
n_dim = 3
219+
distance = 'euclidean'
220+
221+
da = DocumentArray(
222+
storage='qdrant',
223+
config={'n_dim': n_dim, 'columns': {'price': 'float'}, 'distance': distance},
224+
)
225+
226+
print(f'\nDocumentArray distance: {distance}')
227+
228+
with da:
229+
da.extend(
230+
[
231+
Document(id=f'r{i}', embedding=i * np.ones(n_dim), tags={'price': i})
232+
for i in range(10)
233+
]
234+
)
235+
236+
print('\nIndexed Prices:\n')
237+
for embedding, price in zip(da.embeddings, da[:, 'tags__price']):
238+
print(f'\tembedding={embedding},\t price={price}')
239+
```
240+
Consider we want the vectors with the restriction that prices must follow a filter. As an example,
241+
let's consider that retrieved documents must have `price` value lower or equal than `max_price`. We can encode
242+
this information in qdrant using `filter = {'price': {'$lte': max_price}}`.
243+
244+
Then the search with the proposed filter can be implemented and used with the following code:
245+
```python
246+
max_price = 7
247+
n_limit = 4
248+
249+
filter = {'must': [{'key': 'price', 'range': {'lte': max_price}}]}
250+
results = da.filter(filter=filter, limit=n_limit)
251+
252+
print('\nPoints with "price" at most 7:\n')
253+
for embedding, price in zip(results.embeddings, results[:, 'tags__price']):
254+
print(f'\tembedding={embedding},\t price={price}')
255+
```
256+
This would print:
257+
258+
```
259+
260+
Points with "price" at most 7:
261+
262+
embedding=[6. 6. 6.], price=6
263+
embedding=[7. 7. 7.], price=7
264+
embedding=[1. 1. 1.], price=1
265+
embedding=[2. 2. 2.], price=2
266+
```

tests/unit/array/mixins/test_find.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,42 @@ def test_filtering(
545545
)
546546

547547

548+
@pytest.mark.parametrize(
549+
'storage,filter_gen,numeric_operators,operator',
550+
[
551+
*[
552+
tuple(
553+
[
554+
'qdrant',
555+
lambda operator, threshold: {
556+
'must': [{'key': 'price', 'match': {'value': threshold}}]
557+
},
558+
numeric_operators_qdrant,
559+
'eq',
560+
]
561+
)
562+
],
563+
],
564+
)
565+
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
566+
def test_qdrant_filter_function(
567+
storage, filter_gen, operator, numeric_operators, start_storage, columns
568+
):
569+
n_dim = 128
570+
da = DocumentArray(storage='qdrant', config={'n_dim': n_dim, 'columns': columns})
571+
da.extend([Document(id=f'r{i}', tags={'price': i}) for i in range(50)])
572+
thresholds = [10, 20, 30]
573+
for threshold in thresholds:
574+
filter = filter_gen(operator, threshold)
575+
results = da._filter(filter=filter)
576+
577+
assert len(results) > 0
578+
579+
assert all(
580+
[numeric_operators[operator](r.tags['price'], threshold) for r in results]
581+
)
582+
583+
548584
@pytest.mark.parametrize('columns', [[('price', 'int')], {'price': 'int'}])
549585
def test_weaviate_filter_query(start_storage, columns):
550586
n_dim = 128

0 commit comments

Comments
 (0)