Skip to content
Prev Previous commit
Next Next commit
fixed bugs, cleaned code, added some methods. test_universal_historic…
…al_retrieval - 100% passed

Signed-off-by: Youngkyu OH <toping4445@gmail.com>
  • Loading branch information
toping4445 authored and younggyu-oh committed Aug 9, 2022
commit 7cbd23293e9a73abdf273e42043cb58ec56add9d
6 changes: 3 additions & 3 deletions sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from importlib_metadata import PackageNotFoundError, version as _version # type: ignore

from feast.infra.offline_stores.bigquery_source import BigQuerySource
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import (
AthenaSource,
)
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource

from .batch_feature_view import BatchFeatureView
from .data_source import KafkaSource, KinesisSource, PushSource, RequestSource
Expand Down
1 change: 1 addition & 0 deletions sdk/python/feast/batch_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"SnowflakeSource",
"SparkSource",
"TrinoSource",
"AthenaSource",
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,21 @@
from feast import OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import InvalidEntityType
from feast.feature_logging import LoggingConfig, LoggingSource, LoggingDestination
from feast.feature_logging import LoggingConfig, LoggingDestination, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores import offline_utils
from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import (
AthenaLoggingDestination,
AthenaSource,
SavedDatasetAthenaStorage,
)
from feast.infra.offline_stores.offline_store import (
OfflineStore,
RetrievalJob,
RetrievalMetadata,
)

from feast.infra.offline_stores.contrib.athena_offline_store.athena_source import (
AthenaSource,
AthenaLoggingDestination,
SavedDatasetAthenaStorage,
)
from feast.infra.utils import aws_utils
from feast.infra.offline_stores import offline_utils

from feast.registry import Registry
from feast.registry import Registry, BaseRegistry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage
Expand Down Expand Up @@ -82,7 +80,7 @@ def pull_latest_from_table_or_query(
assert isinstance(data_source, AthenaSource)
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)

from_expression = data_source.get_table_query_string()
from_expression = data_source.get_table_query_string(config)

partition_by_join_key_string = ", ".join(join_key_columns)
if partition_by_join_key_string != "":
Expand All @@ -99,9 +97,7 @@ def pull_latest_from_table_or_query(

date_partition_column = data_source.date_partition_column

athena_client = aws_utils.get_athena_data_client(
config.offline_store.region
)
athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

start_date = start_date.astimezone(tz=utc)
Expand Down Expand Up @@ -142,15 +138,13 @@ def pull_all_from_table_or_query(
end_date: datetime,
) -> RetrievalJob:
assert isinstance(data_source, AthenaSource)
from_expression = data_source.get_table_query_string()
from_expression = data_source.get_table_query_string(config)

field_string = ", ".join(
join_key_columns + feature_name_columns + [timestamp_field]
)

athena_client = aws_utils.get_athena_data_client(
config.offline_store.region
)
athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

date_partition_column = data_source.date_partition_column
Expand Down Expand Up @@ -186,9 +180,7 @@ def get_historical_features(
) -> RetrievalJob:
assert isinstance(config.offline_store, AthenaOfflineStoreConfig)

athena_client = aws_utils.get_athena_data_client(
config.offline_store.region
)
athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)

# get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it
Expand All @@ -197,23 +189,24 @@ def get_historical_features(
)

# find timestamp column of entity df.(default = "event_timestamp"). Exception occurs if there are more than two timestamp columns.
entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema
entity_df_event_timestamp_col = (
offline_utils.infer_event_timestamp_from_entity_df(entity_schema)
)

# get min,max of event_timestamp.
entity_df_event_timestamp_range = _get_entity_df_event_timestamp_range(
entity_df, entity_df_event_timestamp_col, athena_client, config,
entity_df,
entity_df_event_timestamp_col,
athena_client,
config,
)

@contextlib.contextmanager
def query_generator() -> Iterator[str]:

table_name = offline_utils.get_temp_entity_table_name()

_upload_entity_df(
entity_df, athena_client, config, s3_resource, table_name
)
_upload_entity_df(entity_df, athena_client, config, s3_resource, table_name)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
Expand All @@ -232,7 +225,6 @@ def query_generator() -> Iterator[str]:
entity_df_event_timestamp_range,
)


# Generate the Athena SQL query from the query context
query = offline_utils.build_point_in_time_query(
query_context,
Expand All @@ -247,17 +239,20 @@ def query_generator() -> Iterator[str]:
yield query
finally:

#Always clean up the temp Athena table
# Always clean up the temp Athena table
aws_utils.execute_athena_query(
athena_client,
config.offline_store.data_source,
config.offline_store.database,
f"DROP TABLE IF EXISTS {config.offline_store.database}.{table_name}",
)

bucket = config.offline_store.s3_staging_location.replace("s3://", "").split("/", 1)[0]
aws_utils.delete_s3_directory(s3_resource,bucket, "entity_df/"+table_name+"/")

bucket = config.offline_store.s3_staging_location.replace(
"s3://", ""
).split("/", 1)[0]
aws_utils.delete_s3_directory(
s3_resource, bucket, "entity_df/" + table_name + "/"
)

return AthenaRetrievalJob(
query=query_generator,
Expand All @@ -276,21 +271,18 @@ def query_generator() -> Iterator[str]:
),
)


