Skip to content

Commit 4be29fb

Browse files
feat: Add Additional Sort Key Validation (#228)
* adding changes to hybrid health-check (#221) * Adding more validation to sort keys in sorted feature view * fixing linting --------- Co-authored-by: Vineet Belur <vbelur@gmail.com>
1 parent 3bd8a2b commit 4be29fb

6 files changed

Lines changed: 106 additions & 15 deletions

File tree

go/internal/feast/server/grpc_server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package server
33
import (
44
"context"
55
"fmt"
6+
67
"google.golang.org/grpc/reflection"
78

89
"github.com/feast-dev/feast/go/internal/feast"
@@ -212,7 +213,7 @@ func (s *grpcServingServiceServer) GetOnlineFeaturesRange(ctx context.Context, r
212213
}
213214

214215
// Register services used by the grpcServingServiceServer.
215-
func (s *grpcServingServiceServer) RegisterServices() (*grpc.Server, *health.Server) {
216+
func (s *grpcServingServiceServer) RegisterServices() *grpc.Server {
216217
grpcPromMetrics := grpcPrometheus.NewServerMetrics()
217218
prometheus.MustRegister(grpcPromMetrics)
218219
grpcServer := grpc.NewServer(
@@ -224,7 +225,7 @@ func (s *grpcServingServiceServer) RegisterServices() (*grpc.Server, *health.Ser
224225
grpc_health_v1.RegisterHealthServer(grpcServer, healthService)
225226
reflection.Register(grpcServer)
226227

227-
return grpcServer, healthService
228+
return grpcServer
228229
}
229230

230231
func GenerateRequestId() string {

go/internal/feast/server/hybrid_server.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@ package server
33

44
import (
55
"context"
6+
"fmt"
67
"net/http"
78
"time"
89

910
"github.com/prometheus/client_golang/prometheus/promhttp"
10-
"google.golang.org/grpc/health"
11+
"google.golang.org/grpc"
12+
"google.golang.org/grpc/credentials/insecure"
1113
healthpb "google.golang.org/grpc/health/grpc_health_v1"
1214
)
1315

1416
var defaultCheckTimeout = 2 * time.Second
1517

1618
// Register default HTTP handlers specific to the hybrid server configuration.
17-
func DefaultHybridHandlers(s *httpServer, hs *health.Server) []Handler {
19+
func DefaultHybridHandlers(s *httpServer, port int) []Handler {
1820
return []Handler{
1921
{
2022
path: "/get-online-features",
@@ -26,25 +28,36 @@ func DefaultHybridHandlers(s *httpServer, hs *health.Server) []Handler {
2628
},
2729
{
2830
path: "/health",
29-
handlerFunc: http.HandlerFunc(combinedHealthCheck(hs)),
31+
handlerFunc: http.HandlerFunc(combinedHealthCheck(port)),
3032
},
3133
}
3234
}
3335

3436
// This function wraps an http.Handler that is registered during hybrid server creation.
3537
// Calls the grpc.server healthcheck check endpoint
36-
func combinedHealthCheck(hs *health.Server) http.HandlerFunc {
38+
func combinedHealthCheck(port int) http.HandlerFunc {
3739
return func(w http.ResponseWriter, r *http.Request) {
3840
ctx, cancel := context.WithTimeout(r.Context(), defaultCheckTimeout)
3941
defer cancel()
4042

41-
req := &healthpb.HealthCheckRequest{
42-
Service: "", // Empty string means that it will simply check overall servingStatus
43+
target := fmt.Sprintf("localhost:%d", port)
44+
conn, err := grpc.DialContext(
45+
ctx,
46+
target,
47+
grpc.WithTransportCredentials(insecure.NewCredentials()),
48+
grpc.WithBlock(),
49+
)
50+
51+
if err != nil {
52+
http.Error(w, fmt.Sprintf("gRPC server connectivity check failed: %v", err), http.StatusServiceUnavailable)
53+
return
4354
}
55+
defer conn.Close()
4456

45-
resp, err := hs.Check(ctx, req)
57+
hc := healthpb.NewHealthClient(conn)
58+
resp, err := hc.Check(ctx, &healthpb.HealthCheckRequest{Service: ""})
4659
if err != nil {
47-
http.Error(w, "gRPC health check failed", http.StatusInternalServerError)
60+
http.Error(w, fmt.Sprintf("gRPC health check failed: %v", err), http.StatusInternalServerError)
4861
return
4962
}
5063

go/main.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func StartGrpcServer(fs *feast.FeatureStore, host string, port int, loggingServi
137137
return err
138138
}
139139

140-
grpcServer, _ := ser.RegisterServices()
140+
grpcServer := ser.RegisterServices()
141141

142142
// Running Prometheus metrics endpoint on a separate goroutine
143143
go func() {
@@ -206,7 +206,7 @@ func StartHybridServer(fs *feast.FeatureStore, host string, httpPort int, grpcPo
206206
return err
207207
}
208208

209-
grpcSer, healthService := ser.RegisterServices()
209+
grpcSer := ser.RegisterServices()
210210

211211
if err != nil {
212212
return err
@@ -235,7 +235,11 @@ func StartHybridServer(fs *feast.FeatureStore, host string, httpPort int, grpcPo
235235
log.Info().Msg("HTTP and gRPC servers terminated")
236236
}()
237237

238-
err = httpSer.Serve(host, httpPort, server.DefaultHybridHandlers(httpSer, healthService))
238+
go func() {
239+
if err := httpSer.Serve(host, httpPort, server.DefaultHybridHandlers(httpSer, grpcPort)); err != nil && err != http.ErrServerClosed {
240+
log.Error().Err(err).Msg("HTTP server failed")
241+
}
242+
}()
239243

240244
if err != nil {
241245
return err

sdk/python/feast/sort_key.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Dict, Optional
2+
from typing import Dict, Optional, Union
33

44
from typeguard import typechecked
55

@@ -30,10 +30,19 @@ def __init__(
3030
self,
3131
name: str,
3232
value_type: ValueType,
33-
default_sort_order: SortOrder.Enum.ValueType = SortOrder.ASC,
33+
default_sort_order: Union[str, SortOrder.Enum.ValueType] = SortOrder.ASC,
3434
tags: Optional[Dict[str, str]] = None,
3535
description: str = "",
3636
):
37+
if isinstance(default_sort_order, str):
38+
try:
39+
default_sort_order = SortOrder.Enum.Value(default_sort_order.upper())
40+
except ValueError:
41+
raise ValueError("default_sort_order must be 'ASC' or 'DESC'")
42+
if default_sort_order not in (SortOrder.ASC, SortOrder.DESC):
43+
raise ValueError(
44+
"default_sort_order must be SortOrder.ASC or SortOrder.DESC"
45+
)
3746
self.name = name
3847
# TODO: Handle ValueType conversion, user should be able to pass in a dtype instead of ValueType
3948
self.value_type = value_type

sdk/python/feast/sorted_feature_view.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,13 @@ def ensure_valid(self):
121121
"SortedFeatureView must have at least one sort key defined."
122122
)
123123

124+
seen_sort_keys = set()
124125
for sort_key in self.sort_keys:
126+
# Check for duplicate sort keys
127+
if sort_key.name in seen_sort_keys:
128+
raise ValueError(f"Duplicate sort key found: '{sort_key.name}'.")
129+
seen_sort_keys.add(sort_key.name)
130+
125131
# Sort keys should not conflict with entity names.
126132
if sort_key.name in self.entities:
127133
raise ValueError(

sdk/python/tests/unit/test_sorted_feature_view.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,61 @@ def test_sorted_feature_view_duplicate_features():
423423
schema=schema,
424424
sort_keys=[sort_key],
425425
)
426+
427+
428+
def test_sorted_feature_view_duplicate_sort_keys():
429+
"""
430+
Test that a SortedFeatureView fails validation if duplicate sort key names are present.
431+
"""
432+
source = FileSource(path="some path")
433+
entity = Entity(name="entity1", join_keys=["entity1_id"])
434+
# Schema with duplicate feature names.
435+
schema = [
436+
Field(name="dup_field", dtype=Int64),
437+
]
438+
sort_key_1 = SortKey(
439+
name="dup_field",
440+
value_type=ValueType.INT64,
441+
default_sort_order=SortOrder.ASC,
442+
)
443+
sort_key_2 = SortKey(
444+
name="dup_field",
445+
value_type=ValueType.INT64,
446+
default_sort_order=SortOrder.ASC,
447+
)
448+
with pytest.raises(ValueError, match="Duplicate sort key found: 'dup_field'."):
449+
SortedFeatureView(
450+
name="invalid_sorted_feature_view",
451+
source=source,
452+
entities=[entity],
453+
schema=schema,
454+
sort_keys=[sort_key_1, sort_key_2],
455+
)
456+
457+
458+
def test_sorted_feature_view_invalid_sort_key_order_str():
459+
"""
460+
Test that a SortedFeatureView fails validation if default_sort_order is incorrect.
461+
"""
462+
463+
with pytest.raises(ValueError, match="default_sort_order must be 'ASC' or 'DESC'"):
464+
SortKey(
465+
name="dup_field",
466+
value_type=ValueType.INT64,
467+
default_sort_order="99",
468+
)
469+
470+
471+
def test_sorted_feature_view_invalid_sort_key_order_int():
472+
"""
473+
Test that a SortedFeatureView fails validation if default_sort_order is incorrect.
474+
"""
475+
476+
with pytest.raises(
477+
ValueError, match="default_sort_order must be SortOrder.ASC or SortOrder.DESC"
478+
):
479+
SortKey(
480+
name="dup_field",
481+
value_type=ValueType.INT64,
482+
default_sort_order=99,
483+
)

0 commit comments

Comments
 (0)