diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index 6ea7d438c..11587fb1c 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -37,6 +37,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" @@ -66,6 +67,7 @@ type StreamingServer struct { type Scheduler interface { Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) + RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, tragetPodName string) (*schedulingtypes.Result, error) } // RequestContext stores context information during the life time of an HTTP request. @@ -189,6 +191,7 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) case *extProcPb.ProcessingRequest_RequestTrailers: // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: + responseHeaders := make(map[string]string) for _, header := range v.ResponseHeaders.Headers.GetHeaders() { value := string(header.RawValue) @@ -199,27 +202,53 @@ func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) reqCtx.modelServerStreaming = true loggerTrace.Info("model server is streaming response") } + responseHeaders[header.Key] = value } - reqCtx.RequestState = ResponseRecieved - reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, + llmReq := &schedulingtypes.LLMRequest{ + Model: reqCtx.Model, + Headers: responseHeaders, + ResolvedTargetModel: reqCtx.ResolvedTargetModel, + } + + var result *types.Result + result, err = s.scheduler.RunPostResponsePlugins(ctx, llmReq, reqCtx.TargetPod) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error handling response") + reqCtx.ResponseStatusCode = errutil.ModelServerError + } else { + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + // This is for debugging purpose only. + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + } + + // Add headers added by PostResponse + for key, value := range result.MutatedHeaders { + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: key, + RawValue: []byte(value), + }, + }) + } + + reqCtx.RequestState = ResponseRecieved + reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, }, }, }, }, - }, + } } case *extProcPb.ProcessingRequest_ResponseBody: diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go index 5c64228ca..3f064fe75 100644 --- a/pkg/epp/scheduling/config.go +++ b/pkg/epp/scheduling/config.go @@ -26,6 +26,7 @@ type SchedulerConfig struct { scorers map[plugins.Scorer]int // map from scorer to weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } var defPlugin = &defaultPlugin{} @@ -40,4 +41,5 @@ var defaultConfig = &SchedulerConfig{ scorers: map[plugins.Scorer]int{}, picker: defPlugin, postSchedulePlugins: []plugins.PostSchedule{}, + postResponsePlugins: []plugins.PostResponse{}, } diff --git a/pkg/epp/scheduling/local_config.go b/pkg/epp/scheduling/local_config.go index 85b91d7cd..d1df2459c 100644 --- a/pkg/epp/scheduling/local_config.go +++ b/pkg/epp/scheduling/local_config.go @@ -36,6 +36,10 @@ const ( loadAwareScorerWeightEnvVar = "LOAD_AWARE_SCORER_WEIGHT" ) +func init() { + setDefaultConfig() +} + func setDefaultConfig() { // since the default config is a global variable, we add this function to minimize rebase conflicts. // this configuration is a temporary state, it should be better streamlined. diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index f4e1714d4..3dd0ca059 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -69,7 +69,6 @@ var ( ) func NewScheduler(datastore Datastore) *Scheduler { - setDefaultConfig() return NewSchedulerWithConfig(datastore, defaultConfig) } @@ -81,6 +80,7 @@ func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Sched scorers: config.scorers, picker: config.picker, postSchedulePlugins: config.postSchedulePlugins, + postResponsePlugins: config.postResponsePlugins, } } @@ -91,6 +91,7 @@ type Scheduler struct { scorers map[plugins.Scorer]int // map from scorer to its weight picker plugins.Picker postSchedulePlugins []plugins.PostSchedule + postResponsePlugins []plugins.PostResponse } type Datastore interface { @@ -211,6 +212,38 @@ func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *ty } } +func (s *Scheduler) RunPostResponsePlugins(ctx context.Context, req *types.LLMRequest, targetPodName string) (*types.Result, error) { + logger := log.FromContext(ctx) + + pool, err := s.datastore.PoolGet() + if err != nil { + return nil, errutil.Error{Code: errutil.Internal, Msg: "failed to find a target pod"} // pool not defined, no pods + } + + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + pods := types.ToSchedulerPodMetrics(s.datastore.PodGetAll()) + var targetPod types.Pod + for _, pod := range pods { + if pod.GetPod().NamespacedName.String() == targetPodName { + targetPod = pod + break + } + } + + sCtx := types.NewSchedulingContext(ctx, req, pods, pool.Spec.TargetPortNumber) + + for _, plugin := range s.postResponsePlugins { + logger.V(logutil.DEBUG).Info("Running post-response plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostResponse(sCtx, targetPod) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostResponsePluginType, plugin.Name(), time.Since(before)) + } + + return &types.Result{TargetPod: nil, MutatedHeaders: sCtx.MutatedHeaders}, nil +} + type defaultPlugin struct { picker.RandomPicker } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go index e6d229aee..eafa8d681 100644 --- a/pkg/epp/scheduling/scheduler_test.go +++ b/pkg/epp/scheduling/scheduler_test.go @@ -483,6 +483,56 @@ func TestSchedulePlugins(t *testing.T) { } } +func TestPostResponse(t *testing.T) { + pr1 := &testPostResponse{ + NameRes: "pr1", + ExtraHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + ReceivedResponseHeaders: make(map[string]string), + } + + tests := []struct { + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + responseHeaders map[string]string + wantMutatedHeaders map[string]string + }{ + { + name: "Simple postResponse test", + config: SchedulerConfig{ + postResponsePlugins: []plugins.PostResponse{pr1}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + }, + responseHeaders: map[string]string{"Content-type": "application/json", "Content-Length": "1234"}, + wantMutatedHeaders: map[string]string{"x-session-id": "qwer-asdf-zxcv"}, + }, + } + + for _, test := range tests { + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) + + req := &types.LLMRequest{ + Model: "test-model", + Headers: test.responseHeaders, + } + + result, err := scheduler.RunPostResponsePlugins(context.Background(), req, test.input[0].Pod.NamespacedName.String()) + if err != nil { + t.Errorf("Received an error. Error: %s", err) + } + + if diff := cmp.Diff(test.responseHeaders, pr1.ReceivedResponseHeaders); diff != "" { + t.Errorf("Unexpected output (-responseHeaders +ReceivedResponseHeaders): %v", diff) + } + + if diff := cmp.Diff(test.wantMutatedHeaders, result.MutatedHeaders); diff != "" { + t.Errorf("Unexpected output (-wantedMutatedHeaders +MutatedHeaders): %v", diff) + } + } +} + type fakeDataStore struct { pods []*backendmetrics.FakePodMetrics } @@ -571,6 +621,23 @@ func (tp *TestPlugin) reset() { tp.NumOfPickerCandidates = 0 } +type testPostResponse struct { + NameRes string + ReceivedResponseHeaders map[string]string + ExtraHeaders map[string]string +} + +func (pr *testPostResponse) Name() string { return pr.NameRes } + +func (pr *testPostResponse) PostResponse(ctx *types.SchedulingContext, pod types.Pod) { + for key, value := range ctx.Req.Headers { + pr.ReceivedResponseHeaders[key] = value + } + for key, value := range pr.ExtraHeaders { + ctx.MutatedHeaders[key] = value + } +} + func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { res := []types.Pod{} for _, pod := range ctx.PodsSnapshot { diff --git a/pkg/epp/scheduling/scorers_test.go b/pkg/epp/scheduling/scorers_test.go index a98a838b1..640143bf1 100644 --- a/pkg/epp/scheduling/scorers_test.go +++ b/pkg/epp/scheduling/scorers_test.go @@ -86,19 +86,23 @@ func TestScorers(t *testing.T) { }, }, wantRes: &types.Result{ - TargetPod: &types.PodMetrics{ - Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, + TargetPod: &types.ScoredPod{ + Pod: &types.PodMetrics{ + Pod: &backendmetrics.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, }, - WaitingModels: map[string]int{}, }, + Score: 0.5, }, + MutatedHeaders: map[string]string{}, }, }, }