Skip to content

Commit 59940cf

Browse files
okay well have the unit test working
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
1 parent 74e7ede commit 59940cf

4 files changed

Lines changed: 32 additions & 26 deletions

File tree

sdk/python/feast/inference.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,7 @@ def _infer_on_demand_features_and_entities(
306306
field = Field(
307307
name=col_name,
308308
dtype=from_value_type(
309-
batch_source.source_datatype_to_feast_value_type()(
310-
col_datatype
311-
)
309+
batch_source.source_datatype_to_feast_value_type()(col_datatype)
312310
),
313311
)
314312
if field.name not in [

sdk/python/feast/infra/online_stores/online_store.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ def get_online_features(
151151
native_entity_values=True,
152152
)
153153

154+
if join_key_values.get("driver_id", None):
155+
print("table_entity_values:", join_key_values)
154156
for table, requested_features in grouped_refs:
155157
# Get the correct set of entity values with the correct join keys.
156158
table_entity_values, idxs = utils._get_unique_entities(

sdk/python/feast/utils.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from feast.constants import FEAST_FS_YAML_FILE_PATH_ENV_NAME
2929
from feast.entity import Entity
3030
from feast.errors import (
31-
EntityNotFoundException,
3231
FeatureNameCollisionError,
3332
FeatureViewNotFoundException,
3433
RequestDataNotFoundInEntityRowsException,
@@ -1032,23 +1031,20 @@ def _prepare_entities_to_read_from_online_store(
10321031
# Found request data
10331032
if join_key_or_entity_name in needed_request_data:
10341033
request_data_features[join_key_or_entity_name] = values
1035-
else:
1036-
if join_key_or_entity_name in join_keys_set:
1037-
join_key = join_key_or_entity_name
1038-
else:
1039-
try:
1040-
if join_key_or_entity_name in request_source_keys:
1041-
join_key = entity_name_to_join_key_map[join_key_or_entity_name]
1042-
except KeyError:
1043-
raise EntityNotFoundException(join_key_or_entity_name, project)
1044-
else:
1045-
warnings.warn(
1046-
"Using entity name is deprecated. Use join_key instead."
1047-
)
1048-
1049-
# All join keys should be returned in the result.
1034+
elif join_key_or_entity_name in join_keys_set:
1035+
# It's a join key
1036+
join_key = join_key_or_entity_name
10501037
requested_result_row_names.add(join_key)
10511038
join_key_values[join_key] = values
1039+
elif join_key_or_entity_name in entity_name_to_join_key_map:
1040+
# It's an entity name (deprecated)
1041+
join_key = entity_name_to_join_key_map[join_key_or_entity_name]
1042+
warnings.warn("Using entity name is deprecated. Use join_key instead.")
1043+
requested_result_row_names.add(join_key)
1044+
join_key_values[join_key] = values
1045+
else:
1046+
# Key is not recognized (likely a feature value), so we skip it.
1047+
continue # Or handle accordingly
10521048

10531049
ensure_request_data_values_exist(needed_request_data, request_data_features)
10541050

sdk/python/tests/unit/test_on_demand_python_transformation.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ def test_invalid_python_transformation_raises_type_error_on_apply():
493493
schema=[Field(name="driver_name_lower", dtype=String)],
494494
mode="python",
495495
)
496-
def python_view(inputs: dict[str, Any]) -> dict[str, Any]: return {"driver_name_lower": []}
496+
def python_view(inputs: dict[str, Any]) -> dict[str, Any]:
497+
return {"driver_name_lower": []}
497498

498499
with pytest.raises(
499500
TypeError,
@@ -628,7 +629,9 @@ def python_stored_writes_feature_view(
628629
)
629630
assert (
630631
python_stored_writes_feature_view.entity_columns
631-
== self.store.get_on_demand_feature_view("python_stored_writes_feature_view").entity_columns
632+
== self.store.get_on_demand_feature_view(
633+
"python_stored_writes_feature_view"
634+
).entity_columns
632635
)
633636

634637
current_datetime = _utc_now()
@@ -660,7 +663,7 @@ def python_stored_writes_feature_view(
660663
{
661664
"driver_id": 1001,
662665
"conv_rate": 0.25,
663-
"acc_rate": 0.25,
666+
"acc_rate": 0.50,
664667
"counter": 0,
665668
"input_datetime": current_datetime,
666669
}
@@ -679,14 +682,21 @@ def python_stored_writes_feature_view(
679682
"driver_hourly_stats:avg_daily_trips",
680683
],
681684
).to_dict()
682-
print(online_python_response)
685+
686+
assert online_python_response == {
687+
"driver_id": [1001],
688+
"conv_rate": [0.25],
689+
"avg_daily_trips": [2],
690+
"acc_rate": [0.25],
691+
}
692+
683693
print("storing odfv features")
684694
self.store.write_to_online_store(
685695
feature_view_name="python_stored_writes_feature_view",
686696
df=odfv_entity_rows_to_write,
687697
)
688698
print("reading odfv features")
689-
online_python_response = self.store.get_online_features(
699+
online_odfv_python_response = self.store.get_online_features(
690700
entity_rows=odfv_entity_rows_to_read,
691701
features=[
692702
"python_stored_writes_feature_view:conv_rate_plus_acc",
@@ -695,8 +705,8 @@ def python_stored_writes_feature_view(
695705
"python_stored_writes_feature_view:input_datetime",
696706
],
697707
).to_dict()
698-
print(online_python_response)
699-
assert sorted(list(online_python_response.keys())) == sorted(
708+
print(online_odfv_python_response)
709+
assert sorted(list(online_odfv_python_response.keys())) == sorted(
700710
[
701711
"driver_id",
702712
"conv_rate_plus_acc",

0 commit comments

Comments
 (0)