Skip to content

Commit 7c4e819

Browse files
vanitabhagwatvanitabhagwat
andauthored
fix: Updated the elastic search to use vector_index defined in the feature view to identify vector fields (#348)
* updated the elastic search to use vector_index defined in the feature view to identify vector fields * fix: formatting * Added logging and switched to use open source elastic search --------- Co-authored-by: vanitabhagwat <vbhagwat@expediagroup.com>
1 parent fec2a10 commit 7c4e819

3 files changed

Lines changed: 103 additions & 10 deletions

File tree

sdk/python/feast/infra/online_stores/elasticsearch_online_store/elasticsearch.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
to_naive_utc,
2727
)
2828

29+
logger = logging.getLogger(__name__)
30+
2931

3032
class ElasticSearchOnlineStoreConfig(FeastConfigBaseModel, VectorStoreConfig):
3133
"""
@@ -93,6 +95,8 @@ def online_write_batch(
9395
],
9496
progress: Optional[Callable[[int], Any]],
9597
) -> None:
98+
vector_field = _get_feature_view_vector_field_metadata(table)
99+
vector_field_name = vector_field.name if vector_field else None
96100
insert_values = []
97101
grouped_docs: dict[str, dict[str, Any]] = defaultdict(
98102
lambda: {
@@ -115,7 +119,9 @@ def online_write_batch(
115119
doc_key = f"{encoded_entity_key}_{timestamp}"
116120

117121
for feature_name, value in values.items():
118-
doc = _encode_feature_value(value)
122+
doc = _encode_feature_value(
123+
value, is_vector=(feature_name == vector_field_name)
124+
)
119125
grouped_docs[doc_key]["features"][feature_name] = doc
120126
grouped_docs[doc_key]["timestamp"] = timestamp
121127
grouped_docs[doc_key]["created_ts"] = created_ts
@@ -299,8 +305,11 @@ def retrieve_online_documents(
299305
Optional[ValueProto],
300306
]
301307
] = []
308+
vector_field = _get_feature_view_vector_field_metadata(table)
302309
vector_field_path = (
303-
config.online_store.vector_field_path or "embedding.vector_value"
310+
f"{vector_field.name}.vector_value"
311+
if vector_field
312+
else config.online_store.vector_field_path or "embedding.vector_value"
304313
)
305314
query = {
306315
"script_score": {
@@ -384,10 +393,21 @@ def retrieve_online_documents_v2(
384393
body["_source"] = source_fields
385394

386395
if embedding:
387-
similarity = (distance_metric or config.online_store.similarity).lower()
396+
vector_field = _get_feature_view_vector_field_metadata(table)
388397
vector_field_path = (
389-
config.online_store.vector_field_path or "embedding.vector_value"
398+
f"{vector_field.name}.vector_value"
399+
if vector_field
400+
else config.online_store.vector_field_path or "embedding.vector_value"
390401
)
402+
similarity = (
403+
distance_metric
404+
or (
405+
vector_field.vector_search_metric
406+
if vector_field and vector_field.vector_search_metric
407+
else None
408+
)
409+
or config.online_store.similarity
410+
).lower()
391411
if similarity == "cosine":
392412
script = f"cosineSimilarity(params.query_vector, '{vector_field_path}') + 1.0"
393413
elif similarity == "dot_product":
@@ -489,16 +509,21 @@ def _to_value_proto(value: Any) -> ValueProto:
489509
return val_proto
490510

491511

492-
def _encode_feature_value(value: ValueProto) -> Dict[str, Any]:
512+
def _encode_feature_value(value: ValueProto, is_vector: bool = False) -> Dict[str, Any]:
493513
"""
494514
Encode a ValueProto into a dictionary for Elasticsearch storage.
495515
"""
496516
encoded_value = base64.b64encode(value.SerializeToString()).decode("utf-8")
497517
result = {"feature_value": encoded_value}
498-
vector_val = get_list_val_str(value)
499518

500-
if vector_val:
501-
result["vector_value"] = json.loads(vector_val)
519+
if is_vector:
520+
vector_val = get_list_val_str(value)
521+
if vector_val:
522+
result["vector_value"] = json.loads(vector_val)
523+
else:
524+
logger.warning(
525+
"Feature is marked as vector but value does not contain a valid vector."
526+
)
502527
if value.HasField("string_val"):
503528
result["value_text"] = value.string_val
504529
elif value.HasField("bytes_val"):

sdk/python/feast/repo_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@
8383
"hazelcast": "feast.infra.online_stores.hazelcast_online_store.hazelcast_online_store.HazelcastOnlineStore",
8484
"ikv": "feast.infra.online_stores.ikv_online_store.ikv.IKVOnlineStore",
8585
"eg-milvus": "feast.expediagroup.vectordb.eg_milvus_online_store.EGMilvusOnlineStore",
86-
"elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore",
87-
# "elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore",
86+
# "elasticsearch": "feast.expediagroup.vectordb.elasticsearch_online_store.ElasticsearchOnlineStore",
87+
"elasticsearch": "feast.infra.online_stores.elasticsearch_online_store.elasticsearch.ElasticSearchOnlineStore",
8888
"remote": "feast.infra.online_stores.remote.RemoteOnlineStore",
8989
"singlestore": "feast.infra.online_stores.singlestore_online_store.singlestore.SingleStoreOnlineStore",
9090
"qdrant": "feast.infra.online_stores.qdrant_online_store.qdrant.QdrantOnlineStore",
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import base64
2+
3+
import pytest
4+
5+
from feast.infra.online_stores.elasticsearch_online_store.elasticsearch import (
6+
_encode_feature_value,
7+
)
8+
from feast.protos.feast.types.Value_pb2 import (
9+
FloatList,
10+
Int64List,
11+
)
12+
from feast.protos.feast.types.Value_pb2 import (
13+
Value as ValueProto,
14+
)
15+
16+
17+
class TestEncodeFeatureValue:
18+
def test_vector_field_includes_vector_value(self):
19+
"""When is_vector=True and value is a float list, vector_value should be present."""
20+
value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3]))
21+
result = _encode_feature_value(value, is_vector=True)
22+
23+
assert "vector_value" in result
24+
assert result["vector_value"] == pytest.approx([0.1, 0.2, 0.3])
25+
26+
def test_non_vector_list_excludes_vector_value(self):
27+
"""When is_vector=False and value is a float list, vector_value should NOT be present."""
28+
value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2, 0.3]))
29+
result = _encode_feature_value(value, is_vector=False)
30+
31+
assert "vector_value" not in result
32+
33+
def test_non_vector_int_list_excludes_vector_value(self):
34+
"""An int64 list with is_vector=False should not produce vector_value."""
35+
value = ValueProto(int64_list_val=Int64List(val=[1, 2, 3]))
36+
result = _encode_feature_value(value, is_vector=False)
37+
38+
assert "vector_value" not in result
39+
40+
def test_string_value_has_value_text(self):
41+
"""A string ValueProto should produce value_text, not vector_value."""
42+
value = ValueProto(string_val="hello")
43+
result = _encode_feature_value(value, is_vector=False)
44+
45+
assert result["value_text"] == "hello"
46+
assert "vector_value" not in result
47+
48+
def test_feature_value_always_present(self):
49+
"""feature_value (base64 binary) should always be present regardless of is_vector."""
50+
vector_value = ValueProto(float_list_val=FloatList(val=[1.0, 2.0]))
51+
string_value = ValueProto(string_val="test")
52+
int_value = ValueProto(int64_val=42)
53+
54+
for val in [vector_value, string_value, int_value]:
55+
for is_vector in [True, False]:
56+
result = _encode_feature_value(val, is_vector=is_vector)
57+
assert "feature_value" in result
58+
# Verify it's valid base64 that deserializes back
59+
decoded = base64.b64decode(result["feature_value"])
60+
roundtrip = ValueProto()
61+
roundtrip.ParseFromString(decoded)
62+
63+
def test_default_is_vector_false(self):
64+
"""Calling without is_vector should default to False (no vector_value)."""
65+
value = ValueProto(float_list_val=FloatList(val=[0.1, 0.2]))
66+
result = _encode_feature_value(value)
67+
68+
assert "vector_value" not in result

0 commit comments

Comments
 (0)