@staticmethod
def write_logged_features(
config: RepoConfig,
data: Union[pyarrow.Table, Path],
source: LoggingSource,
logging_config: LoggingConfig,
registry: Registry,
registry: BaseRegistry,
):
destination = logging_config.destination
assert isinstance(destination, AthenaLoggingDestination)

athena_client = aws_utils.get_athena_data_client(
config.offline_store.region
)
athena_client = aws_utils.get_athena_data_client(config.offline_store.region)
s3_resource = aws_utils.get_s3_resource(config.offline_store.region)
if isinstance(data, Path):
s3_path = f"{config.offline_store.s3_staging_location}/logged_features/{uuid.uuid4()}"
Expand All @@ -299,7 +291,7 @@ def write_logged_features(

aws_utils.upload_arrow_table_to_athena(
table=data,
athena_data_client=athena_client,
athena_client=athena_client,
data_source=config.offline_store.data_source,
database=config.offline_store.database,
s3_resource=s3_resource,
Expand Down Expand Up @@ -332,7 +324,6 @@ def __init__(
on_demand_feature_views (optional): A list of on demand transforms to apply at retrieval time
"""


if not isinstance(query, str):
self._query_generator = query
else:
Expand All @@ -352,7 +343,6 @@ def query_generator() -> Iterator[str]:
)
self._metadata = metadata


@property
def full_feature_names(self) -> bool:
return self._full_feature_names
Expand All @@ -362,9 +352,15 @@ def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]:
return self._on_demand_feature_views

def get_temp_s3_path(self) -> str:
return self._config.offline_store.s3_staging_location + "/unload/" + str(uuid.uuid4())
return (
self._config.offline_store.s3_staging_location
+ "/unload/"
+ str(uuid.uuid4())
)

def get_temp_table_dml_header(self, temp_table_name:str, temp_external_location:str) -> str:
def get_temp_table_dml_header(
self, temp_table_name: str, temp_external_location: str
) -> str:
temp_table_dml_header = f"""
CREATE TABLE {temp_table_name}
WITH (
Expand All @@ -387,7 +383,8 @@ def _to_df_internal(self) -> pd.DataFrame:
self._config.offline_store.database,
self._s3_resource,
temp_external_location,
self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query,
self.get_temp_table_dml_header(temp_table_name, temp_external_location)
+ query,
temp_table_name,
)

Expand All @@ -402,7 +399,8 @@ def _to_arrow_internal(self) -> pa.Table:
self._config.offline_store.database,
self._s3_resource,
temp_external_location,
self.get_temp_table_dml_header(temp_table_name, temp_external_location) + query,
self.get_temp_table_dml_header(temp_table_name, temp_external_location)
+ query,
temp_table_name,
)

Expand All @@ -412,7 +410,33 @@ def metadata(self) -> Optional[RetrievalMetadata]:

def persist(self, storage: SavedDatasetStorage):
assert isinstance(storage, SavedDatasetAthenaStorage)
# self.to_athena(table_name=storage.athena_options.table)
self.to_athena(table_name=storage.athena_options.table)

@log_exceptions_and_usage
def to_athena(self, table_name: str) -> None:

if self.on_demand_feature_views:
transformed_df = self.to_df()

_upload_entity_df(
transformed_df,
self._athena_client,
self._config,
self._s3_resource,
table_name,
)

return

with self._query_generator() as query:
query = f'CREATE TABLE "{table_name}" AS ({query});\n'

aws_utils.execute_athena_query(
self._athena_client,
self._config.offline_store.data_source,
self._config.offline_store.database,
query,
)


def _upload_entity_df(
Expand Down Expand Up @@ -496,12 +520,14 @@ def _get_entity_df_event_timestamp_range(
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max "
f"FROM ({entity_df})",
)
res = aws_utils.get_athena_query_result(athena_client, statement_id)[
"Records"
][0]
res = aws_utils.get_athena_query_result(athena_client, statement_id)
entity_df_event_timestamp_range = (
res.parse(res[0]["stringValue"]),
res.parse(res[1]["stringValue"]),
datetime.strptime(
res["Rows"][1]["Data"][0]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f"
),
datetime.strptime(
res["Rows"][1]["Data"][1]["VarCharValue"], "%Y-%m-%d %H:%M:%S.%f"
),
)
else:
raise InvalidEntityType(type(entity_df))
Expand Down
Loading