diff --git a/.github/ISSUE_TEMPLATE/blank_issue.md b/.github/ISSUE_TEMPLATE/blank_issue.md new file mode 100644 index 00000000..dd6ebabf --- /dev/null +++ b/.github/ISSUE_TEMPLATE/blank_issue.md @@ -0,0 +1,8 @@ +--- +name: Blank Issue +about: Create a new issue from scratch +title: '' +labels: needs-triage +assignees: '' + +--- \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_request.md b/.github/ISSUE_TEMPLATE/bug_request.md index c2597eb3..15ed35e1 100644 --- a/.github/ISSUE_TEMPLATE/bug_request.md +++ b/.github/ISSUE_TEMPLATE/bug_request.md @@ -1,7 +1,9 @@ --- name: Bug Report about: Report a bug you encountered -labels: kind/bug +title: '' +labels: kind/bug, needs-triage +assignees: '' --- diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000..3ba13e0c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1 @@ +blank_issues_enabled: false diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 53a885c7..1eee5871 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -2,7 +2,7 @@ name: Feature request about: Suggest an idea for this project title: '' -labels: '' +labels: needs-triage assignees: '' --- @@ -12,4 +12,3 @@ assignees: '' **What would you like to be added**: **Why is this needed**: - diff --git a/.github/ISSUE_TEMPLATE/new-release.md b/.github/ISSUE_TEMPLATE/new-release.md index be569844..27e83784 100644 --- a/.github/ISSUE_TEMPLATE/new-release.md +++ b/.github/ISSUE_TEMPLATE/new-release.md @@ -4,6 +4,7 @@ about: Propose a new release title: Release v0.x.0 labels: '' assignees: '' + --- - [Introduction](#introduction) diff --git a/Dockerfile b/Dockerfile index 8fb00dfb..d050b869 100644 --- a/Dockerfile +++ b/Dockerfile @@ -19,8 +19,9 @@ COPY cmd ./cmd COPY pkg ./pkg COPY internal ./internal COPY api ./api +COPY .git ./.git WORKDIR /src/cmd/epp -RUN go build -o /epp +RUN go build -buildvcs=true -o /epp ## Multistage deploy FROM ${BASE_IMAGE} diff --git a/Makefile b/Makefile index 66fe89d4..4826a029 100644 --- a/Makefile +++ b/Makefile @@ -121,10 +121,14 @@ vet: ## Run go vet against code. .PHONY: test test: manifests generate fmt vet envtest image-build ## Run tests. - KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e) -race -coverprofile cover.out + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test $$(go list ./... | grep -v /e2e | grep -v /conformance) -race -coverprofile cover.out + +.PHONY: test-unit +test-unit: ## Run unit tests. + KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./pkg/... -race -coverprofile cover.out .PHONY: test-integration -test-integration: ## Run tests. +test-integration: ## Run integration tests. KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./test/integration/epp/... -race -coverprofile cover.out .PHONY: test-e2e @@ -232,7 +236,7 @@ bbr-image-local-load: bbr-image-local-build .PHONY: bbr-image-build bbr-image-build: ## Build the image using Docker Buildx. - $(IMAGE_BUILD_CMD) -f body-based-routing.Dockerfile -t $(BBR_IMAGE_TAG) \ + $(IMAGE_BUILD_CMD) -f bbr.Dockerfile -t $(BBR_IMAGE_TAG) \ --platform=$(PLATFORMS) \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg BUILDER_IMAGE=$(BUILDER_IMAGE) \ diff --git a/OWNERS_ALIASES b/OWNERS_ALIASES index 6e8e0c5d..933fbe9c 100644 --- a/OWNERS_ALIASES +++ b/OWNERS_ALIASES @@ -11,6 +11,9 @@ aliases: gateway-api-inference-extension-reviewers: - liu-cong - robscott + - shaneutt + - nirrozenbaum + wg-serving-leads: - ArangoGutierrez diff --git a/README.md b/README.md index 2ff00581..ffd86758 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,57 @@ -# Gateway API Inference Extension +[![Go Report Card](https://goreportcard.com/badge/sigs.k8s.io/gateway-api-inference-extension)](https://goreportcard.com/report/sigs.k8s.io/gateway-api-inference-extension) +[![Go Reference](https://pkg.go.dev/badge/sigs.k8s.io/gateway-api-inference-extension.svg)](https://pkg.go.dev/sigs.k8s.io/gateway-api-inference-extension) +[![License](https://img.shields.io/github/license/kubernetes-sigs/gateway-api-inference-extension)](/LICENSE) + +# Gateway API Inference Extension (GIE) + +This project offers tools for AI Inference, enabling developers to build [Inference Gateways]. + +[Inference Gateways]:#concepts-and-definitions + +## Concepts and Definitions + +The following are some key industry terms that are important to understand for +this project: + +- **Model**: A generative AI model that has learned patterns from data and is + used for inference. Models vary in size and architecture, from smaller + domain-specific models to massive multi-billion parameter neural networks that + are optimized for diverse language tasks. +- **Inference**: The process of running a generative AI model, such as a large + language model, diffusion model etc, to generate text, embeddings, or other + outputs from input data. +- **Model server**: A service (in our case, containerized) responsible for + receiving inference requests and returning predictions from a model. +- **Accelerator**: specialized hardware, such as Graphics Processing Units + (GPUs) that can be attached to Kubernetes nodes to speed up computations, + particularly for training and inference tasks. + +And the following are more specific terms to this project: + +- **Scheduler**: Makes decisions about which endpoint is optimal (best cost / + best performance) for an inference request based on `Metrics and Capabilities` + from [Model Serving](/docs/proposals/003-model-server-protocol/README.md). +- **Metrics and Capabilities**: Data provided by model serving platforms about + performance, availability and capabilities to optimize routing. Includes + things like [Prefix Cache] status or [LoRA Adapters] availability. +- **Endpoint Selector**: A `Scheduler` combined with `Metrics and Capabilities` + systems is often referred to together as an [Endpoint Selection Extension] + (this is also sometimes referred to as an "endpoint picker", or "EPP"). +- **Inference Gateway**: A proxy/load-balancer which has been coupled with a + `Endpoint Selector`. It provides optimized routing and load balancing for + serving Kubernetes self-hosted generative Artificial Intelligence (AI) + workloads. It simplifies the deployment, management, and observability of AI + inference workloads. + +For deeper insights and more advanced concepts, refer to our [proposals](/docs/proposals). + +[Inference]:https://www.digitalocean.com/community/tutorials/llm-inference-optimization +[Gateway API]:https://github.com/kubernetes-sigs/gateway-api +[Prefix Cache]:https://docs.vllm.ai/en/stable/design/v1/prefix_caching.html +[LoRA Adapters]:https://docs.vllm.ai/en/stable/features/lora.html +[Endpoint Selection Extension]:https://gateway-api-inference-extension.sigs.k8s.io/#endpoint-selection-extension + +## Technical Overview This extension upgrades an [ext-proc](https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_filters/ext_proc_filter)-capable proxy or gateway - such as Envoy Gateway, kGateway, or the GKE Gateway - to become an **inference gateway** - supporting inference platform teams self-hosting large language models on Kubernetes. This integration makes it easy to expose and control access to your local [OpenAI-compatible chat completion endpoints](https://platform.openai.com/docs/api-reference/chat) to other workloads on or off cluster, or to integrate your self-hosted models alongside model-as-a-service providers in a higher level **AI Gateway** like LiteLLM, Solo AI Gateway, or Apigee. @@ -15,7 +68,7 @@ It currently requires a version of vLLM that supports the necessary metrics to p ## Status -This project is [alpha (0.2 release)](https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/tag/v0.2.0). It should not be used in production yet. +This project is [alpha (0.3 release)](https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/tag/v0.3.0). It should not be used in production yet. ## Getting Started diff --git a/api/v1alpha2/inferencemodel_types.go b/api/v1alpha2/inferencemodel_types.go index 052683d8..7cd98a74 100644 --- a/api/v1alpha2/inferencemodel_types.go +++ b/api/v1alpha2/inferencemodel_types.go @@ -126,7 +126,7 @@ type PoolObjectReference struct { } // Criticality defines how important it is to serve the model compared to other models. -// Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional(use a pointer), and set no default. +// Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional (use a pointer), and set no default. // This allows us to union this with a oneOf field in the future should we wish to adjust/extend this behavior. // +kubebuilder:validation:Enum=Critical;Standard;Sheddable type Criticality string diff --git a/body-based-routing.Dockerfile b/bbr.Dockerfile similarity index 76% rename from body-based-routing.Dockerfile rename to bbr.Dockerfile index e0afcf20..03024e49 100644 --- a/body-based-routing.Dockerfile +++ b/bbr.Dockerfile @@ -18,13 +18,13 @@ RUN go mod download COPY cmd ./cmd COPY pkg ./pkg COPY internal ./internal -WORKDIR /src/cmd/body-based-routing -RUN go build -o /body-based-routing +WORKDIR /src/cmd/bbr +RUN go build -o /bbr ## Multistage deploy FROM ${BASE_IMAGE} WORKDIR / -COPY --from=builder /body-based-routing /body-based-routing +COPY --from=builder /bbr /bbr -ENTRYPOINT ["/body-based-routing"] +ENTRYPOINT ["/bbr"] diff --git a/cmd/body-based-routing/health.go b/cmd/bbr/health.go similarity index 100% rename from cmd/body-based-routing/health.go rename to cmd/bbr/health.go diff --git a/cmd/body-based-routing/main.go b/cmd/bbr/main.go similarity index 98% rename from cmd/body-based-routing/main.go rename to cmd/bbr/main.go index cfc584ce..84b1fffa 100644 --- a/cmd/body-based-routing/main.go +++ b/cmd/bbr/main.go @@ -36,7 +36,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/metrics/filters" "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" - runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/server" + runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/server" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) diff --git a/cmd/epp/main.go b/cmd/epp/main.go index 39baf18b..2bd779c5 100644 --- a/cmd/epp/main.go +++ b/cmd/epp/main.go @@ -30,6 +30,7 @@ import ( "go.uber.org/zap/zapcore" "google.golang.org/grpc" healthPb "google.golang.org/grpc/health/grpc_health_v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/rest" "k8s.io/component-base/metrics/legacyregistry" ctrl "sigs.k8s.io/controller-runtime" @@ -40,6 +41,7 @@ import ( backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -120,11 +122,6 @@ func run() error { flag.Parse() initLogging(&opts) - useStreamingServer, err := strconv.ParseBool(os.Getenv("USE_STREAMING")) - if err != nil { - setupLog.Error(err, "Failed to parse env var USE_STREAMING, defaulting to false") - } - // Validate flags if err := validateFlags(); err != nil { setupLog.Error(err, "Failed to validate flags") @@ -145,14 +142,16 @@ func run() error { return err } - mgr, err := runserver.NewDefaultManager(*poolNamespace, *poolName, cfg) + poolNamespacedName := types.NamespacedName{ + Name: *poolName, + Namespace: *poolNamespace, + } + mgr, err := runserver.NewDefaultManager(poolNamespacedName, cfg) if err != nil { setupLog.Error(err, "Failed to create controller manager") return err } - ctx := ctrl.SetupSignalHandler() - // Set up mapper for metric scraping. mapping, err := backendmetrics.NewMetricMapping( *totalQueuedRequestsMetric, @@ -167,19 +166,21 @@ func run() error { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{MetricMapping: mapping}, *refreshMetricsInterval) // Setup runner. + ctx := ctrl.SetupSignalHandler() + datastore := datastore.NewDatastore(ctx, pmf) + scheduler := scheduling.NewScheduler(datastore) serverRunner := &runserver.ExtProcServerRunner{ GrpcPort: *grpcPort, DestinationEndpointHintMetadataNamespace: *destinationEndpointHintMetadataNamespace, DestinationEndpointHintKey: *destinationEndpointHintKey, - PoolName: *poolName, - PoolNamespace: *poolNamespace, + PoolNamespacedName: poolNamespacedName, Datastore: datastore, SecureServing: *secureServing, CertPath: *certPath, - UseStreaming: useStreamingServer, RefreshPrometheusMetricsInterval: *refreshPrometheusMetricsInterval, + Scheduler: scheduler, } if err := serverRunner.SetupWithManager(ctx, mgr); err != nil { setupLog.Error(err, "Failed to setup ext-proc controllers") @@ -249,6 +250,8 @@ func registerHealthServer(mgr manager.Manager, logger logr.Logger, ds datastore. func registerMetricsHandler(mgr manager.Manager, port int, cfg *rest.Config) error { metrics.Register() + metrics.RecordInferenceExtensionInfo() + // Init HTTP server. h, err := metricsHandlerWithAuthenticationAndAuthorization(cfg) if err != nil { diff --git a/config/charts/body-based-routing/README.md b/config/charts/body-based-routing/README.md index 062f2b5c..d311b8c3 100644 --- a/config/charts/body-based-routing/README.md +++ b/config/charts/body-based-routing/README.md @@ -10,7 +10,7 @@ To install a body-based router named `body-based-router`, you can run the follow ```txt $ helm install body-based-router ./config/charts/body-based-routing \ --set provider.name=[gke|istio] \ - --set inference-gateway.name=inference-gateway + --set inferenceGateway.name=inference-gateway ``` Note that the provider name is needed to ensure provider-specific manifests are also applied. If no provider is specified, then only @@ -19,7 +19,7 @@ the deployment and service are deployed. To install via the latest published chart in staging (--version v0 indicates latest dev version), you can run the following command: ```txt -$ helm install body-based-router oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/body-based-router \ +$ helm install body-based-router oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/body-based-routing \ --version v0 --set provider.name=[gke|istio] ``` @@ -47,8 +47,8 @@ The following table list the configurable parameters of the chart. | `bbr.image.tag` | Image tag. | | `bbr.image.pullPolicy` | Image pull policy for the container. Possible values: `Always`, `IfNotPresent`, or `Never`. Defaults to `Always`. | | `provider.name` | Name of the Inference Gateway implementation being used. Possible values: `istio`, `gke`. Defaults to `none`. | -| `inference-gateway.name` | The name of the Gateway. Defaults to `inference-gateway`. | +| `inferenceGateway.name` | The name of the Gateway. Defaults to `inference-gateway`. | ## Notes -This chart should only be deployed once per Gateway. \ No newline at end of file +This chart should only be deployed once per Gateway. diff --git a/config/charts/body-based-routing/templates/gke.yaml b/config/charts/body-based-routing/templates/gke.yaml index 937bfa0b..77b776a4 100644 --- a/config/charts/body-based-routing/templates/gke.yaml +++ b/config/charts/body-based-routing/templates/gke.yaml @@ -9,7 +9,7 @@ spec: targetRefs: - group: "gateway.networking.k8s.io" kind: Gateway - name: {{ .Values.inference-gateway.name }} + name: {{ .Values.inferenceGateway.name }} extensionChains: - name: chain1 extensions: diff --git a/config/charts/body-based-routing/templates/istio.yaml b/config/charts/body-based-routing/templates/istio.yaml index c4c1444f..6d4535cc 100644 --- a/config/charts/body-based-routing/templates/istio.yaml +++ b/config/charts/body-based-routing/templates/istio.yaml @@ -25,9 +25,9 @@ spec: processing_mode: request_header_mode: "SEND" response_header_mode: "SKIP" - request_body_mode: "BUFFERED" + request_body_mode: "FULL_DUPLEX_STREAMED" response_body_mode: "NONE" - request_trailer_mode: "SKIP" + request_trailer_mode: "SEND" response_trailer_mode: "SKIP" grpc_service: envoy_grpc: diff --git a/config/charts/body-based-routing/values.yaml b/config/charts/body-based-routing/values.yaml index b77d7542..0b88dc43 100644 --- a/config/charts/body-based-routing/values.yaml +++ b/config/charts/body-based-routing/values.yaml @@ -12,5 +12,5 @@ bbr: provider: name: none -inference-gateway: +inferenceGateway: name: inference-gateway diff --git a/config/charts/inferencepool/README.md b/config/charts/inferencepool/README.md index 681fc783..301e3d9c 100644 --- a/config/charts/inferencepool/README.md +++ b/config/charts/inferencepool/README.md @@ -2,7 +2,6 @@ A chart to deploy an InferencePool and a corresponding EndpointPicker (epp) deployment. - ## Install To install an InferencePool named `vllm-llama3-8b-instruct` that selects from endpoints with label `app: vllm-llama3-8b-instruct` and listening on port `8000`, you can run the following command: @@ -17,6 +16,21 @@ To install via the latest published chart in staging (--version v0 indicates la ```txt $ helm install vllm-llama3-8b-instruct \ --set inferencePool.modelServers.matchLabels.app=vllm-llama3-8b-instruct \ + --set provider.name=[none|gke] \ + oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/inferencepool --version v0 +``` + +Note that the provider name is needed to deploy provider-specific resources. If no provider is specified, then only the InferencePool object and the EPP are deployed. + +### Install for Triton TensorRT-LLM + +Use `--set inferencePool.modelServerType=triton-tensorrt-llm` to install for Triton TensorRT-LLM, e.g., + +```txt +$ helm install triton-llama3-8b-instruct \ + --set inferencePool.modelServers.matchLabels.app=triton-llama3-8b-instruct \ + --set inferencePool.modelServerType=triton-tensorrt-llm \ + --set provider.name=[none|gke] \ oci://us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/charts/inferencepool --version v0 ``` @@ -34,8 +48,8 @@ The following table list the configurable parameters of the chart. | **Parameter Name** | **Description** | |---------------------------------------------|------------------------------------------------------------------------------------------------------------------------| -| `inferencePool.name` | Name for the InferencePool, and endpoint picker deployment and service will be named as `{.Release.name}-epp`. | | `inferencePool.targetPortNumber` | Target port number for the vllm backends, will be used to scrape metrics by the inference extension. Defaults to 8000. | +| `inferencePool.modelServerType` | Type of the model servers in the pool, valid options are [vllm, triton-tensorrt-llm], default is vllm. | | `inferencePool.modelServers.matchLabels` | Label selector to match vllm backends managed by the inference pool. | | `inferenceExtension.replicas` | Number of replicas for the endpoint picker extension service. Defaults to `1`. | | `inferenceExtension.image.name` | Name of the container image used for the endpoint picker. | @@ -43,6 +57,7 @@ The following table list the configurable parameters of the chart. | `inferenceExtension.image.tag` | Image tag of the endpoint picker. | | `inferenceExtension.image.pullPolicy` | Image pull policy for the container. Possible values: `Always`, `IfNotPresent`, or `Never`. Defaults to `Always`. | | `inferenceExtension.extProcPort` | Port where the endpoint picker service is served for external processing. Defaults to `9002`. | +| `provider.name` | Name of the Inference Gateway implementation being used. Possible values: `gke`. Defaults to `none`. | ## Notes diff --git a/config/charts/inferencepool/templates/epp-deployment.yaml b/config/charts/inferencepool/templates/epp-deployment.yaml index d925a38e..fc490210 100644 --- a/config/charts/inferencepool/templates/epp-deployment.yaml +++ b/config/charts/inferencepool/templates/epp-deployment.yaml @@ -35,9 +35,14 @@ spec: - "9003" - -metricsPort - "9090" - env: - - name: USE_STREAMING - value: "true" + {{- if eq (.Values.inferencePool.modelServerType | default "vllm") "triton-tensorrt-llm" }} + - -totalQueuedRequestsMetric + - "nv_trt_llm_request_metrics{request_type=waiting}" + - -kvCacheUsagePercentageMetric + - "nv_trt_llm_kv_cache_block_metrics{kv_cache_block_type=fraction}" + - -loraInfoMetric + - "" # Set an empty metric to disable LoRA metric scraping as they are not supported by Triton yet. + {{- end }} ports: - name: grpc containerPort: 9002 @@ -57,4 +62,3 @@ spec: service: inference-extension initialDelaySeconds: 5 periodSeconds: 10 - diff --git a/config/charts/inferencepool/templates/gke.yaml b/config/charts/inferencepool/templates/gke.yaml index 220b3bea..70e05b56 100644 --- a/config/charts/inferencepool/templates/gke.yaml +++ b/config/charts/inferencepool/templates/gke.yaml @@ -33,6 +33,8 @@ spec: name: {{ .Release.Name }} default: timeoutSec: 300 # 5-minute timeout (adjust as needed) + logging: + enabled: true # log all requests by default --- apiVersion: monitoring.googleapis.com/v1 kind: ClusterPodMonitoring diff --git a/config/charts/inferencepool/values.yaml b/config/charts/inferencepool/values.yaml index 766ee087..bd48f37e 100644 --- a/config/charts/inferencepool/values.yaml +++ b/config/charts/inferencepool/values.yaml @@ -9,6 +9,7 @@ inferenceExtension: inferencePool: targetPortNumber: 8000 + modelServerType: vllm # vllm, triton-tensorrt-llm # modelServers: # REQUIRED # matchLabels: # app: vllm-llama3-8b-instruct diff --git a/config/default/kustomization.yaml b/config/default/kustomization.yaml deleted file mode 100644 index 1fd9939f..00000000 --- a/config/default/kustomization.yaml +++ /dev/null @@ -1,151 +0,0 @@ -# Adds namespace to all resources. -namespace: api-system - -# Value of this field is prepended to the -# names of all resources, e.g. a deployment named -# "wordpress" becomes "alices-wordpress". -# Note that it should also match with the prefix (text before '-') of the namespace -# field above. -namePrefix: api- - -# Labels to add to all resources and selectors. -#labels: -#- includeSelectors: true -# pairs: -# someName: someValue - -resources: -- ../crd -- ../rbac -- ../manager -# [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in -# crd/kustomization.yaml -#- ../webhook -# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'. 'WEBHOOK' components are required. -#- ../certmanager -# [PROMETHEUS] To enable prometheus monitor, uncomment all sections with 'PROMETHEUS'. -#- ../prometheus -# [METRICS] Expose the controller manager metrics service. -- metrics_service.yaml -# [NETWORK POLICY] Protect the /metrics endpoint and Webhook Server with NetworkPolicy. -# Only Pod(s) running a namespace labeled with 'metrics: enabled' will be able to gather the metrics. -# Only CR(s) which requires webhooks and are applied on namespaces labeled with 'webhooks: enabled' will -# be able to communicate with the Webhook Server. -#- ../network-policy - -# Uncomment the patches line if you enable Metrics, and/or are using webhooks and cert-manager -patches: -# [METRICS] The following patch will enable the metrics endpoint using HTTPS and the port :8443. -# More info: https://book.kubebuilder.io/reference/metrics -- path: manager_metrics_patch.yaml - target: - kind: Deployment - -# [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in -# crd/kustomization.yaml -#- path: manager_webhook_patch.yaml - -# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER'. -# Uncomment 'CERTMANAGER' sections in crd/kustomization.yaml to enable the CA injection in the admission webhooks. -# 'CERTMANAGER' needs to be enabled to use ca injection -#- path: webhookcainjection_patch.yaml - -# [CERTMANAGER] To enable cert-manager, uncomment all sections with 'CERTMANAGER' prefix. -# Uncomment the following replacements to add the cert-manager CA injection annotations -#replacements: -# - source: # Add cert-manager annotation to ValidatingWebhookConfiguration, MutatingWebhookConfiguration and CRDs -# kind: Certificate -# group: cert-manager.io -# version: v1 -# name: serving-cert # this name should match the one in certificate.yaml -# fieldPath: .metadata.namespace # namespace of the certificate CR -# targets: -# - select: -# kind: ValidatingWebhookConfiguration -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 0 -# create: true -# - select: -# kind: MutatingWebhookConfiguration -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 0 -# create: true -# - select: -# kind: CustomResourceDefinition -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 0 -# create: true -# - source: -# kind: Certificate -# group: cert-manager.io -# version: v1 -# name: serving-cert # this name should match the one in certificate.yaml -# fieldPath: .metadata.name -# targets: -# - select: -# kind: ValidatingWebhookConfiguration -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 1 -# create: true -# - select: -# kind: MutatingWebhookConfiguration -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 1 -# create: true -# - select: -# kind: CustomResourceDefinition -# fieldPaths: -# - .metadata.annotations.[cert-manager.io/inject-ca-from] -# options: -# delimiter: '/' -# index: 1 -# create: true -# - source: # Add cert-manager annotation to the webhook Service -# kind: Service -# version: v1 -# name: webhook-service -# fieldPath: .metadata.name # namespace of the service -# targets: -# - select: -# kind: Certificate -# group: cert-manager.io -# version: v1 -# fieldPaths: -# - .spec.dnsNames.0 -# - .spec.dnsNames.1 -# options: -# delimiter: '.' -# index: 0 -# create: true -# - source: -# kind: Service -# version: v1 -# name: webhook-service -# fieldPath: .metadata.namespace # namespace of the service -# targets: -# - select: -# kind: Certificate -# group: cert-manager.io -# version: v1 -# fieldPaths: -# - .spec.dnsNames.0 -# - .spec.dnsNames.1 -# options: -# delimiter: '.' -# index: 1 -# create: true diff --git a/config/default/manager_metrics_patch.yaml b/config/default/manager_metrics_patch.yaml deleted file mode 100644 index 2aaef653..00000000 --- a/config/default/manager_metrics_patch.yaml +++ /dev/null @@ -1,4 +0,0 @@ -# This patch adds the args to allow exposing the metrics endpoint using HTTPS -- op: add - path: /spec/template/spec/containers/0/args/0 - value: --metrics-bind-address=:8443 diff --git a/config/default/metrics_service.yaml b/config/default/metrics_service.yaml deleted file mode 100644 index 140d4943..00000000 --- a/config/default/metrics_service.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - control-plane: controller-manager - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: controller-manager-metrics-service - namespace: system -spec: - ports: - - name: https - port: 8443 - protocol: TCP - targetPort: 8443 - selector: - control-plane: controller-manager diff --git a/config/manifests/gateway/gke/gcp-backend-policy.yaml b/config/manifests/gateway/gke/gcp-backend-policy.yaml index 519a5a93..7b294304 100644 --- a/config/manifests/gateway/gke/gcp-backend-policy.yaml +++ b/config/manifests/gateway/gke/gcp-backend-policy.yaml @@ -9,3 +9,5 @@ spec: name: vllm-llama3-8b-instruct default: timeoutSec: 300 + logging: + enabled: true diff --git a/config/manifests/gateway/gke/healthcheck.yaml b/config/manifests/gateway/gke/healthcheck.yaml index 95f4f2d2..93b6cd7f 100644 --- a/config/manifests/gateway/gke/healthcheck.yaml +++ b/config/manifests/gateway/gke/healthcheck.yaml @@ -7,7 +7,7 @@ spec: targetRef: group: "inference.networking.x-k8s.io" kind: InferencePool - name: vllm-llama2-7b + name: vllm-llama3-8b-instruct default: config: type: HTTP diff --git a/config/manifests/inferencemodel.yaml b/config/manifests/inferencemodel.yaml index 75c9bb17..67c91d0e 100644 --- a/config/manifests/inferencemodel.yaml +++ b/config/manifests/inferencemodel.yaml @@ -8,9 +8,8 @@ spec: poolRef: name: vllm-llama3-8b-instruct targetModels: - - name: food-review + - name: food-review-1 weight: 100 - --- apiVersion: inference.networking.x-k8s.io/v1alpha2 kind: InferenceModel @@ -21,7 +20,6 @@ spec: criticality: Critical poolRef: name: vllm-llama3-8b-instruct - --- apiVersion: inference.networking.x-k8s.io/v1alpha2 kind: InferenceModel diff --git a/config/manifests/inferencepool-resources.yaml b/config/manifests/inferencepool-resources.yaml index cef70d7f..3d978292 100644 --- a/config/manifests/inferencepool-resources.yaml +++ b/config/manifests/inferencepool-resources.yaml @@ -1,7 +1,9 @@ +# Note: If you change this file, please also change the file used for e2e tests! +# +# https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/test/testdata/inferencepool-e2e.yaml apiVersion: inference.networking.x-k8s.io/v1alpha2 kind: InferencePool metadata: - labels: name: vllm-llama3-8b-instruct spec: targetPortNumber: 8000 @@ -51,6 +53,8 @@ spec: args: - -poolName - "vllm-llama3-8b-instruct" + - "-poolNamespace" + - "default" - -v - "4" - --zap-encoder @@ -59,9 +63,6 @@ spec: - "9002" - -grpcHealthPort - "9003" - env: - - name: USE_STREAMING - value: "true" ports: - containerPort: 9002 - containerPort: 9003 diff --git a/config/manifests/vllm/cpu-deployment.yaml b/config/manifests/vllm/cpu-deployment.yaml index 6fb40950..827f2156 100644 --- a/config/manifests/vllm/cpu-deployment.yaml +++ b/config/manifests/vllm/cpu-deployment.yaml @@ -113,5 +113,8 @@ data: ensureExist: models: - base-model: Qwen/Qwen2.5-1.5B - id: food-review-1 + id: food-review + source: SriSanth2345/Qwen-1.5B-Tweet-Generations + - base-model: Qwen/Qwen2.5-1.5B + id: cad-fabricator source: SriSanth2345/Qwen-1.5B-Tweet-Generations \ No newline at end of file diff --git a/config/manifests/vllm/gpu-deployment.yaml b/config/manifests/vllm/gpu-deployment.yaml index 4f13736d..16f93882 100644 --- a/config/manifests/vllm/gpu-deployment.yaml +++ b/config/manifests/vllm/gpu-deployment.yaml @@ -24,9 +24,15 @@ spec: - "1" - "--port" - "8000" + - "--max-num-seq" + - "1024" + - "--compilation-config" + - "3" - "--enable-lora" - "--max-loras" - "2" + - "--max-lora-rank" + - "8" - "--max-cpu-loras" - "12" env: @@ -77,7 +83,7 @@ spec: #exec: # command: # - /usr/bin/sleep - # - 30 + # - "30" livenessProbe: httpGet: path: /health @@ -133,7 +139,6 @@ spec: path: /health port: http scheme: HTTP - resources: limits: nvidia.com/gpu: 1 @@ -244,12 +249,10 @@ metadata: data: configmap.yaml: | vLLMLoRAConfig: - name: vllm-llama3.1-8b-instruct + name: vllm-llama3-8b-instruct-adapters port: 8000 defaultBaseModel: meta-llama/Llama-3.1-8B-Instruct ensureExist: models: - - id: food-review + - id: food-review-1 source: Kawon/llama3.1-food-finetune_v14_r8 - - id: cad-fabricator - source: redcathode/fabricator diff --git a/config/network-policy/allow-metrics-traffic.yaml b/config/network-policy/allow-metrics-traffic.yaml deleted file mode 100644 index aae53668..00000000 --- a/config/network-policy/allow-metrics-traffic.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# This NetworkPolicy allows ingress traffic -# with Pods running on namespaces labeled with 'metrics: enabled'. Only Pods on those -# namespaces are able to gathering data from the metrics endpoint. -apiVersion: networking.k8s.io/v1 -kind: NetworkPolicy -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: allow-metrics-traffic - namespace: system -spec: - podSelector: - matchLabels: - control-plane: controller-manager - policyTypes: - - Ingress - ingress: - # This allows ingress traffic from any namespace with the label metrics: enabled - - from: - - namespaceSelector: - matchLabels: - metrics: enabled # Only from namespaces with this label - ports: - - port: 8443 - protocol: TCP diff --git a/config/network-policy/kustomization.yaml b/config/network-policy/kustomization.yaml deleted file mode 100644 index ec0fb5e5..00000000 --- a/config/network-policy/kustomization.yaml +++ /dev/null @@ -1,2 +0,0 @@ -resources: -- allow-metrics-traffic.yaml diff --git a/config/prometheus/kustomization.yaml b/config/prometheus/kustomization.yaml deleted file mode 100644 index ed137168..00000000 --- a/config/prometheus/kustomization.yaml +++ /dev/null @@ -1,2 +0,0 @@ -resources: -- monitor.yaml diff --git a/config/prometheus/monitor.yaml b/config/prometheus/monitor.yaml deleted file mode 100644 index aac24ef3..00000000 --- a/config/prometheus/monitor.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Prometheus Monitor Service (Metrics) -apiVersion: monitoring.coreos.com/v1 -kind: ServiceMonitor -metadata: - labels: - control-plane: controller-manager - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: controller-manager-metrics-monitor - namespace: system -spec: - endpoints: - - path: /metrics - port: https # Ensure this is the name of the port that exposes HTTPS metrics - scheme: https - bearerTokenFile: /var/run/secrets/kubernetes.io/serviceaccount/token - tlsConfig: - # TODO(user): The option insecureSkipVerify: true is not recommended for production since it disables - # certificate verification. This poses a significant security risk by making the system vulnerable to - # man-in-the-middle attacks, where an attacker could intercept and manipulate the communication between - # Prometheus and the monitored services. This could lead to unauthorized access to sensitive metrics data, - # compromising the integrity and confidentiality of the information. - # Please use the following options for secure configurations: - # caFile: /etc/metrics-certs/ca.crt - # certFile: /etc/metrics-certs/tls.crt - # keyFile: /etc/metrics-certs/tls.key - insecureSkipVerify: true - selector: - matchLabels: - control-plane: controller-manager diff --git a/config/rbac/inferencemodel_editor_role.yaml b/config/rbac/inferencemodel_editor_role.yaml deleted file mode 100644 index b175a9a3..00000000 --- a/config/rbac/inferencemodel_editor_role.yaml +++ /dev/null @@ -1,27 +0,0 @@ -# permissions for end users to edit inferencemodels. -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: inferencemodel-editor-role -rules: -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencemodels - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencemodels/status - verbs: - - get diff --git a/config/rbac/inferencemodel_viewer_role.yaml b/config/rbac/inferencemodel_viewer_role.yaml deleted file mode 100644 index 3b3e67f6..00000000 --- a/config/rbac/inferencemodel_viewer_role.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# permissions for end users to view inferencemodels. -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: inferencemodel-viewer-role -rules: -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencemodels - verbs: - - get - - list - - watch -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencemodels/status - verbs: - - get diff --git a/config/rbac/inferencepool_editor_role.yaml b/config/rbac/inferencepool_editor_role.yaml deleted file mode 100644 index cc1f7c35..00000000 --- a/config/rbac/inferencepool_editor_role.yaml +++ /dev/null @@ -1,27 +0,0 @@ -# permissions for end users to edit inferencepools. -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: inferencepool-editor-role -rules: -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencepools - verbs: - - create - - delete - - get - - list - - patch - - update - - watch -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencepools/status - verbs: - - get diff --git a/config/rbac/inferencepool_viewer_role.yaml b/config/rbac/inferencepool_viewer_role.yaml deleted file mode 100644 index 828e0022..00000000 --- a/config/rbac/inferencepool_viewer_role.yaml +++ /dev/null @@ -1,23 +0,0 @@ -# permissions for end users to view inferencepools. -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: inferencepool-viewer-role -rules: -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencepools - verbs: - - get - - list - - watch -- apiGroups: - - inference.networking.x-k8s.io - resources: - - inferencepools/status - verbs: - - get diff --git a/config/rbac/kustomization.yaml b/config/rbac/kustomization.yaml deleted file mode 100644 index c3a52137..00000000 --- a/config/rbac/kustomization.yaml +++ /dev/null @@ -1,29 +0,0 @@ -resources: -# All RBAC will be applied under this service account in -# the deployment namespace. You may comment out this resource -# if your manager will use a service account that exists at -# runtime. Be sure to update RoleBinding and ClusterRoleBinding -# subjects if changing service account names. -- service_account.yaml -- role.yaml -- role_binding.yaml -- leader_election_role.yaml -- leader_election_role_binding.yaml -# The following RBAC configurations are used to protect -# the metrics endpoint with authn/authz. These configurations -# ensure that only authorized users and service accounts -# can access the metrics endpoint. Comment the following -# permissions if you want to disable this protection. -# More info: https://book.kubebuilder.io/reference/metrics.html -- metrics_auth_role.yaml -- metrics_auth_role_binding.yaml -- metrics_reader_role.yaml -# For each CRD, "Editor" and "Viewer" roles are scaffolded by -# default, aiding admins in cluster management. Those roles are -# not used by the Project itself. You can comment the following lines -# if you do not want those helpers be installed with your Project. -- inferencemodel_editor_role.yaml -- inferencemodel_viewer_role.yaml -- inferencepool_editor_role.yaml -- inferencepool_viewer_role.yaml - diff --git a/config/rbac/leader_election_role.yaml b/config/rbac/leader_election_role.yaml deleted file mode 100644 index e2f8551b..00000000 --- a/config/rbac/leader_election_role.yaml +++ /dev/null @@ -1,40 +0,0 @@ -# permissions to do leader election. -apiVersion: rbac.authorization.k8s.io/v1 -kind: Role -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: leader-election-role -rules: -- apiGroups: - - "" - resources: - - configmaps - verbs: - - get - - list - - watch - - create - - update - - patch - - delete -- apiGroups: - - coordination.k8s.io - resources: - - leases - verbs: - - get - - list - - watch - - create - - update - - patch - - delete -- apiGroups: - - "" - resources: - - events - verbs: - - create - - patch diff --git a/config/rbac/leader_election_role_binding.yaml b/config/rbac/leader_election_role_binding.yaml deleted file mode 100644 index fb71a122..00000000 --- a/config/rbac/leader_election_role_binding.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: RoleBinding -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: leader-election-rolebinding -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: Role - name: leader-election-role -subjects: -- kind: ServiceAccount - name: controller-manager - namespace: system diff --git a/config/rbac/metrics_auth_role.yaml b/config/rbac/metrics_auth_role.yaml deleted file mode 100644 index 32d2e4ec..00000000 --- a/config/rbac/metrics_auth_role.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: metrics-auth-role -rules: -- apiGroups: - - authentication.k8s.io - resources: - - tokenreviews - verbs: - - create -- apiGroups: - - authorization.k8s.io - resources: - - subjectaccessreviews - verbs: - - create diff --git a/config/rbac/metrics_auth_role_binding.yaml b/config/rbac/metrics_auth_role_binding.yaml deleted file mode 100644 index e775d67f..00000000 --- a/config/rbac/metrics_auth_role_binding.yaml +++ /dev/null @@ -1,12 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - name: metrics-auth-rolebinding -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: metrics-auth-role -subjects: -- kind: ServiceAccount - name: controller-manager - namespace: system diff --git a/config/rbac/metrics_reader_role.yaml b/config/rbac/metrics_reader_role.yaml deleted file mode 100644 index 51a75db4..00000000 --- a/config/rbac/metrics_reader_role.yaml +++ /dev/null @@ -1,9 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - name: metrics-reader -rules: -- nonResourceURLs: - - "/metrics" - verbs: - - get diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml deleted file mode 100644 index 9d6247eb..00000000 --- a/config/rbac/role.yaml +++ /dev/null @@ -1,11 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRole -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: manager-role -rules: -- apiGroups: [""] - resources: ["pods"] - verbs: ["get", "list", "watch"] diff --git a/config/rbac/role_binding.yaml b/config/rbac/role_binding.yaml deleted file mode 100644 index c66b66bf..00000000 --- a/config/rbac/role_binding.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: rbac.authorization.k8s.io/v1 -kind: ClusterRoleBinding -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: manager-rolebinding -roleRef: - apiGroup: rbac.authorization.k8s.io - kind: ClusterRole - name: manager-role -subjects: -- kind: ServiceAccount - name: controller-manager - namespace: system diff --git a/config/rbac/service_account.yaml b/config/rbac/service_account.yaml deleted file mode 100644 index 9286120f..00000000 --- a/config/rbac/service_account.yaml +++ /dev/null @@ -1,8 +0,0 @@ -apiVersion: v1 -kind: ServiceAccount -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: controller-manager - namespace: system diff --git a/config/samples/gateway_v1alpha1_inferencemodel.yaml b/config/samples/gateway_v1alpha1_inferencemodel.yaml deleted file mode 100644 index 34ea0680..00000000 --- a/config/samples/gateway_v1alpha1_inferencemodel.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: inference.networking.x-k8s.io/v1alpha1 -kind: InferenceModel -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: sample-sql-assist -spec: - criticality: Critical - modelName: sql-code-assist - poolRef: - name: vllm-llama-31-8b-sample-pool - targetModels: - - name: npc-bot-v1 - weight: 50 - - name: npc-bot-v2 - weight: 50 diff --git a/config/samples/gateway_v1alpha1_inferencepool.yaml b/config/samples/gateway_v1alpha1_inferencepool.yaml deleted file mode 100644 index 4993d786..00000000 --- a/config/samples/gateway_v1alpha1_inferencepool.yaml +++ /dev/null @@ -1,11 +0,0 @@ -apiVersion: inference.networking.x-k8s.io/v1alpha1 -kind: InferencePool -metadata: - labels: - app.kubernetes.io/name: api - app.kubernetes.io/managed-by: kustomize - name: vllm-llama-31-8b-sample-pool -spec: - selector: - app: npc-bot - targetPortNumber: 8000 diff --git a/config/samples/kustomization.yaml b/config/samples/kustomization.yaml deleted file mode 100644 index e4b9f2e8..00000000 --- a/config/samples/kustomization.yaml +++ /dev/null @@ -1,5 +0,0 @@ -## Append samples of your project ## -resources: -- gateway_v1alpha1_inferencepool.yaml -- gateway_v1alpha1_inferencemodel.yaml -# +kubebuilder:scaffold:manifestskustomizesamples diff --git a/conformance/conformance.go b/conformance/conformance.go new file mode 100644 index 00000000..20d80fde --- /dev/null +++ b/conformance/conformance.go @@ -0,0 +1,230 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package conformance contains the core setup and execution logic +// for the Gateway API Inference Extension conformance test suite. +package conformance + +import ( + "fmt" + "io/fs" + "os" + "testing" + + "github.com/stretchr/testify/require" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + clientset "k8s.io/client-go/kubernetes" + + // Import runtime package for scheme creation + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" + "sigs.k8s.io/yaml" + + // Import necessary types and utilities from the core Gateway API conformance suite. + // Assumes sigs.k8s.io/gateway-api is a dependency in the go.mod. + gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" // Import core Gateway API types + confapis "sigs.k8s.io/gateway-api/conformance/apis/v1" // Report struct definition + confconfig "sigs.k8s.io/gateway-api/conformance/utils/config" + confflags "sigs.k8s.io/gateway-api/conformance/utils/flags" + confsuite "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/pkg/features" // Using core features definitions if applicable + + // Import the test definitions package to access the ConformanceTests slice + "sigs.k8s.io/gateway-api-inference-extension/conformance/tests" + + // Import test packages using blank identifier + // This triggers the init() functions in these packages, which register the tests + // by appending them to the tests.ConformanceTests slice. + _ "sigs.k8s.io/gateway-api-inference-extension/conformance/tests/basic" + // TODO: Add blank imports for other test categories as they are created. + // _ "sigs.k8s.io/gateway-api-inference-extension/conformance/tests/model_routing" + + // Import the Inference Extension API types + inferencev1alpha2 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" +) + +// GatewayLayerProfileName defines the name for the conformance profile that tests +// the Gateway API layer aspects of the Inference Extension (e.g., InferencePool, InferenceModel CRDs). +// Future profiles will cover EPP and ModelServer layers. +const GatewayLayerProfileName confsuite.ConformanceProfileName = "Gateway" + +var InferenceCoreFeatures = sets.New[features.FeatureName]() // Placeholder - Populate with actual features specific to this profile or manage features per profile + +// GatewayLayerProfile defines the conformance profile for the Gateway API layer +// of the Inference Extension. +// In future iterations, we will add constants and ConformanceProfile structs for +// EPPProfileName ("EPP") and ModelServerProfileName ("ModelServer") +// to cover their respective conformance layers. +var GatewayLayerProfile = confsuite.ConformanceProfile{ + Name: GatewayLayerProfileName, + CoreFeatures: InferenceCoreFeatures, +} + +// DefaultOptions parses command line flags and sets up the suite options. +// Adapted from the core Gateway API conformance suite. +func DefaultOptions(t *testing.T) confsuite.ConformanceOptions { + t.Helper() + + cfg, err := config.GetConfig() + require.NoError(t, err, "error loading Kubernetes config") + + // Initialize client options. The scheme must include Gateway API types + // and the Inference Extension types. + clientOptions := client.Options{} + scheme := clientOptions.Scheme + if scheme == nil { + // If default options don't provide a scheme, create one using runtime.NewScheme(). + scheme = runtime.NewScheme() + clientOptions.Scheme = scheme + } + + // Register necessary API Types + require.NoError(t, gatewayv1.Install(scheme)) // Add core Gateway API types + // Add the Inference Extension API types to the scheme using the correct import alias + require.NoError(t, inferencev1alpha2.Install(scheme)) + require.NoError(t, apiextensionsv1.AddToScheme(scheme)) // Needed for CRD checks + + // Create the Kubernetes clients + c, err := client.New(cfg, clientOptions) + require.NoError(t, err, "error initializing Kubernetes client") + cs, err := clientset.NewForConfig(cfg) + require.NoError(t, err, "error initializing Kubernetes clientset") + + exemptFeatures := confsuite.ParseSupportedFeatures(*confflags.ExemptFeatures) + skipTests := confsuite.ParseSkipTests(*confflags.SkipTests) + // Initially, run the GatewayLayerProfile. This will expand as other profiles + // (EPP, ModelServer) are added and can be selected via flags in future iterations. + conformanceProfiles := sets.New(GatewayLayerProfileName) + + // Implementation details from flags + implementation := confsuite.ParseImplementation( + *confflags.ImplementationOrganization, + *confflags.ImplementationProject, + *confflags.ImplementationURL, + *confflags.ImplementationVersion, + *confflags.ImplementationContact, + ) + + // Inference Extension Specific Report Fields + inferenceExtensionVersion := "v0.3.0" + _ = inferenceExtensionVersion // Avoid unused variable error until implemented + + // Create ConformanceOptions + opts := confsuite.ConformanceOptions{ + Client: c, + Clientset: cs, + RestConfig: cfg, + GatewayClassName: *confflags.GatewayClassName, + Debug: *confflags.ShowDebug, + CleanupBaseResources: *confflags.CleanupBaseResources, + SupportedFeatures: sets.New[features.FeatureName](), // Initialize empty, will be populated below + TimeoutConfig: confconfig.DefaultTimeoutConfig(), + SkipTests: skipTests, + ExemptFeatures: exemptFeatures, + RunTest: *confflags.RunTest, + Mode: *confflags.Mode, + Implementation: implementation, + ConformanceProfiles: conformanceProfiles, + ManifestFS: []fs.FS{&Manifests}, // Assumes embed.go defines `Manifests` + ReportOutputPath: *confflags.ReportOutput, + SkipProvisionalTests: *confflags.SkipProvisionalTests, + // TODO: Add the inference extension specific fields to ConformanceOptions struct if needed, + // or handle them during report generation. + // GatewayAPIInferenceExtensionChannel: inferenceExtensionChannel, + // GatewayAPIInferenceExtensionVersion: inferenceExtensionVersion, + } + + // Populate SupportedFeatures based on the GatewayLayerProfile. + // Since all features are mandatory for this profile, add all defined core features. + if opts.ConformanceProfiles.Has(GatewayLayerProfileName) { + for feature := range GatewayLayerProfile.CoreFeatures { + opts.SupportedFeatures.Insert(feature) + } + } + + // Remove any features explicitly exempted via flags. + for feature := range opts.ExemptFeatures { + opts.SupportedFeatures.Delete(feature) + } + + return opts +} + +// RunConformance runs the Inference Extension conformance tests using default options. +func RunConformance(t *testing.T) { + RunConformanceWithOptions(t, DefaultOptions(t)) +} + +// RunConformanceWithOptions runs the Inference Extension conformance tests with specific options. +func RunConformanceWithOptions(t *testing.T, opts confsuite.ConformanceOptions) { + t.Logf("Running Inference Extension conformance tests with GatewayClass %s", opts.GatewayClassName) + + // Register the GatewayLayerProfile with the suite runner. + // In the future, other profiles (EPP, ModelServer) will also be registered here, + // and the suite runner will execute tests based on the selected profiles. + confsuite.RegisterConformanceProfile(GatewayLayerProfile) + + // Initialize the test suite. + cSuite, err := confsuite.NewConformanceTestSuite(opts) + require.NoError(t, err, "error initializing conformance suite") + + t.Log("Setting up Inference Extension conformance tests") + // Setup requires the list of tests, which is populated by the init() functions + // triggered by the blank imports at the top of this file. + cSuite.Setup(t, tests.ConformanceTests) + + t.Log("Running Inference Extension conformance tests") + // Run the tests. + err = cSuite.Run(t, tests.ConformanceTests) + require.NoError(t, err, "error running conformance tests") + + // Generate and write the report if requested. + if opts.ReportOutputPath != "" { + t.Log("Generating Inference Extension conformance report") + report, err := cSuite.Report() // Use the existing report generation logic. + require.NoError(t, err, "error generating conformance report") + + // TODO: Modify the report struct here if channel, version need to be modified. + // Example (requires adding fields to confapis.ConformanceReport): + // report.GatewayAPIInferenceExtensionChannel = opts.GatewayAPIInferenceExtensionChannel + // report.GatewayAPIInferenceExtensionVersion = opts.GatewayAPIInferenceExtensionVersion + + err = writeReport(t.Logf, *report, opts.ReportOutputPath) + require.NoError(t, err, "error writing conformance report") + } +} + +// writeReport writes the generated conformance report to the specified output file or logs it. +// Adapted from the core Gateway API suite. +func writeReport(logf func(string, ...any), report confapis.ConformanceReport, output string) error { + rawReport, err := yaml.Marshal(report) + if err != nil { + return fmt.Errorf("error marshaling report: %w", err) + } + + if output != "" { + if err = os.WriteFile(output, rawReport, 0o600); err != nil { + return fmt.Errorf("error writing report file %s: %w", output, err) + } + logf("Conformance report written to %s", output) + } else { + // Log the report YAML to stdout if no output file is specified. + logf("Conformance report:\n%s", string(rawReport)) + } + return nil +} diff --git a/pkg/epp/scheduling/types.go b/conformance/conformance_test.go similarity index 60% rename from pkg/epp/scheduling/types.go rename to conformance/conformance_test.go index 29e6648d..de82d5ec 100644 --- a/pkg/epp/scheduling/types.go +++ b/conformance/conformance_test.go @@ -14,14 +14,16 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package conformance -// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. -type LLMRequest struct { - Model string - // Target models is a map of target model name to weight. - TargetModels map[string]int - // Resolved target model is the final target model after traffic split. - ResolvedTargetModel string - Critical bool +import ( + "testing" +) + +// TestConformance is the top-level function that runs the conformance tests. +// It calls the RunConformance function which sets up the suite and executes +// the registered tests. +func TestConformance(t *testing.T) { + // RunConformance is defined in conformance.go + RunConformance(t) } diff --git a/conformance/embed.go b/conformance/embed.go new file mode 100644 index 00000000..f7fa64c9 --- /dev/null +++ b/conformance/embed.go @@ -0,0 +1,25 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package conformance + +import "embed" + +// Manifests embeds the contents of the conformance/resources directory making +// the YAML files within them available to the test suite at runtime. +// +//go:embed resources/* tests/* +var Manifests embed.FS diff --git a/conformance/reports/README.md b/conformance/reports/README.md new file mode 100644 index 00000000..81652b1c --- /dev/null +++ b/conformance/reports/README.md @@ -0,0 +1,93 @@ +# Conformance Reports for Gateway API Inference Extension + +This directory stores conformance reports submitted by various implementations of the Gateway API Inference Extension. This structure closely follows the [kubernetes-sigs/gateway-api/conformance/reports](https://github.com/kubernetes-sigs/gateway-api/blob/main/conformance/reports/README.md). + +## How this folder is structured + +This folder stores conformance reports organized first by the version of the Gateway API Inference Extension specification they were tested against, and then by the specific conformance profile (e.g., Gateway, EPP, Model Server): + +|-- conformance/reports +| |-- v0.3.0 # Example extension version +| | |-- gateway # Conformance profile/category +| | | |-- my-inference-gateway +| | | | |-- README.md +| | | | |-- experimental-v1.2.3-default-gateway-report.yaml # Example report file +| | | |-- another-implementation +| | | | |-- README.md +| | | | |-- ... +| | |-- epp # Future conformance profile/category +| | | |-- my-epp-implementation +| | | | |-- ... +| | |-- model-server # Future conformance profile/category +| | | |-- ... +| |-- v0.4.0 # Future extension version +| | |-- ... + +## Implementation Submissions + +Each implementation conformant with a specific profile of a specific version of the Gateway API Inference Extension should have its own folder within the corresponding version and profile directory (e.g., `/conformance/reports/v0.3.0/Gateway/my-implementation/`). + +The implementation is the owner of its folder and is responsible for: + +1. Uploading one or more conformance reports (YAML files). +2. Maintaining a mandatory `README.md` file within their folder, structured as follows: + + # My Inference Gateway Implementation (Gateway Profile Conformance) + + General information about the My/Implementation project. + + ## Table of Contents + +| Extension Version Tested | Profile Tested | Implementation Version | Mode | Report | +|--------------------------|----------------|------------------------|---------|----------------------------------------------------------------------------| +| v0.3.0 | Gateway | v1.2.3 | default | [v1.2.3 Gateway report](./experimental-v1.2.3-default-gateway-report.yaml) | +| ... | ... | ... | ... | ... | + + ## Reproduce + + Instructions on how to reproduce the claimed report(s). + +### Table of Contents (within Implementation README) + +The table of contents within an implementation's `README.md` should contain one row for each submitted report and include the following columns: + +* **Extension Version Tested**: The version of the Gateway API Inference Extension specification tested against (e.g., `v0.3.0`). Must correspond to the `gatewayAPIInferenceExtensionVersion` field in the report. +* **Profile Tested**: The specific conformance profile tested (e.g., `Gateway`, `EPP`, `ModelServer`). Must correspond to the `name` of the profile in the `profiles` list within the report. +* **Implementation Version**: A link to the GitHub/website page for the specific release/commit of the implementation tested. The version value MUST correspond to the `implementation.version` field in the report. +* **Mode**: The operating mode of the implementation used for the test run (default is `default`). Must correspond to the `mode` field in the report. If a mode other than `default` is used, the "Reproduce" section must explain how to configure it. +* **Report**: A link to the corresponding report YAML file. Reports MUST be named according to the pattern: `---report.yaml` (e.g., `experimental-v1.2.3-default-gateway-report.yaml`). + +### Reproduce Section (within Implementation README) + +This section MUST exist and contain the manual or automatic steps required to reproduce the results claimed by the uploaded conformance reports for that specific implementation. If reproduction steps differ significantly between implementation versions, use sub-sections. + +## Report Files + +Conformance reports MUST be uploaded exactly as generated by the official Gateway API Inference Extension conformance test suite, without any modifications. The "Reproduce" section allows for verification of the submitted report against a fresh run. + +### Report Rules + +To be accepted, submitted conformance reports must comply with the following rules: + +1. **Implementation Details:** All fields within the `implementation` block must have meaningful values: + * `organization`: The entity maintaining the implementation (company, open source org, individual). + * `project`: The name of the implementation project, unique within the organization. + * `url`: A valid URL for the project (e.g., GitHub repository, product page). + * `version`: A specific, reproducible snapshot of the implementation (e.g., tag, commit hash, release version). Branch names are not acceptable. + * `contact`: A list of contact points (GitHub handles like `@maintainer`, team handles like `@org/team`, email addresses, or support URLs like an issue tracker). +2. **Inference Extension Versioning:** The report MUST include: + * `gatewayAPIInferenceExtensionVersion`: The specific version of the Gateway API Inference Extension specification tested against (e.g., `v0.3.0`). +3. **Mode:** The `mode` field indicates the implementation's operating mode during the test run. +4. **Test Profile & Result:** + * The report MUST contain exactly one profile result under the `profiles` list for the specific conformance category being submitted (e.g., a report for "Gateway" conformance should only contain the "Gateway" profile result). + * The profile's `name` MUST match the conformance category (e.g., `Gateway`, `EPP`, `ModelServer`). + * The profile's `result` field MUST be `success`. A `success` result indicates that **all** tests defined within the Gateway API Inference Extension conformance suite for that specific profile and version passed. + +## Submission Process + +Conformance reports demonstrating a `success` result for a specific profile (e.g., `Gateway`) should be submitted via Pull Request directly to this repository (`kubernetes-sigs/gateway-api-inference-extension`). + +1. Create a new folder structure under `/conformance/reports///` named after your implementation (e.g., `/conformance/reports/v0.3.0/Gateway/my-implementation/`). +2. Add your implementation's `README.md` to this folder, following the structure described above. +3. Add your generated conformance report YAML file(s) to this folder, ensuring they follow the naming convention `---report.yaml`. +4. Submit the Pull Request. diff --git a/conformance/resources/manifests/manifests.yaml b/conformance/resources/manifests/manifests.yaml new file mode 100644 index 00000000..7b43b784 --- /dev/null +++ b/conformance/resources/manifests/manifests.yaml @@ -0,0 +1,49 @@ +# Base Kubernetes resources for the Gateway API Inference Extension conformance tests. +# This includes namespaces and a minimal set of resources (Gateway, Backend) +# required by many tests. More specific resources should be defined within +# individual test files or other resource directories (e.g., sample_backends). + +--- +# Namespace for core infrastructure like Gateways. +apiVersion: v1 +kind: Namespace +metadata: + name: gateway-conformance-infra + labels: + gateway-conformance: infra + +--- +# Namespace for application backends (potentially simulating model servers +# or where InferencePools might reside in some tests). +apiVersion: v1 +kind: Namespace +metadata: + name: gateway-conformance-app-backend + labels: + gateway-conformance: backend + +--- +# A basic Gateway resource that allows HTTPRoutes from the same namespace. +# Tests can use this as a parent reference for routes that target InferencePools. +# Using a simple echo server instead of an actual model server to simplify the test +# execution, this design may need to be revised based on the test case needs. +apiVersion: gateway.networking.k8s.io/v1 # Using v1 as per latest Gateway API standard +kind: Gateway +metadata: + name: same-namespace + namespace: gateway-conformance-infra +spec: + # The conformance suite runner will replace this placeholder + # with the actual GatewayClass name provided via flags. + gatewayClassName: "{GATEWAY_CLASS_NAME}" + listeners: + - name: http # Standard listener name + port: 80 + protocol: HTTP + allowedRoutes: + namespaces: + from: Same # Restrict to same namespace initially for simplicity + kinds: + # Allows HTTPRoutes to attach, which can then reference InferencePools. + - group: gateway.networking.k8s.io + kind: HTTPRoute diff --git a/conformance/tests/basic/inferencepool_accepted.go b/conformance/tests/basic/inferencepool_accepted.go new file mode 100644 index 00000000..eae59404 --- /dev/null +++ b/conformance/tests/basic/inferencepool_accepted.go @@ -0,0 +1,60 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package basic + +import ( + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" // For standard condition types + "sigs.k8s.io/gateway-api/conformance/utils/suite" + "sigs.k8s.io/gateway-api/pkg/features" // For standard feature names + + // Import the tests package to append to ConformanceTests + "sigs.k8s.io/gateway-api-inference-extension/conformance/tests" + infrakubernetes "sigs.k8s.io/gateway-api-inference-extension/conformance/utils/kubernetes" +) + +func init() { + // Register the InferencePoolAccepted test case with the conformance suite. + // This ensures it will be discovered and run by the test runner. + tests.ConformanceTests = append(tests.ConformanceTests, InferencePoolAccepted) +} + +// InferencePoolAccepted defines the test case for verifying basic InferencePool acceptance. +var InferencePoolAccepted = suite.ConformanceTest{ + ShortName: "InferencePoolAccepted", + Description: "A minimal InferencePool resource should be accepted by the controller and report an Accepted condition", + Manifests: []string{"tests/basic/inferencepool_accepted.yaml"}, + Features: []features.FeatureName{}, + Test: func(t *testing.T, s *suite.ConformanceTestSuite) { + // created by the associated manifest file. + poolNN := types.NamespacedName{Name: "inferencepool-basic-accepted", Namespace: "gateway-conformance-app-backend"} + + t.Run("InferencePool should have Accepted condition set to True", func(t *testing.T) { + // Define the expected status condition. We use the standard "Accepted" + // condition type from the Gateway API for consistency. + acceptedCondition := metav1.Condition{ + Type: string(gatewayv1.GatewayConditionAccepted), // Standard condition type + Status: metav1.ConditionTrue, + Reason: "", // "" means we don't strictly check the Reason for this basic test. + } + infrakubernetes.InferencePoolMustHaveCondition(t, s.Client, s.TimeoutConfig, poolNN, acceptedCondition) + }) + }, +} diff --git a/conformance/tests/basic/inferencepool_accepted.yaml b/conformance/tests/basic/inferencepool_accepted.yaml new file mode 100644 index 00000000..8ae327d8 --- /dev/null +++ b/conformance/tests/basic/inferencepool_accepted.yaml @@ -0,0 +1,27 @@ +# Basic InferencePool for acceptance testing. +# This manifest defines the minimal required fields to create a valid +# InferencePool resource, which the InferencePoolAccepted test will use +# to verify that the controller recognizes and accepts the resource. + +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + # This name must match the 'poolNN' variable defined in the + # conformance/tests/basic/inferencepool_accepted.go test file. + name: inferencepool-basic-accepted + # This namespace should be one created by the base manifests. + namespace: gateway-conformance-app-backend +spec: + # --- Selector (Required) --- + # Selects the Pods belonging to this pool. + selector: + app: "infra-backend-v1" + + # --- Target Port (Required) --- + # The port the model server container listens on. + targetPortNumber: 3000 + + # --- Extension Reference --- + # GKE-specific configuration reference. + extensionRef: + name: infra-backend-v1-epp diff --git a/conformance/tests/main.go b/conformance/tests/main.go new file mode 100644 index 00000000..fc66c765 --- /dev/null +++ b/conformance/tests/main.go @@ -0,0 +1,35 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package tests is the root package for all Gateway API Inference Extension +// conformance test implementations. +package tests + +import ( + // Importing the suite package to access the ConformanceTest struct definition. + // For initial version directly importing from the core gateway-api repo. + // This may be adjusted in the future if we have need to create a copy of + // the suite utilities. + "sigs.k8s.io/gateway-api/conformance/utils/suite" + // Do NOT add blank imports for specific test packages here. + // They should be added to the main conformance package instead + // to avoid import cycles. +) + +// ConformanceTests holds all the conformance tests definitions for the +// Gateway API Inference Extension suite. Tests are registered from other packages +// using init() functions like the one in the basic package. +var ConformanceTests []suite.ConformanceTest diff --git a/conformance/utils/assertions.go b/conformance/utils/assertions.go new file mode 100644 index 00000000..c77d0fc5 --- /dev/null +++ b/conformance/utils/assertions.go @@ -0,0 +1,25 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package assertions contains custom assertion helper functions used within +// the Gateway API Inference Extension conformance test suite. +package assertions + +// TODO: Implement custom assertion functions specific to Inference Extension testing. +// Examples might include: +// - Asserting specific fields or structures within an inference API response body. +// - Asserting specific metrics reported by mock model servers or EPPs. +// - Asserting specific conditions or status fields unique to InferencePool or InferenceModel. diff --git a/conformance/utils/kubernetes/helpers.go b/conformance/utils/kubernetes/helpers.go new file mode 100644 index 00000000..3d517863 --- /dev/null +++ b/conformance/utils/kubernetes/helpers.go @@ -0,0 +1,49 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package kubernetes contains helper functions for interacting with +// Kubernetes objects within the conformance test suite. +package kubernetes + +import ( + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + + // Import necessary utilities from the core Gateway API conformance suite + "sigs.k8s.io/gateway-api/conformance/utils/config" +) + +// InferencePoolMustHaveCondition waits for the specified InferencePool resource +// to exist and report the expected status condition. +// This is a placeholder and needs full implementation. +// +// TODO: Implement the actual logic for this helper function. +// It should fetch the InferencePool using the provided client and check its +// Status.Conditions field, polling until the condition is met or a timeout occurs. +// like HTTPRouteMustHaveCondition. +func InferencePoolMustHaveCondition(t *testing.T, c client.Client, timeoutConfig config.TimeoutConfig, poolNN types.NamespacedName, expectedCondition metav1.Condition) { + t.Helper() // Marks this function as a test helper + + // Placeholder implementation: Log and skip the check. + t.Logf("Verification for InferencePool condition (%s=%s) on %s - Placeholder: Skipping check.", + expectedCondition.Type, expectedCondition.Status, poolNN.String()) + + // Skip the test using this helper until it's fully implemented. + t.Skip("InferencePoolMustHaveCondition helper not yet implemented") +} diff --git a/conformance/utils/traffic/traffic.go b/conformance/utils/traffic/traffic.go new file mode 100644 index 00000000..4f13f980 --- /dev/null +++ b/conformance/utils/traffic/traffic.go @@ -0,0 +1,22 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package traffic contains helper functions specifically for generating, +// sending, and validating network traffic related to inference workloads +// within the Gateway API Inference Extension conformance tests. +package traffic + +// TODO: Add helpers for specific inference protocols or request patterns as needed. diff --git a/docs/proposals/0683-epp-architecture-proposal/README.md b/docs/proposals/0683-epp-architecture-proposal/README.md new file mode 100644 index 00000000..48c7720f --- /dev/null +++ b/docs/proposals/0683-epp-architecture-proposal/README.md @@ -0,0 +1,99 @@ +# Gateway API Inference Extension + +Author(s): @kfswain +## Proposal Status + ***Draft*** + +## Table of Contents + + + +- [Summary](#summary) +- [Goals](#goals) +- [Non-Goals](#non-goals) +- [Proposal](#proposal) + - [Personas](#personas) + - [Inference Platform Admin](#inference-platform-admin) + - [Inference Workload Owner](#workload-owner) + - [Axioms](#axioms) + - [InferencePool](#inferencepool) + - [InferenceModel](#inferencemodel) + - [Spec](#spec) + - [Diagrams](#diagrams) + - [Alternatives](#alternatives) +- [Open Questions](#open-questions) + + + +## Summary + +This proposal seeks to standardize the implementation of an EPP (End-point Picker) for the Inference Gateway extension (also known as Gateway API Inference Extension). Additionally, this proposes to restructure the current implementation of the EPP to be more modular, and approachable. + +## Goals + +- Set a standard on how the EPP & APIs interact +- Settle on common nomenclature for clearer communication +- Allow for modularization of the EPP, to be extended to a user's specific needs + +## Non-Goals + +- Reshaping the current API +- A change in scope of the current project + +## Proposal + +This proposal is not proposing any net new features, instead, we are refactoring our current implementation to better handle more devs, more features, etc. At the time of writing, GIE is currently at v0.3, and that stronger experimental context (along with external feedback) made clear the need this restructure. The image below give a high level view of how our components work together. + +Scheduling Algorithm + +## Overview +At a quick glance, the EPP is being broken into specific layers. The `Data Layer` is of note, as it is a vertical that will be accessed by all the others. The data layer manages the k8s, data, metric & usage data, as well as processing of the above data to determine resource scarcity regimes. + +The other layers are handled in sequential process. Starting with the **Ext-Proc** call. The request is buffered and then sent to the **Routing Layer**, which processes any User defined per-InferenceModel routing rules & request enrichment happening first (at the time of writing that is currently just translating the InferenceModel name to a weight-split actual model). Then _all_ requests pass through the to-be-implemented [**Flow Controller**](https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/674) to ensure that any request entry to the pool adhereing to the guidelines set by the Priority, Fairness, & Queueing configuration. And finally, the **Scheduling Layer** is the load balancing algorithm that intelligently routes requests based on the current state of the InferencePool. + +## Components + +To further expand upon these component layers. We will first break them into `extensible` and `non-extensible` layers. `Non-extensible` layers are intended to be static, and handled on behalf of the user, typically implementing low-opinion infrastructure. + +The `Extensible` layers are: +- Data Layer +- Routing Layer +- Flow Controller +- Scheduling Layer + +The `Non-Extensible` layer(s) are: +- The Ext-Proc Server + +### `Extensible` + +#### Data Layer + +The data layer will consume and store: the InferencePool/InferenceModel config and the pre-defined [Model Server Protocol](../003-model-server-protocol/README.md). Additionally, the data fed from the model servers will be processed and digested to provide resource scarcity regime hints, and autoscaling reccomendations. + +Many extensions to scheduling will require changes to ingested metrics, as such, the data layer will be built to be extended, but extenders accept that the Model Server Protocol will no longer provide guarantees on portability of a model server out of the box. + +#### Routing Layer + +The routing layer is likely to be the most opinion heavy section, as the scope of what constitutes a 'Route Rule' is somewhat broad. The current examples we expect would be: + +- System Prompt injection +- RAG callout +- Per-InferenceModel request validation (such as saftey/on-topic, etc) + +Due to the possibility of this becoming a bit of a dumping ground. The API will keep a _very_ tight scope on which of these route rules are included in the spec. A standard method of extension will be provided if the need to define a custom rule arises. + +#### Flow Controller (WIP - implementation tracked in [#674](https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/674)) + +The flow controller will consume resource regime data, and enforce proper resource sharing between workloads. This will primarily be done through a queuing mechanism [as described here](https://docs.google.com/document/d/1VZL7opFWuwgWquvgiOzLlXAJ633qZ9U-A0ZixGjBgaI/edit?usp=sharing). + +#### Scheduling Layer + +As the Scheduling Layer is the final interface to the entirety of the pool, all configuration will be at the _pool_ level. The default scheduling layer will be an experimentally-backed LB algorithm, with exposed config values. + +The Scheduler will define a strong interface API, so that new scheduling algos may be plugged & dark-launched to test in production traffic without impacting said traffic. Extension is expected to adhere to the [Scheduler Subsystem definition](https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/603) + +### `Non-extensible` + +#### Ext-Proc Server + +The Ext-Proc Server protocol is very well defined & specific, deviation could cause the EPP to become unusable or unstable. Extension is ill-advised. diff --git a/docs/proposals/0683-epp-architecture-proposal/images/epp_arch.svg b/docs/proposals/0683-epp-architecture-proposal/images/epp_arch.svg new file mode 100644 index 00000000..4c585728 --- /dev/null +++ b/docs/proposals/0683-epp-architecture-proposal/images/epp_arch.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/proposals/README.md b/docs/proposals/README.md new file mode 100644 index 00000000..2b0408d3 --- /dev/null +++ b/docs/proposals/README.md @@ -0,0 +1,5 @@ +# Proposals Best Practices + + +## Naming +The directory of the proposal should lead with a 4-digit PR number (will move to 5,6,... should our PR count get that high), followed by kebab-cased title. The PR number is not known until the PR is cut, so development can use a placeholder, ex. XXXX-my-proposal. PR number is used b/c it is unique & chronological, allowing the default ordering of proposals to follow the timeline of development. \ No newline at end of file diff --git a/go.mod b/go.mod index fba85f91..30d0487e 100644 --- a/go.mod +++ b/go.mod @@ -7,24 +7,25 @@ require ( github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/go-logr/logr v1.4.2 github.com/google/go-cmp v0.7.0 - github.com/onsi/ginkgo/v2 v2.23.3 - github.com/onsi/gomega v1.36.3 - github.com/prometheus/client_golang v1.21.1 - github.com/prometheus/client_model v0.6.1 + github.com/onsi/ginkgo/v2 v2.23.4 + github.com/onsi/gomega v1.37.0 + github.com/prometheus/client_golang v1.22.0 + github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.63.0 github.com/stretchr/testify v1.10.0 go.uber.org/multierr v1.11.0 go.uber.org/zap v1.27.0 - google.golang.org/grpc v1.71.0 + google.golang.org/grpc v1.71.1 google.golang.org/protobuf v1.36.6 - k8s.io/api v0.32.3 - k8s.io/apiextensions-apiserver v0.32.3 - k8s.io/apimachinery v0.32.3 - k8s.io/client-go v0.32.3 - k8s.io/code-generator v0.32.3 - k8s.io/component-base v0.32.3 + k8s.io/api v0.32.4 + k8s.io/apiextensions-apiserver v0.32.4 + k8s.io/apimachinery v0.32.4 + k8s.io/client-go v0.32.4 + k8s.io/code-generator v0.32.4 + k8s.io/component-base v0.32.4 k8s.io/utils v0.0.0-20241210054802-24370beab758 sigs.k8s.io/controller-runtime v0.20.4 + sigs.k8s.io/gateway-api v1.2.1 sigs.k8s.io/structured-merge-diff/v4 v4.6.0 sigs.k8s.io/yaml v1.4.0 ) @@ -42,17 +43,17 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.0 // indirect + github.com/emicklei/go-restful/v3 v3.12.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect - github.com/fatih/color v1.16.0 // indirect + github.com/fatih/color v1.17.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect - github.com/go-openapi/jsonreference v0.20.2 // indirect + github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/swag v0.23.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.0 // indirect @@ -65,16 +66,15 @@ require ( github.com/google/cel-go v0.22.0 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad // indirect + github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.0 // indirect + github.com/gorilla/websocket v1.5.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/huandu/xstrings v1.3.3 // indirect - github.com/imdario/mergo v0.3.11 // indirect + github.com/imdario/mergo v0.3.16 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.17.11 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.2.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect @@ -104,17 +104,18 @@ require ( go.opentelemetry.io/otel/sdk v1.34.0 // indirect go.opentelemetry.io/otel/trace v1.34.0 // indirect go.opentelemetry.io/proto/otlp v1.3.1 // indirect + go.uber.org/automaxprocs v1.6.0 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 // indirect - golang.org/x/mod v0.23.0 // indirect + golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.37.0 // indirect golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.12.0 // indirect - golang.org/x/sys v0.31.0 // indirect + golang.org/x/sys v0.32.0 // indirect golang.org/x/term v0.30.0 // indirect golang.org/x/text v0.23.0 // indirect golang.org/x/time v0.7.0 // indirect - golang.org/x/tools v0.30.0 // indirect + golang.org/x/tools v0.31.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 // indirect @@ -123,11 +124,11 @@ require ( gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - k8s.io/apiserver v0.32.3 // indirect + k8s.io/apiserver v0.32.4 // indirect k8s.io/gengo/v2 v2.0.0-20240911193312-2b36238f13e9 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f // indirect sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0 // indirect - sigs.k8s.io/controller-tools v0.14.0 // indirect + sigs.k8s.io/controller-tools v0.16.3 // indirect sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect ) diff --git a/go.sum b/go.sum index 2bcff108..6688c578 100644 --- a/go.sum +++ b/go.sum @@ -23,25 +23,24 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySeApCX4GeOjPl9qhRF3QuIZq+Q= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/elastic/crd-ref-docs v0.1.0 h1:Cr5kz89QB3Iuuj7dhAfLMApCrChEGAaIBTxGk/xuRKw= github.com/elastic/crd-ref-docs v0.1.0/go.mod h1:X83mMBdJt05heJUYiS3T0yJ/JkCuliuhSUNav5Gjo/U= -github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= -github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.12.0 h1:y2DdzBAURM29NFF94q6RaY4vjIH1rtwDapwQtU84iWk= +github.com/emicklei/go-restful/v3 v3.12.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= -github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= -github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= +github.com/evanphx/json-patch v5.7.0+incompatible h1:vgGkfT/9f8zE6tvSCe74nfpAVDQ2tG6yudJd8LBksgI= +github.com/evanphx/json-patch v5.7.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM= -github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= -github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= +github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= @@ -55,12 +54,10 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= -github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= -github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= -github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= +github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= +github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= @@ -92,18 +89,18 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad h1:a6HEuzUHeKH6hwfN/ZoQgRgVIWFJljSWa/zetS2WTvg= -github.com/google/pprof v0.0.0-20241210010833-40e02aabc2ad/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= +github.com/google/pprof v0.0.0-20250403155104-27863c87afa6 h1:BHT72Gu3keYf3ZEu2J0b1vyeLSOYI8bm5wbJM/8yDe8= +github.com/google/pprof v0.0.0-20250403155104-27863c87afa6/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= -github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= +github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= -github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= +github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= +github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -112,13 +109,10 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= -github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -151,10 +145,10 @@ github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.23.3 h1:edHxnszytJ4lD9D5Jjc4tiDkPBZ3siDeJJkUZJJVkp0= -github.com/onsi/ginkgo/v2 v2.23.3/go.mod h1:zXTP6xIp3U8aVuXN8ENK9IXRaTjFnpVB9mGmaSRvxnM= -github.com/onsi/gomega v1.36.3 h1:hID7cr8t3Wp26+cYnfcjR6HpJ00fdogN6dqZ1t6IylU= -github.com/onsi/gomega v1.36.3/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= +github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= +github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= +github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= +github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= @@ -162,10 +156,12 @@ github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= -github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= -github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= -github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA98k= github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= @@ -213,6 +209,8 @@ go.opentelemetry.io/otel/trace v1.34.0 h1:+ouXS2V8Rd4hp4580a8q23bg0azF2nI8cqLYnC go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h0M2jHBhmSfRE= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= +go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= @@ -228,8 +226,8 @@ golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0 golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= -golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= +golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= +golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -248,8 +246,8 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= -golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y= golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -263,8 +261,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= -golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= +golang.org/x/tools v0.31.0 h1:0EedkvKDbh+qistFTd0Bcwe/YLh4vHwWEkiI0toFIBU= +golang.org/x/tools v0.31.0/go.mod h1:naFTU+Cev749tSJRXJlna0T3WxKvb1kWEx15xA4SdmQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -277,8 +275,8 @@ google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422 h1: google.golang.org/genproto/googleapis/api v0.0.0-20250106144421-5f5ef82da422/go.mod h1:b6h1vNKhxaSoEI+5jc3PJUCustfli/mRab7295pY7rw= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f h1:OxYkA3wjPsZyBylwymxSHa7ViiW1Sml4ToBrncvFehI= google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go.mod h1:+2Yz8+CLJbIfL9z73EW45avw8Lmge3xVElCP9zEKi50= -google.golang.org/grpc v1.71.0 h1:kF77BGdPTQ4/JZWMlb9VpJ5pa25aqvVqogsxNHHdeBg= -google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= +google.golang.org/grpc v1.71.1 h1:ffsFWr7ygTUscGPI0KKK6TLrGz0476KUvvsbqWK0rPI= +google.golang.org/grpc v1.71.1/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -290,26 +288,25 @@ gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= -gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -k8s.io/api v0.32.3 h1:Hw7KqxRusq+6QSplE3NYG4MBxZw1BZnq4aP4cJVINls= -k8s.io/api v0.32.3/go.mod h1:2wEDTXADtm/HA7CCMD8D8bK4yuBUptzaRhYcYEEYA3k= -k8s.io/apiextensions-apiserver v0.32.3 h1:4D8vy+9GWerlErCwVIbcQjsWunF9SUGNu7O7hiQTyPY= -k8s.io/apiextensions-apiserver v0.32.3/go.mod h1:8YwcvVRMVzw0r1Stc7XfGAzB/SIVLunqApySV5V7Dss= -k8s.io/apimachinery v0.32.3 h1:JmDuDarhDmA/Li7j3aPrwhpNBA94Nvk5zLeOge9HH1U= -k8s.io/apimachinery v0.32.3/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE= -k8s.io/apiserver v0.32.3 h1:kOw2KBuHOA+wetX1MkmrxgBr648ksz653j26ESuWNY8= -k8s.io/apiserver v0.32.3/go.mod h1:q1x9B8E/WzShF49wh3ADOh6muSfpmFL0I2t+TG0Zdgc= -k8s.io/client-go v0.32.3 h1:RKPVltzopkSgHS7aS98QdscAgtgah/+zmpAogooIqVU= -k8s.io/client-go v0.32.3/go.mod h1:3v0+3k4IcT9bXTc4V2rt+d2ZPPG700Xy6Oi0Gdl2PaY= -k8s.io/code-generator v0.32.3 h1:31p2TVzC9+hVdSkAFruAk3JY+iSfzrJ83Qij1yZutyw= -k8s.io/code-generator v0.32.3/go.mod h1:+mbiYID5NLsBuqxjQTygKM/DAdKpAjvBzrJd64NU1G8= -k8s.io/component-base v0.32.3 h1:98WJvvMs3QZ2LYHBzvltFSeJjEx7t5+8s71P7M74u8k= -k8s.io/component-base v0.32.3/go.mod h1:LWi9cR+yPAv7cu2X9rZanTiFKB2kHA+JjmhkKjCZRpI= +k8s.io/api v0.32.4 h1:kw8Y/G8E7EpNy7gjB8gJZl3KJkNz8HM2YHrZPtAZsF4= +k8s.io/api v0.32.4/go.mod h1:5MYFvLvweRhyKylM3Es/6uh/5hGp0dg82vP34KifX4g= +k8s.io/apiextensions-apiserver v0.32.4 h1:IA+CoR63UDOijR/vEpow6wQnX4V6iVpzazJBskHrpHE= +k8s.io/apiextensions-apiserver v0.32.4/go.mod h1:Y06XO/b92H8ymOdG1HlA1submf7gIhbEDc3RjriqZOs= +k8s.io/apimachinery v0.32.4 h1:8EEksaxA7nd7xWJkkwLDN4SvWS5ot9g6Z/VZb3ju25I= +k8s.io/apimachinery v0.32.4/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE= +k8s.io/apiserver v0.32.4 h1:Yf7sd/y+GOQKH1Qf6wUeayZrYXe2SKZ17Bcq7VQM5HQ= +k8s.io/apiserver v0.32.4/go.mod h1:JFUMNtE2M5yqLZpIsgCb06SkVSW1YcxW1oyLSTfjXR8= +k8s.io/client-go v0.32.4 h1:zaGJS7xoYOYumoWIFXlcVrsiYioRPrXGO7dBfVC5R6M= +k8s.io/client-go v0.32.4/go.mod h1:k0jftcyYnEtwlFW92xC7MTtFv5BNcZBr+zn9jPlT9Ic= +k8s.io/code-generator v0.32.4 h1:d4dm/43RD6xhPBX22JgJw9JUpwTKzVR6tAxJD7pz83o= +k8s.io/code-generator v0.32.4/go.mod h1:R0bKdIg1smtvsKvj9q7SxTeKq5X9ko6PuICCGt4yqxg= +k8s.io/component-base v0.32.4 h1:HuF+2JVLbFS5GODLIfPCb1Td6b+G2HszJoArcWOSr5I= +k8s.io/component-base v0.32.4/go.mod h1:10KloJEYw1keU/Xmjfy9TKJqUq7J2mYdiD1VDXoco4o= k8s.io/gengo/v2 v2.0.0-20240911193312-2b36238f13e9 h1:si3PfKm8dDYxgfbeA6orqrtLkvvIeH8UqffFJDl0bz4= k8s.io/gengo/v2 v2.0.0-20240911193312-2b36238f13e9/go.mod h1:EJykeLsmFC60UQbYJezXkEsG2FLrt0GPNkU5iK5GWxU= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= @@ -322,8 +319,10 @@ sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0 h1:CPT0ExVicCzcp sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.0/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= sigs.k8s.io/controller-runtime v0.20.4 h1:X3c+Odnxz+iPTRobG4tp092+CvBU9UK0t/bRf+n0DGU= sigs.k8s.io/controller-runtime v0.20.4/go.mod h1:xg2XB0K5ShQzAgsoujxuKN4LNXR2LfwwHsPj7Iaw+XY= -sigs.k8s.io/controller-tools v0.14.0 h1:rnNoCC5wSXlrNoBKKzL70LNJKIQKEzT6lloG6/LF73A= -sigs.k8s.io/controller-tools v0.14.0/go.mod h1:TV7uOtNNnnR72SpzhStvPkoS/U5ir0nMudrkrC4M9Sc= +sigs.k8s.io/controller-tools v0.16.3 h1:z48C5/d4jCVQQvtiSBL5MYyZ3EO2eFIOXrIKMgHVhFY= +sigs.k8s.io/controller-tools v0.16.3/go.mod h1:AEj6k+w1kYpLZv2einOH3mj52ips4W/6FUjnB5tkJGs= +sigs.k8s.io/gateway-api v1.2.1 h1:fZZ/+RyRb+Y5tGkwxFKuYuSRQHu9dZtbjenblleOLHM= +sigs.k8s.io/gateway-api v1.2.1/go.mod h1:EpNfEXNjiYfUJypf0eZ0P5iXA9ekSGWaS1WgPaM42X0= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= sigs.k8s.io/randfill v0.0.0-20250304075658-069ef1bbf016 h1:kXv6kKdoEtedwuqMmkqhbkgvYKeycVbC8+iPCP9j5kQ= diff --git a/mkdocs.yml b/mkdocs.yml index 2dc4d2a1..e5927ed5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,7 +10,7 @@ theme: icon: repo: fontawesome/brands/git-alt logo: images/logo/logo-text-large-horizontal-white.png - favicon: images/k8s-favicon.png + favicon: images/favicon-64.png features: - search.highlight - navigation.tabs @@ -54,13 +54,16 @@ nav: API Overview: concepts/api-overview.md Conformance: concepts/conformance.md Roles and Personas: concepts/roles-and-personas.md - - Implementations: implementations.md + - Implementations: + - Gateways: implementations/gateways.md + - Model Servers: implementations/model-servers.md - FAQ: faq.md - Guides: - User Guides: - Getting started: guides/index.md - Adapter Rollout: guides/adapter-rollout.md - Metrics: guides/metrics.md + - Replacing an Inference Pool: guides/replacing-inference-pool.md - Implementer's Guide: guides/implementers.md - Performance: - Benchmark: performance/benchmark/index.md diff --git a/pkg/body-based-routing/README.md b/pkg/bbr/README.md similarity index 100% rename from pkg/body-based-routing/README.md rename to pkg/bbr/README.md diff --git a/pkg/body-based-routing/handlers/request.go b/pkg/bbr/handlers/request.go similarity index 98% rename from pkg/body-based-routing/handlers/request.go rename to pkg/bbr/handlers/request.go index c0be46ac..32fffc02 100644 --- a/pkg/body-based-routing/handlers/request.go +++ b/pkg/bbr/handlers/request.go @@ -25,7 +25,7 @@ import ( eppb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) diff --git a/pkg/body-based-routing/handlers/request_test.go b/pkg/bbr/handlers/request_test.go similarity index 98% rename from pkg/body-based-routing/handlers/request_test.go rename to pkg/bbr/handlers/request_test.go index 0f088702..55c42a21 100644 --- a/pkg/body-based-routing/handlers/request_test.go +++ b/pkg/bbr/handlers/request_test.go @@ -28,7 +28,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" "k8s.io/component-base/metrics/legacyregistry" metricsutils "k8s.io/component-base/metrics/testutil" - "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) diff --git a/pkg/body-based-routing/handlers/response.go b/pkg/bbr/handlers/response.go similarity index 100% rename from pkg/body-based-routing/handlers/response.go rename to pkg/bbr/handlers/response.go diff --git a/pkg/body-based-routing/handlers/server.go b/pkg/bbr/handlers/server.go similarity index 98% rename from pkg/body-based-routing/handlers/server.go rename to pkg/bbr/handlers/server.go index 24664f98..484b3318 100644 --- a/pkg/body-based-routing/handlers/server.go +++ b/pkg/bbr/handlers/server.go @@ -114,16 +114,16 @@ func (s *Server) processRequestBody(ctx context.Context, body *extProcPb.HttpBod var requestBody map[string]interface{} if s.streaming { + streamedBody.body = append(streamedBody.body, body.Body...) // In the stream case, we can receive multiple request bodies. - if !body.EndOfStream { - streamedBody.body = append(streamedBody.body, body.Body...) - return nil, nil - } else { + if body.EndOfStream { loggerVerbose.Info("Flushing stream buffer") err := json.Unmarshal(streamedBody.body, &requestBody) if err != nil { logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") } + } else { + return nil, nil } } else { if err := json.Unmarshal(body.GetBody(), &requestBody); err != nil { diff --git a/pkg/body-based-routing/handlers/server_test.go b/pkg/bbr/handlers/server_test.go similarity index 100% rename from pkg/body-based-routing/handlers/server_test.go rename to pkg/bbr/handlers/server_test.go diff --git a/pkg/body-based-routing/metrics/metrics.go b/pkg/bbr/metrics/metrics.go similarity index 100% rename from pkg/body-based-routing/metrics/metrics.go rename to pkg/bbr/metrics/metrics.go diff --git a/pkg/body-based-routing/server/runserver.go b/pkg/bbr/server/runserver.go similarity index 96% rename from pkg/body-based-routing/server/runserver.go rename to pkg/bbr/server/runserver.go index 1646aa5a..2001b7ff 100644 --- a/pkg/body-based-routing/server/runserver.go +++ b/pkg/bbr/server/runserver.go @@ -27,7 +27,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/internal/runnable" tlsutil "sigs.k8s.io/gateway-api-inference-extension/internal/tls" - "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/handlers" + "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/handlers" ) // ExtProcServerRunner provides methods to manage an external process server. diff --git a/pkg/epp/README.md b/pkg/epp/README.md index 1bf47993..99d1bf06 100644 --- a/pkg/epp/README.md +++ b/pkg/epp/README.md @@ -1,5 +1,5 @@ # The EndPoint Picker (EPP) -This package provides the reference implementation for the Endpoint Picker (EPP). As demonistrated in the diagram below, it implements the [extension protocol](../../docs/proposals/004-endpoint-picker-protocol), enabling a proxy or gateway to request endpoint hints from an extension, and interacts with the model servers through the defined [model server protocol](../..//docs/proposals/003-model-server-protocol). +This package provides the reference implementation for the Endpoint Picker (EPP). As demonstrated in the diagram below, it implements the [extension protocol](../../docs/proposals/004-endpoint-picker-protocol), enabling a proxy or gateway to request endpoint hints from an extension, and interacts with the model servers through the defined [model server protocol](../..//docs/proposals/003-model-server-protocol). ![Architecture Diagram](../../docs/endpoint-picker.svg) diff --git a/pkg/epp/backend/metrics/fake.go b/pkg/epp/backend/metrics/fake.go index 7fd4970d..58d05026 100644 --- a/pkg/epp/backend/metrics/fake.go +++ b/pkg/epp/backend/metrics/fake.go @@ -24,13 +24,13 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) // FakePodMetrics is an implementation of PodMetrics that doesn't run the async refresh loop. type FakePodMetrics struct { - Pod *Pod + Pod *backend.Pod Metrics *Metrics } @@ -38,7 +38,7 @@ func (fpm *FakePodMetrics) String() string { return fmt.Sprintf("Pod: %v; Metrics: %v", fpm.GetPod(), fpm.GetMetrics()) } -func (fpm *FakePodMetrics) GetPod() *Pod { +func (fpm *FakePodMetrics) GetPod() *backend.Pod { return fpm.Pod } func (fpm *FakePodMetrics) GetMetrics() *Metrics { @@ -56,7 +56,7 @@ type FakePodMetricsClient struct { Res map[types.NamespacedName]*Metrics } -func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *Pod, existing *Metrics, port int32) (*Metrics, error) { +func (f *FakePodMetricsClient) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *Metrics, port int32) (*Metrics, error) { f.errMu.RLock() err, ok := f.Err[pod.NamespacedName] f.errMu.RUnlock() @@ -84,11 +84,3 @@ func (f *FakePodMetricsClient) SetErr(new map[types.NamespacedName]error) { defer f.errMu.Unlock() f.Err = new } - -type FakeDataStore struct { - Res map[string]*v1alpha2.InferenceModel -} - -func (fds *FakeDataStore) FetchModelData(modelName string) (returnModel *v1alpha2.InferenceModel) { - return fds.Res[modelName] -} diff --git a/pkg/epp/backend/metrics/logger.go b/pkg/epp/backend/metrics/logger.go index d71dc3fa..7dc1a8b8 100644 --- a/pkg/epp/backend/metrics/logger.go +++ b/pkg/epp/backend/metrics/logger.go @@ -32,6 +32,7 @@ const ( // Note currently the EPP treats stale metrics same as fresh. // TODO: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/336 metricsValidityPeriod = 5 * time.Second + debugPrintInterval = 5 * time.Second ) type Datastore interface { @@ -46,17 +47,16 @@ type Datastore interface { // enabled; 2) flushes Prometheus metrics about the backend servers. func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometheusMetricsInterval time.Duration) { logger := log.FromContext(ctx) - - // Periodically flush prometheus metrics for inference pool + ticker := time.NewTicker(refreshPrometheusMetricsInterval) go func() { + defer ticker.Stop() for { select { case <-ctx.Done(): logger.V(logutil.DEFAULT).Info("Shutting down prometheus metrics thread") return - default: - time.Sleep(refreshPrometheusMetricsInterval) - flushPrometheusMetricsOnce(logger, datastore) + case <-ticker.C: // Periodically refresh prometheus metrics for inference pool + refreshPrometheusMetrics(logger, datastore) } } }() @@ -64,13 +64,14 @@ func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometh // Periodically print out the pods and metrics for DEBUGGING. if logger := logger.V(logutil.DEBUG); logger.Enabled() { go func() { + ticker := time.NewTicker(debugPrintInterval) + defer ticker.Stop() for { select { case <-ctx.Done(): logger.V(logutil.DEFAULT).Info("Shutting down metrics logger thread") return - default: - time.Sleep(5 * time.Second) + case <-ticker.C: podsWithFreshMetrics := datastore.PodList(func(pm PodMetrics) bool { return time.Since(pm.GetMetrics().UpdateTime) <= metricsValidityPeriod }) @@ -85,11 +86,11 @@ func StartMetricsLogger(ctx context.Context, datastore Datastore, refreshPrometh } } -func flushPrometheusMetricsOnce(logger logr.Logger, datastore Datastore) { +func refreshPrometheusMetrics(logger logr.Logger, datastore Datastore) { pool, err := datastore.PoolGet() if err != nil { // No inference pool or not initialize. - logger.V(logutil.DEFAULT).Info("pool is not initialized, skipping flushing metrics") + logger.V(logutil.DEFAULT).Info("Pool is not initialized, skipping refreshing metrics") return } @@ -97,7 +98,7 @@ func flushPrometheusMetricsOnce(logger logr.Logger, datastore Datastore) { var queueTotal int podMetrics := datastore.PodGetAll() - logger.V(logutil.VERBOSE).Info("Flushing Prometheus Metrics", "ReadyPods", len(podMetrics)) + logger.V(logutil.TRACE).Info("Refreshing Prometheus Metrics", "ReadyPods", len(podMetrics)) if len(podMetrics) == 0 { return } diff --git a/pkg/epp/backend/metrics/metrics.go b/pkg/epp/backend/metrics/metrics.go index d48b1dc5..4cf56179 100644 --- a/pkg/epp/backend/metrics/metrics.go +++ b/pkg/epp/backend/metrics/metrics.go @@ -26,6 +26,7 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "go.uber.org/multierr" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" ) const ( @@ -39,15 +40,8 @@ type PodMetricsClientImpl struct { MetricMapping *MetricMapping } -// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an -// updated one. -func (p *PodMetricsClientImpl) FetchMetrics( - ctx context.Context, - pod *Pod, - existing *Metrics, - port int32, -) (*Metrics, error) { - +// FetchMetrics fetches metrics from a given pod, clones the existing metrics object and returns an updated one. +func (p *PodMetricsClientImpl) FetchMetrics(ctx context.Context, pod *backend.Pod, existing *Metrics, port int32) (*Metrics, error) { // Currently the metrics endpoint is hard-coded, which works with vLLM. // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/16): Consume this from InferencePool config. url := "http://" + pod.Address + ":" + strconv.Itoa(int(port)) + "/metrics" @@ -109,6 +103,7 @@ func (p *PodMetricsClientImpl) promToPodMetrics( if loraMetrics != nil { updated.ActiveModels = make(map[string]int) + updated.WaitingModels = make(map[string]int) for _, label := range loraMetrics.GetLabel() { if label.GetName() == LoraInfoRunningAdaptersMetricName { if label.GetValue() != "" { @@ -122,7 +117,7 @@ func (p *PodMetricsClientImpl) promToPodMetrics( if label.GetValue() != "" { adapterList := strings.Split(label.GetValue(), ",") for _, adapter := range adapterList { - updated.ActiveModels[adapter] = 0 + updated.WaitingModels[adapter] = 0 } } } diff --git a/pkg/epp/backend/metrics/metrics_test.go b/pkg/epp/backend/metrics/metrics_test.go index d0396bf7..53127010 100644 --- a/pkg/epp/backend/metrics/metrics_test.go +++ b/pkg/epp/backend/metrics/metrics_test.go @@ -30,6 +30,7 @@ import ( "google.golang.org/protobuf/proto" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -404,7 +405,8 @@ func TestPromToPodMetrics(t *testing.T) { expectedMetrics: &Metrics{ WaitingQueueSize: 7, KVCacheUsagePercent: 0.8, - ActiveModels: map[string]int{"lora1": 0, "lora2": 0, "lora3": 0}, + ActiveModels: map[string]int{"lora1": 0, "lora2": 0}, + WaitingModels: map[string]int{"lora3": 0}, MaxActiveModels: 3, }, }, @@ -416,8 +418,8 @@ func TestPromToPodMetrics(t *testing.T) { KVCacheUtilization: &MetricSpec{MetricName: "vllm_usage"}, LoraRequestInfo: &MetricSpec{MetricName: "vllm:lora_requests_info"}, }, - existingMetrics: &Metrics{ActiveModels: map[string]int{}}, - expectedMetrics: &Metrics{ActiveModels: map[string]int{}}, + existingMetrics: &Metrics{ActiveModels: map[string]int{}, WaitingModels: map[string]int{}}, + expectedMetrics: &Metrics{ActiveModels: map[string]int{}, WaitingModels: map[string]int{}}, expectedErr: multierr.Combine(errors.New("metric family \"vllm_waiting\" not found"), errors.New("metric family \"vllm_usage\" not found"), errors.New("metric family \"vllm:lora_requests_info\" not found")), }, { @@ -439,7 +441,8 @@ func TestPromToPodMetrics(t *testing.T) { expectedMetrics: &Metrics{ WaitingQueueSize: 0, KVCacheUsagePercent: 0.8, - ActiveModels: map[string]int{"lora1": 0, "lora2": 0, "lora3": 0}, + ActiveModels: map[string]int{"lora1": 0, "lora2": 0}, + WaitingModels: map[string]int{"lora3": 0}, MaxActiveModels: 3, }, expectedErr: errors.New("metric family \"vllm_waiting\" not found"), @@ -457,6 +460,7 @@ func TestPromToPodMetrics(t *testing.T) { existingMetrics: &Metrics{}, expectedMetrics: &Metrics{ ActiveModels: map[string]int{"lora1": 0}, + WaitingModels: map[string]int{}, MaxActiveModels: 0, // Should still default to 0. }, @@ -483,7 +487,7 @@ func TestPromToPodMetrics(t *testing.T) { // there's no server running on the specified port. func TestFetchMetrics(t *testing.T) { ctx := logutil.NewTestLoggerIntoContext(context.Background()) - pod := &Pod{ + pod := &backend.Pod{ Address: "127.0.0.1", NamespacedName: types.NamespacedName{ Namespace: "test", diff --git a/pkg/epp/backend/metrics/pod_metrics.go b/pkg/epp/backend/metrics/pod_metrics.go index cfb6b138..bdeb28ba 100644 --- a/pkg/epp/backend/metrics/pod_metrics.go +++ b/pkg/epp/backend/metrics/pod_metrics.go @@ -27,6 +27,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -35,28 +36,27 @@ const ( ) type podMetrics struct { - pod atomic.Pointer[Pod] + pod atomic.Pointer[backend.Pod] metrics atomic.Pointer[Metrics] pmc PodMetricsClient ds Datastore interval time.Duration - parentCtx context.Context - once sync.Once // ensure the StartRefreshLoop is only called once. - done chan struct{} + once sync.Once // ensure the StartRefreshLoop is only called once. + done chan struct{} logger logr.Logger } type PodMetricsClient interface { - FetchMetrics(ctx context.Context, pod *Pod, existing *Metrics, port int32) (*Metrics, error) + FetchMetrics(ctx context.Context, pod *backend.Pod, existing *Metrics, port int32) (*Metrics, error) } func (pm *podMetrics) String() string { return fmt.Sprintf("Pod: %v; Metrics: %v", pm.GetPod(), pm.GetMetrics()) } -func (pm *podMetrics) GetPod() *Pod { +func (pm *podMetrics) GetPod() *backend.Pod { return pm.pod.Load() } @@ -68,8 +68,8 @@ func (pm *podMetrics) UpdatePod(in *corev1.Pod) { pm.pod.Store(toInternalPod(in)) } -func toInternalPod(in *corev1.Pod) *Pod { - return &Pod{ +func toInternalPod(in *corev1.Pod) *backend.Pod { + return &backend.Pod{ NamespacedName: types.NamespacedName{ Name: in.Name, Namespace: in.Namespace, @@ -79,26 +79,24 @@ func toInternalPod(in *corev1.Pod) *Pod { } // start starts a goroutine exactly once to periodically update metrics. The goroutine will be -// stopped either when stop() is called, or the parentCtx is cancelled. -func (pm *podMetrics) startRefreshLoop() { +// stopped either when stop() is called, or the given ctx is cancelled. +func (pm *podMetrics) startRefreshLoop(ctx context.Context) { pm.once.Do(func() { go func() { pm.logger.V(logutil.DEFAULT).Info("Starting refresher", "pod", pm.GetPod()) + ticker := time.NewTicker(pm.interval) + defer ticker.Stop() for { select { case <-pm.done: return - case <-pm.parentCtx.Done(): + case <-ctx.Done(): return - default: + case <-ticker.C: // refresh metrics periodically + if err := pm.refreshMetrics(); err != nil { + pm.logger.V(logutil.TRACE).Error(err, "Failed to refresh metrics", "pod", pm.GetPod()) + } } - - err := pm.refreshMetrics() - if err != nil { - pm.logger.V(logutil.TRACE).Error(err, "Failed to refresh metrics", "pod", pm.GetPod()) - } - - time.Sleep(pm.interval) } }() }) diff --git a/pkg/epp/backend/metrics/pod_metrics_test.go b/pkg/epp/backend/metrics/pod_metrics_test.go index cf6698ca..e79c1bf0 100644 --- a/pkg/epp/backend/metrics/pod_metrics_test.go +++ b/pkg/epp/backend/metrics/pod_metrics_test.go @@ -44,6 +44,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } updated = &Metrics{ WaitingQueueSize: 9999, @@ -53,6 +54,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } ) diff --git a/pkg/epp/backend/metrics/types.go b/pkg/epp/backend/metrics/types.go index 17db23b4..4932e3ac 100644 --- a/pkg/epp/backend/metrics/types.go +++ b/pkg/epp/backend/metrics/types.go @@ -24,8 +24,8 @@ import ( "time" corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" ) func NewPodMetricsFactory(pmc PodMetricsClient, refreshMetricsInterval time.Duration) *PodMetricsFactory { @@ -41,45 +41,34 @@ type PodMetricsFactory struct { } func (f *PodMetricsFactory) NewPodMetrics(parentCtx context.Context, in *corev1.Pod, ds Datastore) PodMetrics { + pod := toInternalPod(in) pm := &podMetrics{ - pmc: f.pmc, - ds: ds, - interval: f.refreshMetricsInterval, - parentCtx: parentCtx, - once: sync.Once{}, - done: make(chan struct{}), - logger: log.FromContext(parentCtx), + pmc: f.pmc, + ds: ds, + interval: f.refreshMetricsInterval, + once: sync.Once{}, + done: make(chan struct{}), + logger: log.FromContext(parentCtx).WithValues("pod", pod.NamespacedName), } - pm.pod.Store(toInternalPod(in)) + pm.pod.Store(pod) pm.metrics.Store(newMetrics()) - pm.startRefreshLoop() + pm.startRefreshLoop(parentCtx) return pm } type PodMetrics interface { - GetPod() *Pod + GetPod() *backend.Pod GetMetrics() *Metrics UpdatePod(*corev1.Pod) StopRefreshLoop() String() string } -type Pod struct { - NamespacedName types.NamespacedName - Address string -} - -func (p *Pod) String() string { - if p == nil { - return "" - } - return fmt.Sprintf("%+v", *p) -} - type Metrics struct { // ActiveModels is a set of models(including LoRA adapters) that are currently cached to GPU. - ActiveModels map[string]int + ActiveModels map[string]int + WaitingModels map[string]int // MaxActiveModels is the maximum number of models that can be loaded to GPU. MaxActiveModels int RunningQueueSize int @@ -93,7 +82,8 @@ type Metrics struct { func newMetrics() *Metrics { return &Metrics{ - ActiveModels: make(map[string]int), + ActiveModels: make(map[string]int), + WaitingModels: make(map[string]int), } } @@ -105,12 +95,20 @@ func (m *Metrics) String() string { } func (m *Metrics) Clone() *Metrics { + if m == nil { + return nil + } cm := make(map[string]int, len(m.ActiveModels)) for k, v := range m.ActiveModels { cm[k] = v } + wm := make(map[string]int, len(m.WaitingModels)) + for k, v := range m.WaitingModels { + wm[k] = v + } clone := &Metrics{ ActiveModels: cm, + WaitingModels: wm, MaxActiveModels: m.MaxActiveModels, RunningQueueSize: m.RunningQueueSize, WaitingQueueSize: m.WaitingQueueSize, diff --git a/pkg/epp/backend/pod.go b/pkg/epp/backend/pod.go new file mode 100644 index 00000000..a63a0a83 --- /dev/null +++ b/pkg/epp/backend/pod.go @@ -0,0 +1,45 @@ +/* +Copyright 2025 The Kubernetes Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package backend + +import ( + "fmt" + + "k8s.io/apimachinery/pkg/types" +) + +type Pod struct { + NamespacedName types.NamespacedName + Address string +} + +func (p *Pod) String() string { + if p == nil { + return "" + } + return fmt.Sprintf("%+v", *p) +} + +func (p *Pod) Clone() *Pod { + if p == nil { + return nil + } + return &Pod{ + NamespacedName: types.NamespacedName{ + Name: p.NamespacedName.Name, + Namespace: p.NamespacedName.Namespace, + }, + Address: p.Address, + } +} diff --git a/pkg/epp/controller/inferencemodel_reconciler_test.go b/pkg/epp/controller/inferencemodel_reconciler_test.go index cd1ff1fb..80c30e19 100644 --- a/pkg/epp/controller/inferencemodel_reconciler_test.go +++ b/pkg/epp/controller/inferencemodel_reconciler_test.go @@ -25,6 +25,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -178,7 +179,8 @@ func TestInferenceModelReconciler(t *testing.T) { t.Run(test.name, func(t *testing.T) { // Create a fake client with no InferenceModel objects. scheme := runtime.NewScheme() - _ = v1alpha2.AddToScheme(scheme) + _ = clientgoscheme.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) initObjs := []client.Object{} if test.model != nil { initObjs = append(initObjs, test.model) @@ -186,6 +188,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInAPIServer { initObjs = append(initObjs, m) } + fakeClient := fake.NewClientBuilder(). WithScheme(scheme). WithObjects(initObjs...). @@ -196,7 +199,7 @@ func TestInferenceModelReconciler(t *testing.T) { for _, m := range test.modelsInStore { ds.ModelSetIfOlder(m) } - ds.PoolSet(pool) + _ = ds.PoolSet(context.Background(), fakeClient, pool) reconciler := &InferenceModelReconciler{ Client: fakeClient, Record: record.NewFakeRecorder(10), diff --git a/pkg/epp/controller/inferencepool_reconciler.go b/pkg/epp/controller/inferencepool_reconciler.go index c92d4ecc..fb7d7727 100644 --- a/pkg/epp/controller/inferencepool_reconciler.go +++ b/pkg/epp/controller/inferencepool_reconciler.go @@ -18,10 +18,8 @@ package controller import ( "context" - "reflect" "k8s.io/apimachinery/pkg/api/errors" - "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" @@ -36,9 +34,8 @@ import ( // will have the proper controller that will create/manage objects on behalf of the server pool. type InferencePoolReconciler struct { client.Client - Record record.EventRecorder - PoolNamespacedName types.NamespacedName - Datastore datastore.Datastore + Record record.EventRecorder + Datastore datastore.Datastore } func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { @@ -62,28 +59,15 @@ func (c *InferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Reques c.Datastore.Clear() return ctrl.Result{}, nil } - - c.updateDatastore(ctx, infPool) + // update pool in datastore + if err := c.Datastore.PoolSet(ctx, c.Client, infPool); err != nil { + logger.Error(err, "Failed to update datastore") + return ctrl.Result{}, err + } return ctrl.Result{}, nil } -func (c *InferencePoolReconciler) updateDatastore(ctx context.Context, newPool *v1alpha2.InferencePool) { - logger := log.FromContext(ctx) - oldPool, err := c.Datastore.PoolGet() - c.Datastore.PoolSet(newPool) - if err != nil || !reflect.DeepEqual(newPool.Spec.Selector, oldPool.Spec.Selector) { - logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", newPool.Spec.Selector) - // A full resync is required to address two cases: - // 1) At startup, the pod events may get processed before the pool is synced with the datastore, - // and hence they will not be added to the store since pool selector is not known yet - // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need - // to resync the whole pool: remove pods in the store that don't match the new selector and add - // the ones that may have existed already to the store. - c.Datastore.PodResyncAll(ctx, c.Client, newPool) - } -} - func (c *InferencePoolReconciler) SetupWithManager(mgr ctrl.Manager) error { return ctrl.NewControllerManagedBy(mgr). For(&v1alpha2.InferencePool{}). diff --git a/pkg/epp/controller/inferencepool_reconciler_test.go b/pkg/epp/controller/inferencepool_reconciler_test.go index 27c4238e..b7e28334 100644 --- a/pkg/epp/controller/inferencepool_reconciler_test.go +++ b/pkg/epp/controller/inferencepool_reconciler_test.go @@ -77,7 +77,7 @@ func TestInferencePoolReconciler(t *testing.T) { // Set up the scheme. scheme := runtime.NewScheme() _ = clientgoscheme.AddToScheme(scheme) - _ = v1alpha2.AddToScheme(scheme) + _ = v1alpha2.Install(scheme) // Create a fake client with the pool and the pods. initialObjects := []client.Object{pool1, pool2} @@ -96,7 +96,7 @@ func TestInferencePoolReconciler(t *testing.T) { pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) datastore := datastore.NewDatastore(ctx, pmf) - inferencePoolReconciler := &InferencePoolReconciler{PoolNamespacedName: namespacedName, Client: fakeClient, Datastore: datastore} + inferencePoolReconciler := &InferencePoolReconciler{Client: fakeClient, Datastore: datastore} // Step 1: Inception, only ready pods matching pool1 are added to the store. if _, err := inferencePoolReconciler.Reconcile(ctx, req); err != nil { diff --git a/pkg/epp/controller/pod_reconciler.go b/pkg/epp/controller/pod_reconciler.go index 046561e4..5f1df10d 100644 --- a/pkg/epp/controller/pod_reconciler.go +++ b/pkg/epp/controller/pod_reconciler.go @@ -26,10 +26,12 @@ import ( "k8s.io/client-go/tools/record" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/event" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" ) type PodReconciler struct { @@ -40,8 +42,7 @@ type PodReconciler struct { func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { logger := log.FromContext(ctx) - pool, err := c.Datastore.PoolGet() - if err != nil { + if !c.Datastore.PoolHasSynced() { logger.V(logutil.TRACE).Info("Skipping reconciling Pod because the InferencePool is not available yet") // When the inferencePool is initialized it lists the appropriate pods and populates the datastore, so no need to requeue. return ctrl.Result{}, nil @@ -59,38 +60,46 @@ func (c *PodReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.R return ctrl.Result{}, err } - c.updateDatastore(logger, pod, pool) + c.updateDatastore(logger, pod) return ctrl.Result{}, nil } func (c *PodReconciler) SetupWithManager(mgr ctrl.Manager) error { + filter := predicate.Funcs{ + CreateFunc: func(ce event.CreateEvent) bool { + pod := ce.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + UpdateFunc: func(ue event.UpdateEvent) bool { + oldPod := ue.ObjectOld.(*corev1.Pod) + newPod := ue.ObjectNew.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(oldPod.GetLabels()) || c.Datastore.PoolLabelsMatch(newPod.GetLabels()) + }, + DeleteFunc: func(de event.DeleteEvent) bool { + pod := de.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + GenericFunc: func(ge event.GenericEvent) bool { + pod := ge.Object.(*corev1.Pod) + return c.Datastore.PoolLabelsMatch(pod.GetLabels()) + }, + } return ctrl.NewControllerManagedBy(mgr). For(&corev1.Pod{}). + WithEventFilter(filter). Complete(c) } -func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod, pool *v1alpha2.InferencePool) { +func (c *PodReconciler) updateDatastore(logger logr.Logger, pod *corev1.Pod) { namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - if !pod.DeletionTimestamp.IsZero() || !c.Datastore.PoolLabelsMatch(pod.Labels) || !podIsReady(pod) { + if !podutil.IsPodReady(pod) || !c.Datastore.PoolLabelsMatch(pod.Labels) { logger.V(logutil.DEBUG).Info("Pod removed or not added", "name", namespacedName) c.Datastore.PodDelete(namespacedName) } else { - if c.Datastore.PodUpdateOrAddIfNotExist(pod, pool) { + if c.Datastore.PodUpdateOrAddIfNotExist(pod) { logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) } else { logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) } } } - -func podIsReady(pod *corev1.Pod) bool { - for _, condition := range pod.Status.Conditions { - if condition.Type == corev1.PodReady { - if condition.Status == corev1.ConditionTrue { - return true - } - break - } - } - return false -} diff --git a/pkg/epp/controller/pod_reconciler_test.go b/pkg/epp/controller/pod_reconciler_test.go index e4cb0b62..d2bdd5d0 100644 --- a/pkg/epp/controller/pod_reconciler_test.go +++ b/pkg/epp/controller/pod_reconciler_test.go @@ -182,9 +182,9 @@ func TestPodReconciler(t *testing.T) { // Configure the initial state of the datastore. store := datastore.NewDatastore(t.Context(), pmf) - store.PoolSet(test.pool) + _ = store.PoolSet(t.Context(), fakeClient, test.pool) for _, pod := range test.existingPods { - store.PodUpdateOrAddIfNotExist(pod, pool) + store.PodUpdateOrAddIfNotExist(pod) } podReconciler := &PodReconciler{Client: fakeClient, Datastore: store} diff --git a/pkg/epp/datastore/datastore.go b/pkg/epp/datastore/datastore.go index 8ada3e64..22c50022 100644 --- a/pkg/epp/datastore/datastore.go +++ b/pkg/epp/datastore/datastore.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "reflect" "sync" corev1 "k8s.io/api/core/v1" @@ -30,6 +31,7 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod" ) const ( @@ -43,7 +45,10 @@ var ( // The datastore is a local cache of relevant data for the given InferencePool (currently all pulled from k8s-api) type Datastore interface { // InferencePool operations - PoolSet(pool *v1alpha2.InferencePool) + // PoolSet sets the given pool in datastore. If the given pool has different label selector than the previous pool + // that was stored, the function triggers a resync of the pods to keep the datastore updated. If the given pool + // is nil, this call triggers the datastore.Clear() function. + PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error PoolGet() (*v1alpha2.InferencePool, error) PoolHasSynced() bool PoolLabelsMatch(podLabels map[string]string) bool @@ -59,16 +64,15 @@ type Datastore interface { // PodGetAll returns all pods and metrics, including fresh and stale. PodGetAll() []backendmetrics.PodMetrics // PodList lists pods matching the given predicate. - PodList(func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics - PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool + PodList(predicate func(backendmetrics.PodMetrics) bool) []backendmetrics.PodMetrics + PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool PodDelete(namespacedName types.NamespacedName) - PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) // Clears the store state, happens when the pool gets deleted. Clear() } -func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFactory) *datastore { +func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFactory) Datastore { store := &datastore{ parentCtx: parentCtx, poolAndModelsMu: sync.RWMutex{}, @@ -101,10 +105,31 @@ func (ds *datastore) Clear() { } // /// InferencePool APIs /// -func (ds *datastore) PoolSet(pool *v1alpha2.InferencePool) { +func (ds *datastore) PoolSet(ctx context.Context, client client.Client, pool *v1alpha2.InferencePool) error { + if pool == nil { + ds.Clear() + return nil + } + logger := log.FromContext(ctx) ds.poolAndModelsMu.Lock() defer ds.poolAndModelsMu.Unlock() + + oldPool := ds.pool ds.pool = pool + if oldPool == nil || !reflect.DeepEqual(pool.Spec.Selector, oldPool.Spec.Selector) { + logger.V(logutil.DEFAULT).Info("Updating inference pool endpoints", "selector", pool.Spec.Selector) + // A full resync is required to address two cases: + // 1) At startup, the pod events may get processed before the pool is synced with the datastore, + // and hence they will not be added to the store since pool selector is not known yet + // 2) If the selector on the pool was updated, then we will not get any pod events, and so we need + // to resync the whole pool: remove pods in the store that don't match the new selector and add + // the ones that may have existed already to the store. + if err := ds.podResyncAll(ctx, client); err != nil { + return fmt.Errorf("failed to update pods according to the pool selector - %w", err) + } + } + + return nil } func (ds *datastore) PoolGet() (*v1alpha2.InferencePool, error) { @@ -125,6 +150,9 @@ func (ds *datastore) PoolHasSynced() bool { func (ds *datastore) PoolLabelsMatch(podLabels map[string]string) bool { ds.poolAndModelsMu.RLock() defer ds.poolAndModelsMu.RUnlock() + if ds.pool == nil { + return false + } poolSelector := selectorFromInferencePoolSelector(ds.pool.Spec.Selector) podSet := labels.Set(podLabels) return poolSelector.Matches(podSet) @@ -228,7 +256,7 @@ func (ds *datastore) PodList(predicate func(backendmetrics.PodMetrics) bool) []b return res } -func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.InferencePool) bool { +func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod) bool { namespacedName := types.NamespacedName{ Name: pod.Name, Namespace: pod.Namespace, @@ -246,27 +274,35 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.In return ok } -func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, pool *v1alpha2.InferencePool) { +func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { + v, ok := ds.pods.LoadAndDelete(namespacedName) + if ok { + pmr := v.(backendmetrics.PodMetrics) + pmr.StopRefreshLoop() + } +} + +func (ds *datastore) podResyncAll(ctx context.Context, ctrlClient client.Client) error { logger := log.FromContext(ctx) podList := &corev1.PodList{} if err := ctrlClient.List(ctx, podList, &client.ListOptions{ - LabelSelector: selectorFromInferencePoolSelector(pool.Spec.Selector), - Namespace: pool.Namespace, + LabelSelector: selectorFromInferencePoolSelector(ds.pool.Spec.Selector), + Namespace: ds.pool.Namespace, }); err != nil { - log.FromContext(ctx).V(logutil.DEFAULT).Error(err, "Failed to list clients") - return + return fmt.Errorf("failed to list pods - %w", err) } activePods := make(map[string]bool) for _, pod := range podList.Items { - if podIsReady(&pod) { - namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} - activePods[pod.Name] = true - if ds.PodUpdateOrAddIfNotExist(&pod, pool) { - logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) - } else { - logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) - } + if !podutil.IsPodReady(&pod) { + continue + } + namespacedName := types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace} + activePods[pod.Name] = true + if ds.PodUpdateOrAddIfNotExist(&pod) { + logger.V(logutil.DEFAULT).Info("Pod added", "name", namespacedName) + } else { + logger.V(logutil.DEFAULT).Info("Pod already exists", "name", namespacedName) } } @@ -280,14 +316,8 @@ func (ds *datastore) PodResyncAll(ctx context.Context, ctrlClient client.Client, return true } ds.pods.Range(deleteFn) -} -func (ds *datastore) PodDelete(namespacedName types.NamespacedName) { - v, ok := ds.pods.LoadAndDelete(namespacedName) - if ok { - pmr := v.(backendmetrics.PodMetrics) - pmr.StopRefreshLoop() - } + return nil } func selectorFromInferencePoolSelector(selector map[v1alpha2.LabelKey]v1alpha2.LabelValue) labels.Selector { @@ -301,23 +331,3 @@ func stripLabelKeyAliasFromLabelMap(labels map[v1alpha2.LabelKey]v1alpha2.LabelV } return outMap } - -func IsCritical(model *v1alpha2.InferenceModel) bool { - if model.Spec.Criticality != nil && *model.Spec.Criticality == v1alpha2.Critical { - return true - } - return false -} - -// TODO: move out to share with pod_reconciler.go -func podIsReady(pod *corev1.Pod) bool { - for _, condition := range pod.Status.Conditions { - if condition.Type == corev1.PodReady { - if condition.Status == corev1.ConditionTrue { - return true - } - break - } - } - return false -} diff --git a/pkg/epp/datastore/datastore_test.go b/pkg/epp/datastore/datastore_test.go index 22bb0365..b6466e6b 100644 --- a/pkg/epp/datastore/datastore_test.go +++ b/pkg/epp/datastore/datastore_test.go @@ -27,7 +27,10 @@ import ( "github.com/stretchr/testify/assert" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" + clientgoscheme "k8s.io/client-go/kubernetes/scheme" + "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" testutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" @@ -71,9 +74,15 @@ func TestPool(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) datastore := NewDatastore(context.Background(), pmf) - datastore.PoolSet(tt.inferencePool) + _ = datastore.PoolSet(context.Background(), fakeClient, tt.inferencePool) gotPool, gotErr := datastore.PoolGet() if diff := cmp.Diff(tt.wantErr, gotErr, cmpopts.EquateErrors()); diff != "" { t.Errorf("Unexpected error diff (+got/-want): %s", diff) @@ -236,6 +245,7 @@ var ( "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, } pod2 = &corev1.Pod{ ObjectMeta: metav1.ObjectMeta{ @@ -250,6 +260,7 @@ var ( "foo1": 1, "bar1": 1, }, + WaitingModels: map[string]int{}, } pod1NamespacedName = types.NamespacedName{Name: pod1.Name, Namespace: pod1.Namespace} pod2NamespacedName = types.NamespacedName{Name: pod2.Name, Namespace: pod2.Namespace} @@ -305,6 +316,7 @@ func TestMetrics(t *testing.T) { // Failed to fetch pod2 metrics so it remains the default values. { ActiveModels: map[string]int{}, + WaitingModels: map[string]int{}, WaitingQueueSize: 0, KVCacheUsagePercent: 0, MaxActiveModels: 0, @@ -317,11 +329,17 @@ func TestMetrics(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() + // Set up the scheme. + scheme := runtime.NewScheme() + _ = clientgoscheme.AddToScheme(scheme) + fakeClient := fake.NewClientBuilder(). + WithScheme(scheme). + Build() pmf := backendmetrics.NewPodMetricsFactory(test.pmc, time.Millisecond) ds := NewDatastore(ctx, pmf) - ds.PoolSet(inferencePool) + _ = ds.PoolSet(ctx, fakeClient, inferencePool) for _, pod := range test.storePods { - ds.PodUpdateOrAddIfNotExist(pod, inferencePool) + ds.PodUpdateOrAddIfNotExist(pod) } assert.EventuallyWithT(t, func(t *assert.CollectT) { got := ds.PodGetAll() @@ -337,3 +355,94 @@ func TestMetrics(t *testing.T) { }) } } + +func TestPods(t *testing.T) { + updatedPod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + }, + Spec: corev1.PodSpec{ + NodeName: "node-1", + }, + } + tests := []struct { + name string + op func(ctx context.Context, ds Datastore) + existingPods []*corev1.Pod + wantPods []*corev1.Pod + }{ + { + name: "Add new pod, no existing pods, should add", + existingPods: []*corev1.Pod{}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod1) + }, + }, + { + name: "Add new pod, with existing pods, should add", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1, pod2}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(pod2) + }, + }, + { + name: "Update existing pod, new field, should update", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{updatedPod}, + op: func(ctx context.Context, ds Datastore) { + ds.PodUpdateOrAddIfNotExist(updatedPod) + }, + }, + { + name: "Update existing pod, no new fields, should not update", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + incoming := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pod1", + Namespace: "default", + }, + } + ds.PodUpdateOrAddIfNotExist(incoming) + }, + }, + { + name: "Delete the pod", + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodDelete(pod2NamespacedName) + }, + }, + { + name: "Delete the pod that doesn't exist", + existingPods: []*corev1.Pod{pod1}, + wantPods: []*corev1.Pod{pod1}, + op: func(ctx context.Context, ds Datastore) { + ds.PodDelete(pod2NamespacedName) + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := context.Background() + pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.FakePodMetricsClient{}, time.Second) + ds := NewDatastore(t.Context(), pmf) + for _, pod := range test.existingPods { + ds.PodUpdateOrAddIfNotExist(pod) + } + + test.op(ctx, ds) + var gotPods []*corev1.Pod + for _, pm := range ds.PodGetAll() { + pod := &corev1.Pod{ObjectMeta: metav1.ObjectMeta{Name: pm.GetPod().NamespacedName.Name, Namespace: pm.GetPod().NamespacedName.Namespace}, Status: corev1.PodStatus{PodIP: pm.GetPod().Address}} + gotPods = append(gotPods, pod) + } + if !cmp.Equal(gotPods, test.wantPods, cmpopts.SortSlices(func(a, b *corev1.Pod) bool { return a.Name < b.Name })) { + t.Logf("got (%v) != want (%v);", gotPods, test.wantPods) + } + }) + } +} diff --git a/pkg/epp/handlers/request.go b/pkg/epp/handlers/request.go index d7678fad..cfcd82ec 100644 --- a/pkg/epp/handlers/request.go +++ b/pkg/epp/handlers/request.go @@ -21,44 +21,36 @@ import ( "encoding/json" "fmt" "strconv" + "time" - configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "google.golang.org/protobuf/types/known/structpb" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// HandleRequestBody handles body of the request to the backend server, such as parsing the "model" -// parameter. -// Envoy sends the request body to ext proc before sending the request to the backend server. -func (s *Server) HandleRequestBody( +// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. +func (s *StreamingServer) HandleRequestBody( ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest, -) (*extProcPb.ProcessingResponse, error) { + requestBodyMap map[string]interface{}, +) (*RequestContext, error) { + var requestBodyBytes []byte logger := log.FromContext(ctx) - loggerVerbose := logger.V(logutil.VERBOSE) - loggerVerbose.Info("Handling request body") - - // Unmarshal request body (must be JSON). - v := req.Request.(*extProcPb.ProcessingRequest_RequestBody) - var rb map[string]interface{} - if err := json.Unmarshal(v.RequestBody.Body, &rb); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") - return nil, errutil.Error{Code: errutil.BadRequest, Msg: fmt.Sprintf("error unmarshaling request body: %v", err)} - } - loggerVerbose.Info("Request body unmarshalled", "body", rb) // Resolve target models. - model, ok := rb["model"].(string) + model, ok := requestBodyMap["model"].(string) + if !ok { + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} + } + prompt, ok := requestBodyMap["prompt"].(string) if !ok { - return nil, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} + return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "prompt not found in request"} } - loggerVerbose.Info("Model requested", "model", model) + modelName := model // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. @@ -66,145 +58,100 @@ func (s *Server) HandleRequestBody( // are able to be requested by using their distinct name. modelObj := s.datastore.ModelGet(model) if modelObj == nil { - return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} + return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} } if len(modelObj.Spec.TargetModels) > 0 { modelName = RandomWeightedDraw(logger, modelObj, 0) if modelName == "" { - return nil, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} + return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} } } - llmReq := &scheduling.LLMRequest{ + llmReq := &schedulingtypes.LLMRequest{ Model: model, ResolvedTargetModel: modelName, - Critical: datastore.IsCritical(modelObj), + Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical, + Prompt: prompt, } - loggerVerbose.Info("LLM request assembled", "request", llmReq) + logger.V(logutil.DEBUG).Info("LLM request assembled", "request", llmReq) - requestBody := v.RequestBody.Body var err error // Update target models in the body. if llmReq.Model != llmReq.ResolvedTargetModel { - rb["model"] = llmReq.ResolvedTargetModel - requestBody, err = json.Marshal(rb) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") - return nil, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} - } - loggerVerbose.Info("Updated request body marshalled", "body", string(requestBody)) + requestBodyMap["model"] = llmReq.ResolvedTargetModel } - target, err := s.scheduler.Schedule(ctx, llmReq) + requestBodyBytes, err = json.Marshal(requestBodyMap) if err != nil { - return nil, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} + logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") + return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} } - targetPod := target.GetPod() - logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) + res, err := s.scheduler.Schedule(ctx, llmReq) + if err != nil { + return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} + } + targetPod := res.TargetPod.GetPod() // Insert target endpoint to instruct Envoy to route requests to the specified target pod. // Attach the port number pool, err := s.datastore.PoolGet() if err != nil { - return nil, err + return reqCtx, err } endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + logger.V(logutil.DEFAULT).Info("Request handled", + "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod) + reqCtx.Model = llmReq.Model reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel - reqCtx.RequestSize = len(v.RequestBody.Body) + reqCtx.RequestSize = len(requestBodyBytes) reqCtx.TargetPod = targetPod.NamespacedName.String() reqCtx.TargetEndpoint = endpoint - headers := []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: s.destinationEndpointHintKey, - RawValue: []byte(endpoint), - }, - }, - // We need to update the content length header if the body is mutated, see Envoy doc: - // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(len(requestBody))), - }, - }, - } - // Print headers for debugging - for _, header := range headers { - logger.V(logutil.DEBUG).Info("Request body header", "key", header.Header.Key, "value", header.Header.RawValue) - } + s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes)) - targetEndpointValue := &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintKey: { - Kind: &structpb.Value_StringValue{ - StringValue: endpoint, - }, - }, - }, - } - dynamicMetadata := targetEndpointValue - if s.destinationEndpointHintMetadataNamespace != "" { - // If a namespace is defined, wrap the selected endpoint with that. - dynamicMetadata = &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintMetadataNamespace: { - Kind: &structpb.Value_StructValue{ - StructValue: targetEndpointValue, - }, - }, - }, - } - } - - resp := &extProcPb.ProcessingResponse{ + reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header // and as an unstructure ext-proc response metadata key/value pair. This enables different integration // options for gateway providers. Response: &extProcPb.ProcessingResponse_RequestBody{ RequestBody: &extProcPb.BodyResponse{ Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: headers, - }, BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: requestBody, + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: requestBodyBytes, + EndOfStream: true, + }, }, }, }, }, }, - DynamicMetadata: dynamicMetadata, } - return resp, nil + return reqCtx, nil } -func HandleRequestHeaders( - ctx context.Context, - reqCtx *RequestContext, - req *extProcPb.ProcessingRequest, -) *extProcPb.ProcessingResponse { - r := req.Request - h := r.(*extProcPb.ProcessingRequest_RequestHeaders) - log.FromContext(ctx).V(logutil.VERBOSE).Info("Handling request headers", "headers", h) - - resp := &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - // Set `clear_route_cache = true` to force Envoy to recompute the target cluster - // based on the new "target-pod" header. - // See https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto#service-ext-proc-v3-commonresponse. - ClearRouteCache: true, - }, - }, - }, +func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { + reqCtx.RequestReceivedTimestamp = time.Now() + + // an EoS in the request headers means this request has no body or trailers. + if req.RequestHeaders.EndOfStream { + // We will route this request to a random pod as this is assumed to just be a GET + // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 + // The above PR will address endpoint admission, but currently any request without a body will be + // routed to a random upstream pod. + pod := GetRandomPod(s.datastore) + if pod == nil { + return errutil.Error{Code: errutil.Internal, Msg: "no pods available in datastore"} + } + pool, err := s.datastore.PoolGet() + if err != nil { + return err + } + endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) + s.populateRequestHeaderResponse(reqCtx, endpoint, 0) } - - return resp + return nil } diff --git a/pkg/epp/handlers/response.go b/pkg/epp/handlers/response.go index 991b7d16..04c7a5e9 100644 --- a/pkg/epp/handlers/response.go +++ b/pkg/epp/handlers/response.go @@ -19,14 +19,11 @@ package handlers import ( "context" "encoding/json" - "fmt" "strings" - configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -35,78 +32,48 @@ const ( streamingEndMsg = "data: [DONE]" ) -// HandleResponseHeaders processes response headers from the backend model server. -func (s *Server) HandleResponseHeaders( +// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling. +func (s *StreamingServer) HandleResponseBody( ctx context.Context, reqCtx *RequestContext, - req *extProcPb.ProcessingRequest, -) (*extProcPb.ProcessingResponse, error) { - loggerVerbose := log.FromContext(ctx).V(logutil.VERBOSE) - loggerVerbose.Info("Processing ResponseHeaders") - h := req.Request.(*extProcPb.ProcessingRequest_ResponseHeaders) - loggerVerbose.Info("Headers before", "headers", h) - - // Example header - // { - // "ResponseHeaders": { - // "headers": [ - // { - // "key": ":status", - // "raw_value": "200" - // }, - // { - // "key": "date", - // "raw_value": "Thu, 30 Jan 2025 18:50:48 GMT" - // }, - // { - // "key": "server", - // "raw_value": "uvicorn" - // }, - // { - // "key": "content-type", - // "raw_value": "text/event-stream; charset=utf-8" - // }, - // { - // "key": "transfer-encoding", - // "raw_value": "chunked" - // } - // ] - // } - // } - for _, header := range h.ResponseHeaders.Headers.GetHeaders() { - var statusFound, typeFound bool - if header.Key == "status" { - code := header.RawValue[0] - if string(code) != "200" { - reqCtx.ResponseStatusCode = errutil.ModelServerError - statusFound = true - } - } - if header.Key == "content-type" { - contentType := header.RawValue - if strings.Contains(string(contentType), "text/event-stream") { - reqCtx.modelServerStreaming = true - } - typeFound = true - } - - if statusFound && typeFound { - break + response map[string]interface{}, +) (*RequestContext, error) { + logger := log.FromContext(ctx) + responseBytes, err := json.Marshal(response) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody") + return reqCtx, err + } + if response["usage"] != nil { + usg := response["usage"].(map[string]interface{}) + usage := Usage{ + PromptTokens: int(usg["prompt_tokens"].(float64)), + CompletionTokens: int(usg["completion_tokens"].(float64)), + TotalTokens: int(usg["total_tokens"].(float64)), } + reqCtx.Usage = usage + logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage) } + reqCtx.ResponseSize = len(responseBytes) + // ResponseComplete is to indicate the response is complete. In non-streaming + // case, it will be set to be true once the response is processed; in + // streaming case, it will be set to be true once the last chunk is processed. + // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178) + // will add the processing for streaming case. + reqCtx.ResponseComplete = true - resp := &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ + reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ + // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header + // and as an unstructure ext-proc response metadata key/value pair. This enables different integration + // options for gateway providers. + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: responseBytes, + EndOfStream: true, }, }, }, @@ -114,106 +81,21 @@ func (s *Server) HandleResponseHeaders( }, }, } - return resp, nil + return reqCtx, nil } -// HandleResponseBody parses response body to update information such as number of completion tokens. -// NOTE: The current implementation only supports Buffered mode, which is not enabled by default. To -// use it, you need to configure EnvoyExtensionPolicy to have response body in Buffered mode. -// https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto#envoy-v3-api-msg-extensions-filters-http-ext-proc-v3-processingmode -// Example response -/* -{ - "id": "cmpl-573498d260f2423f9e42817bbba3743a", - "object": "text_completion", - "created": 1732563765, - "model": "meta-llama/Llama-3.1-8B-Instruct", - "choices": [ - { - "index": 0, - "text": " Chronicle\nThe San Francisco Chronicle has a new book review section, and it's a good one. The reviews are short, but they're well-written and well-informed. The Chronicle's book review section is a good place to start if you're looking for a good book review.\nThe Chronicle's book review section is a good place to start if you're looking for a good book review. The Chronicle's book review section", - "logprobs": null, - "finish_reason": "length", - "stop_reason": null, - "prompt_logprobs": null - } - ], - "usage": { - "prompt_tokens": 11, - "total_tokens": 111, - "completion_tokens": 100 - } -}*/ -func (s *Server) HandleResponseBody( +// The function is to handle streaming response if the modelServer is streaming. +func (s *StreamingServer) HandleResponseBodyModelStreaming( ctx context.Context, reqCtx *RequestContext, - req *extProcPb.ProcessingRequest, -) (*extProcPb.ProcessingResponse, error) { - logger := log.FromContext(ctx) - loggerVerbose := logger.V(logutil.VERBOSE) - body := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) - - if reqCtx.modelServerStreaming { - logger.V(logutil.DEBUG).Info("Processing HandleResponseBody") - if err := s.HandleStreaming(ctx, reqCtx, body, loggerVerbose); err != nil { - return nil, err - } - } else { - loggerVerbose.Info("Processing HandleResponseBody") - if err := s.HandleNonStreaming(ctx, reqCtx, body, loggerVerbose); err != nil { - return nil, err - } - } - - resp := &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{}, - }, - }, - } - return resp, nil -} - -func (s *Server) HandleNonStreaming( - ctx context.Context, - reqCtx *RequestContext, - body *extProcPb.ProcessingRequest_ResponseBody, - loggerVerbose logr.Logger, -) error { - loggerVerbose.Info("Processing HandleResponseBody") - - res := Response{} - if err := json.Unmarshal(body.ResponseBody.Body, &res); err != nil { - return errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("unmarshaling response body: %v", err)} - } - reqCtx.Usage = res.Usage - reqCtx.ResponseSize = len(body.ResponseBody.Body) - reqCtx.ResponseComplete = true - loggerVerbose.Info("Response generated", "response", res) - return nil -} - -func (s *Server) HandleStreaming( - ctx context.Context, - reqCtx *RequestContext, - body *extProcPb.ProcessingRequest_ResponseBody, - loggerVerbose logr.Logger, -) error { - responseText := string(body.ResponseBody.Body) + responseText string, +) { if strings.Contains(responseText, streamingEndMsg) { - parsedResp := ParseRespForUsage(ctx, responseText) - reqCtx.Usage = parsedResp.Usage + resp := parseRespForUsage(ctx, responseText) + reqCtx.Usage = resp.Usage + metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.CompletionTokens) } - - if body.ResponseBody.EndOfStream { - loggerVerbose.Info("Streaming is completed") - reqCtx.ResponseComplete = true - } else { - reqCtx.ResponseSize += len(body.ResponseBody.Body) - } - - return nil } // Example message if "stream_options": {"include_usage": "true"} is included in the request: @@ -227,11 +109,12 @@ func (s *Server) HandleStreaming( // // If include_usage is not included in the request, `data: [DONE]` is returned separately, which // indicates end of streaming. -func ParseRespForUsage( +func parseRespForUsage( ctx context.Context, responseText string, ) Response { response := Response{} + logger := log.FromContext(ctx) lines := strings.Split(responseText, "\n") for _, line := range lines { @@ -245,8 +128,7 @@ func ParseRespForUsage( byteSlice := []byte(content) if err := json.Unmarshal(byteSlice, &response); err != nil { - logger := log.FromContext(ctx) - logger.V(logutil.DEFAULT).Error(err, "unmarshaling response body") + logger.Error(err, "unmarshaling response body") continue } } diff --git a/pkg/epp/handlers/response_test.go b/pkg/epp/handlers/response_test.go index 074b45c9..bfe5a629 100644 --- a/pkg/epp/handlers/response_test.go +++ b/pkg/epp/handlers/response_test.go @@ -18,9 +18,9 @@ package handlers import ( "context" + "encoding/json" "testing" - extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" "github.com/google/go-cmp/cmp" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -63,40 +63,61 @@ func TestHandleResponseBody(t *testing.T) { tests := []struct { name string - req *extProcPb.ProcessingRequest_ResponseBody + body []byte reqCtx *RequestContext want Usage wantErr bool }{ { name: "success", - req: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(body), - }, - }, + body: []byte(body), want: Usage{ PromptTokens: 11, TotalTokens: 111, CompletionTokens: 100, }, }, - { - name: "malformed response", - req: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte("malformed json"), - }, - }, - wantErr: true, - }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + server := &StreamingServer{} + reqCtx := test.reqCtx + if reqCtx == nil { + reqCtx = &RequestContext{} + } + var responseMap map[string]interface{} + marshalErr := json.Unmarshal(test.body, &responseMap) + if marshalErr != nil { + t.Error(marshalErr, "Error unmarshaling request body") + } + _, err := server.HandleResponseBody(ctx, reqCtx, responseMap) + if err != nil { + if !test.wantErr { + t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr) + } + return + } + + if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" { + t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) + } + }) + } +} + +func TestHandleStreamedResponseBody(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + tests := []struct { + name string + body string + reqCtx *RequestContext + want Usage + wantErr bool + }{ { name: "streaming request without usage", - req: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(streamingBodyWithoutUsage), - }, - }, + body: streamingBodyWithoutUsage, reqCtx: &RequestContext{ modelServerStreaming: true, }, @@ -105,11 +126,7 @@ func TestHandleResponseBody(t *testing.T) { }, { name: "streaming request with usage", - req: &extProcPb.ProcessingRequest_ResponseBody{ - ResponseBody: &extProcPb.HttpBody{ - Body: []byte(streamingBodyWithUsage), - }, - }, + body: streamingBodyWithUsage, reqCtx: &RequestContext{ modelServerStreaming: true, }, @@ -124,18 +141,12 @@ func TestHandleResponseBody(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - server := &Server{} + server := &StreamingServer{} reqCtx := test.reqCtx if reqCtx == nil { reqCtx = &RequestContext{} } - _, err := server.HandleResponseBody(ctx, reqCtx, &extProcPb.ProcessingRequest{Request: test.req}) - if err != nil { - if !test.wantErr { - t.Fatalf("HandleResponseBody returned unexpected error: %v, want %v", err, test.wantErr) - } - return - } + server.HandleResponseBodyModelStreaming(ctx, reqCtx, test.body) if diff := cmp.Diff(test.want, reqCtx.Usage); diff != "" { t.Errorf("HandleResponseBody returned unexpected response, diff(-want, +got): %v", diff) diff --git a/pkg/epp/handlers/server.go b/pkg/epp/handlers/server.go index a92f091c..630baef3 100644 --- a/pkg/epp/handlers/server.go +++ b/pkg/epp/handlers/server.go @@ -18,24 +18,32 @@ package handlers import ( "context" + "encoding/json" "io" + "math/rand" + "strconv" + "strings" "time" + configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + "github.com/go-logr/logr" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/structpb" "sigs.k8s.io/controller-runtime/pkg/log" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" + schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -func NewServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *Server { - return &Server{ +func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *StreamingServer { + return &StreamingServer{ scheduler: scheduler, destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, destinationEndpointHintKey: destinationEndpointHintKey, @@ -45,7 +53,7 @@ func NewServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, de // Server implements the Envoy external processing server. // https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto -type Server struct { +type StreamingServer struct { scheduler Scheduler // The key of the header to specify the target pod address. This value needs to match Envoy // configuration. @@ -57,30 +65,78 @@ type Server struct { } type Scheduler interface { - Schedule(ctx context.Context, b *scheduling.LLMRequest) (targetPod backendmetrics.PodMetrics, err error) + Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (result *schedulingtypes.Result, err error) } -func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { +// RequestContext stores context information during the life time of an HTTP request. +type RequestContext struct { + TargetPod string + TargetEndpoint string + Model string + ResolvedTargetModel string + RequestReceivedTimestamp time.Time + ResponseCompleteTimestamp time.Time + RequestSize int + Usage Usage + ResponseSize int + ResponseComplete bool + ResponseStatusCode string + RequestRunning bool + + RequestState StreamRequestState + modelServerStreaming bool + + reqHeaderResp *extProcPb.ProcessingResponse + reqBodyResp *extProcPb.ProcessingResponse + reqTrailerResp *extProcPb.ProcessingResponse + + respHeaderResp *extProcPb.ProcessingResponse + respBodyResp *extProcPb.ProcessingResponse + respTrailerResp *extProcPb.ProcessingResponse +} + +type StreamRequestState int + +const ( + RequestReceived StreamRequestState = 0 + HeaderRequestResponseComplete StreamRequestState = 1 + BodyRequestResponsesComplete StreamRequestState = 2 + TrailerRequestResponsesComplete StreamRequestState = 3 + ResponseRecieved StreamRequestState = 4 + HeaderResponseResponseComplete StreamRequestState = 5 + BodyResponseResponsesComplete StreamRequestState = 6 + TrailerResponseResponsesComplete StreamRequestState = 7 +) + +func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { ctx := srv.Context() logger := log.FromContext(ctx) - loggerVerbose := logger.V(logutil.VERBOSE) - loggerVerbose.Info("Processing") + loggerTrace := logger.V(logutil.TRACE) + loggerTrace.Info("Processing") // Create request context to share states during life time of an HTTP request. // See https://github.com/envoyproxy/envoy/issues/17540. - reqCtx := &RequestContext{} + reqCtx := &RequestContext{ + RequestState: RequestReceived, + } + + var body []byte + var requestBody, responseBody map[string]interface{} - // Create variable for error handling as each request should only report once for - // error metric. This doesn't cover the error "Cannot receive stream request" because - // such error might happen even the response is processed. + // Create error handling var as each request should only report once for + // error metrics. This doesn't cover the error "Cannot receive stream request" because + // such errors might happen even though response is processed. var err error - defer func(error) { + defer func(error, *RequestContext) { if reqCtx.ResponseStatusCode != "" { metrics.RecordRequestErrCounter(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseStatusCode) } else if err != nil { metrics.RecordRequestErrCounter(reqCtx.Model, reqCtx.ResolvedTargetModel, errutil.CanonicalCode(err)) } - }(err) + if reqCtx.RequestRunning { + metrics.DecRunningRequests(reqCtx.Model) + } + }(err, reqCtx) for { select { @@ -96,71 +152,311 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { if recvErr != nil { // This error occurs very frequently, though it doesn't seem to have any impact. // TODO Figure out if we can remove this noise. - loggerVerbose.Error(err, "Cannot receive stream request") + logger.V(logutil.DEFAULT).Error(err, "Cannot receive stream request") return status.Errorf(codes.Unknown, "cannot receive stream request: %v", err) } - var resp *extProcPb.ProcessingResponse switch v := req.Request.(type) { case *extProcPb.ProcessingRequest_RequestHeaders: - reqCtx.RequestReceivedTimestamp = time.Now() - resp = HandleRequestHeaders(ctx, reqCtx, req) - loggerVerbose.Info("Request context after HandleRequestHeaders", "context", reqCtx) + err = s.HandleRequestHeaders(ctx, reqCtx, v) case *extProcPb.ProcessingRequest_RequestBody: - resp, err = s.HandleRequestBody(ctx, reqCtx, req) - if err == nil { - metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) - metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) + loggerTrace.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream) + // In the stream case, we can receive multiple request bodies. + body = append(body, v.RequestBody.Body...) + + // Message is buffered, we can read and decode. + if v.RequestBody.EndOfStream { + loggerTrace.Info("decoding") + err = json.Unmarshal(body, &requestBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") + } + + // Body stream complete. Allocate empty slice for response to use. + body = []byte{} + + reqCtx, err = s.HandleRequestBody(ctx, reqCtx, req, requestBody) + if err != nil { + logger.V(logutil.DEFAULT).Error(err, "Error handling body") + } else { + metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) + metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) + } } - loggerVerbose.Info("Request context after HandleRequestBody", "context", reqCtx) + case *extProcPb.ProcessingRequest_RequestTrailers: + // This is currently unused. case *extProcPb.ProcessingRequest_ResponseHeaders: - resp, err = s.HandleResponseHeaders(ctx, reqCtx, req) - loggerVerbose.Info("Request context after HandleResponseHeaders", "context", reqCtx) - case *extProcPb.ProcessingRequest_ResponseBody: - // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. - // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. - // using the standard 'err' var will send an immediate error response back to the caller. - var responseErr error - resp, responseErr = s.HandleResponseBody(ctx, reqCtx, req) - if responseErr != nil { - logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body", "request", req) - } else if reqCtx.ResponseComplete { - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) - metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) + for _, header := range v.ResponseHeaders.Headers.GetHeaders() { + value := string(header.RawValue) + + loggerTrace.Info("header", "key", header.Key, "value", value) + if header.Key == "status" && value != "200" { + reqCtx.ResponseStatusCode = errutil.ModelServerError + } else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") { + reqCtx.modelServerStreaming = true + loggerTrace.Info("model server is streaming response") + } } + reqCtx.RequestState = ResponseRecieved + reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseHeaders{ + ResponseHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + // This is for debugging purpose only. + Key: "x-went-into-resp-headers", + RawValue: []byte("true"), + }, + }, + }, + }, + }, + }, + }, + } + + case *extProcPb.ProcessingRequest_ResponseBody: if reqCtx.modelServerStreaming { - logger.V(logutil.DEBUG).Info("Request context after HandleResponseBody", "context", reqCtx) + // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough. + + responseText := string(v.ResponseBody.Body) + s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText) + if v.ResponseBody.EndOfStream { + loggerTrace.Info("stream completed") + + reqCtx.ResponseCompleteTimestamp = time.Now() + metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) + metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) + } + + reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: v.ResponseBody.Body, + EndOfStream: v.ResponseBody.EndOfStream, + }, + }, + }, + }, + }, + }, + } } else { - loggerVerbose.Info("Request context after HandleResponseBody", "context", reqCtx) + body = append(body, v.ResponseBody.Body...) + + // Message is buffered, we can read and decode. + if v.ResponseBody.EndOfStream { + loggerTrace.Info("stream completed") + // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. + // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. + // using the standard 'err' var will send an immediate error response back to the caller. + var responseErr error + responseErr = json.Unmarshal(body, &responseBody) + if responseErr != nil { + logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshaling request body") + } + + reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) + if responseErr != nil { + logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body", "request", req) + } else if reqCtx.ResponseComplete { + reqCtx.ResponseCompleteTimestamp = time.Now() + metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) + metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) + metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) + metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) + } + } } - default: - logger.V(logutil.DEFAULT).Error(nil, "Unknown Request type", "request", v) - return status.Error(codes.Unknown, "unknown request type") + case *extProcPb.ProcessingRequest_ResponseTrailers: + // This is currently unused. } + // Handle the err and fire an immediate response. if err != nil { logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req) - resp, err = BuildErrResponse(err) + resp, err := BuildErrResponse(err) if err != nil { return err } + if err := srv.Send(resp); err != nil { + logger.V(logutil.DEFAULT).Error(err, "Send failed") + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } + return nil + } + loggerTrace.Info("checking", "request state", reqCtx.RequestState) + if err := reqCtx.updateStateAndSendIfNeeded(srv, logger); err != nil { + return err } + } +} - if !reqCtx.modelServerStreaming { - loggerVerbose.Info("Response generated", "response", resp) - } else { - logger.V(logutil.DEBUG).Info("Response generated", "response", resp) +// updateStateAndSendIfNeeded checks state and can send mutiple responses in a single pass, but only if ordered properly. +// Order of requests matter in FULL_DUPLEX_STREAMING. For both request and response, the order of response sent back MUST be: Header->Body->Trailer, with trailer being optional. +func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProcessor_ProcessServer, logger logr.Logger) error { + loggerTrace := logger.V(logutil.TRACE) + // No switch statement as we could send multiple responses in one pass. + if r.RequestState == RequestReceived && r.reqHeaderResp != nil { + loggerTrace.Info("Sending request header response", "obj", r.reqHeaderResp) + if err := srv.Send(r.reqHeaderResp); err != nil { + logger.V(logutil.DEFAULT).Error(err, "error sending response") + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } - if err := srv.Send(resp); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Send failed") + r.RequestState = HeaderRequestResponseComplete + } + if r.RequestState == HeaderRequestResponseComplete && r.reqBodyResp != nil { + loggerTrace.Info("Sending request body response") + if err := srv.Send(r.reqBodyResp); err != nil { return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) } + r.RequestState = BodyRequestResponsesComplete + metrics.IncRunningRequests(r.Model) + r.RequestRunning = true + // Dump the response so a new stream message can begin + r.reqBodyResp = nil + } + if r.RequestState == BodyRequestResponsesComplete && r.reqTrailerResp != nil { + // Trailers in requests are not guaranteed + if err := srv.Send(r.reqTrailerResp); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } + } + if r.RequestState == ResponseRecieved && r.respHeaderResp != nil { + loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp) + if err := srv.Send(r.respHeaderResp); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } + r.RequestState = HeaderResponseResponseComplete + } + if r.RequestState == HeaderResponseResponseComplete && r.respBodyResp != nil { + loggerTrace.Info("Sending response body response") + if err := srv.Send(r.respBodyResp); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } + + body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody) + if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() { + r.RequestState = BodyResponseResponsesComplete + } + // Dump the response so a new stream message can begin + r.respBodyResp = nil + } + if r.RequestState == BodyResponseResponsesComplete && r.respTrailerResp != nil { + // Trailers in requests are not guaranteed + if err := srv.Send(r.respTrailerResp); err != nil { + return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) + } + } + return nil +} + +func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { + headers := []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: s.destinationEndpointHintKey, + RawValue: []byte(endpoint), + }, + }, + } + if requestBodyLength > 0 { + // We need to update the content length header if the body is mutated, see Envoy doc: + // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto + headers = append(headers, &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: "Content-Length", + RawValue: []byte(strconv.Itoa(requestBodyLength)), + }, + }) + } + + targetEndpointValue := &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintKey: { + Kind: &structpb.Value_StringValue{ + StringValue: endpoint, + }, + }, + }, + } + dynamicMetadata := targetEndpointValue + if s.destinationEndpointHintMetadataNamespace != "" { + // If a namespace is defined, wrap the selected endpoint with that. + dynamicMetadata = &structpb.Struct{ + Fields: map[string]*structpb.Value{ + s.destinationEndpointHintMetadataNamespace: { + Kind: &structpb.Value_StructValue{ + StructValue: targetEndpointValue, + }, + }, + }, + } + } + + reqCtx.reqHeaderResp = &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + }, + }, + }, + DynamicMetadata: dynamicMetadata, } } +func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { + // TODO: after we are down to 1 server implementation, make these methods a part of the struct + // and handle random seeding on the struct. + source := rand.NewSource(rand.Int63()) + if seed > 0 { + source = rand.NewSource(seed) + } + r := rand.New(source) + + // all the weight values are nil, then we should return random model name + if model.Spec.TargetModels[0].Weight == nil { + index := r.Int31n(int32(len(model.Spec.TargetModels))) + return model.Spec.TargetModels[index].Name + } + + var weights int32 + for _, model := range model.Spec.TargetModels { + weights += *model.Weight + } + logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) + randomVal := r.Int31n(weights) + // TODO: optimize this without using loop + for _, model := range model.Spec.TargetModels { + if randomVal < *model.Weight { + return model.Name + } + randomVal -= *model.Weight + } + return "" +} + +func GetRandomPod(ds datastore.Datastore) *backend.Pod { + pods := ds.PodGetAll() + if len(pods) == 0 { + return nil + } + number := rand.Intn(len(pods)) + pod := pods[number] + return pod.GetPod() +} + func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) { var resp *extProcPb.ProcessingResponse @@ -214,43 +510,3 @@ func BuildErrResponse(err error) (*extProcPb.ProcessingResponse, error) { } return resp, nil } - -// RequestContext stores context information during the life time of an HTTP request. -type RequestContext struct { - TargetPod string - TargetEndpoint string - Model string - ResolvedTargetModel string - RequestReceivedTimestamp time.Time - ResponseCompleteTimestamp time.Time - RequestSize int - Usage Usage - ResponseSize int - ResponseComplete bool - ResponseStatusCode string - RequestRunning bool - - RequestState StreamRequestState - modelServerStreaming bool - - reqHeaderResp *extProcPb.ProcessingResponse - reqBodyResp *extProcPb.ProcessingResponse - reqTrailerResp *extProcPb.ProcessingResponse - - respHeaderResp *extProcPb.ProcessingResponse - respBodyResp *extProcPb.ProcessingResponse - respTrailerResp *extProcPb.ProcessingResponse -} - -type StreamRequestState int - -const ( - RequestReceived StreamRequestState = 0 - HeaderRequestResponseComplete StreamRequestState = 1 - BodyRequestResponsesComplete StreamRequestState = 2 - TrailerRequestResponsesComplete StreamRequestState = 3 - ResponseRecieved StreamRequestState = 4 - HeaderResponseResponseComplete StreamRequestState = 5 - BodyResponseResponsesComplete StreamRequestState = 6 - TrailerResponseResponsesComplete StreamRequestState = 7 -) diff --git a/pkg/epp/handlers/streamingserver.go b/pkg/epp/handlers/streamingserver.go deleted file mode 100644 index 874dd734..00000000 --- a/pkg/epp/handlers/streamingserver.go +++ /dev/null @@ -1,592 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package handlers - -import ( - "context" - "encoding/json" - "fmt" - "io" - "math/rand" - "strconv" - "strings" - "time" - - configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" - extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - "google.golang.org/protobuf/types/known/structpb" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" - errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *StreamingServer { - return &StreamingServer{ - scheduler: scheduler, - destinationEndpointHintMetadataNamespace: destinationEndpointHintMetadataNamespace, - destinationEndpointHintKey: destinationEndpointHintKey, - datastore: datastore, - } -} - -type StreamingServer struct { - scheduler Scheduler - // The key of the header to specify the target pod address. This value needs to match Envoy - // configuration. - destinationEndpointHintKey string - // The key acting as the outer namespace struct in the metadata extproc response to communicate - // back the picked endpoints. - destinationEndpointHintMetadataNamespace string - datastore datastore.Datastore -} - -func (s *StreamingServer) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { - ctx := srv.Context() - logger := log.FromContext(ctx) - loggerTrace := logger.V(logutil.TRACE) - loggerTrace.Info("Processing") - - // Create request context to share states during life time of an HTTP request. - // See https://github.com/envoyproxy/envoy/issues/17540. - reqCtx := &RequestContext{ - RequestState: RequestReceived, - } - - var body []byte - var requestBody, responseBody map[string]interface{} - - // Create error handling var as each request should only report once for - // error metrics. This doesn't cover the error "Cannot receive stream request" because - // such errors might happen even though response is processed. - var err error - defer func(error, *RequestContext) { - if reqCtx.ResponseStatusCode != "" { - metrics.RecordRequestErrCounter(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseStatusCode) - } else if err != nil { - metrics.RecordRequestErrCounter(reqCtx.Model, reqCtx.ResolvedTargetModel, errutil.CanonicalCode(err)) - } - if reqCtx.RequestRunning { - metrics.DecRunningRequests(reqCtx.Model) - } - }(err, reqCtx) - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - req, recvErr := srv.Recv() - if recvErr == io.EOF || status.Code(recvErr) == codes.Canceled { - return nil - } - if recvErr != nil { - // This error occurs very frequently, though it doesn't seem to have any impact. - // TODO Figure out if we can remove this noise. - logger.V(logutil.DEFAULT).Error(err, "Cannot receive stream request") - return status.Errorf(codes.Unknown, "cannot receive stream request: %v", err) - } - - switch v := req.Request.(type) { - case *extProcPb.ProcessingRequest_RequestHeaders: - err = s.HandleRequestHeaders(ctx, reqCtx, v) - case *extProcPb.ProcessingRequest_RequestBody: - loggerTrace.Info("Incoming body chunk", "EoS", v.RequestBody.EndOfStream) - // In the stream case, we can receive multiple request bodies. - body = append(body, v.RequestBody.Body...) - - // Message is buffered, we can read and decode. - if v.RequestBody.EndOfStream { - loggerTrace.Info("decoding") - err = json.Unmarshal(body, &requestBody) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error unmarshaling request body") - } - - // Body stream complete. Allocate empty slice for response to use. - body = []byte{} - - reqCtx, err = s.HandleRequestBody(ctx, reqCtx, req, requestBody) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error handling body") - } else { - metrics.RecordRequestCounter(reqCtx.Model, reqCtx.ResolvedTargetModel) - metrics.RecordRequestSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestSize) - } - } - case *extProcPb.ProcessingRequest_RequestTrailers: - // This is currently unused. - case *extProcPb.ProcessingRequest_ResponseHeaders: - for _, header := range v.ResponseHeaders.Headers.GetHeaders() { - value := string(header.RawValue) - - loggerTrace.Info("header", "key", header.Key, "value", value) - if header.Key == "status" && value != "200" { - reqCtx.ResponseStatusCode = errutil.ModelServerError - } else if header.Key == "content-type" && strings.Contains(value, "text/event-stream") { - reqCtx.modelServerStreaming = true - loggerTrace.Info("model server is streaming response") - } - } - reqCtx.RequestState = ResponseRecieved - reqCtx.respHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseHeaders{ - ResponseHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - // This is for debugging purpose only. - Key: "x-went-into-resp-headers", - RawValue: []byte("true"), - }, - }, - }, - }, - }, - }, - }, - } - - case *extProcPb.ProcessingRequest_ResponseBody: - if reqCtx.modelServerStreaming { - // Currently we punt on response parsing if the modelServer is streaming, and we just passthrough. - - responseText := string(v.ResponseBody.Body) - s.HandleResponseBodyModelStreaming(ctx, reqCtx, responseText) - if v.ResponseBody.EndOfStream { - loggerTrace.Info("stream completed") - - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) - } - - reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: v.ResponseBody.Body, - EndOfStream: v.ResponseBody.EndOfStream, - }, - }, - }, - }, - }, - }, - } - } else { - body = append(body, v.ResponseBody.Body...) - - // Message is buffered, we can read and decode. - if v.ResponseBody.EndOfStream { - loggerTrace.Info("stream completed") - // Don't send a 500 on a response error. Just let the message passthrough and log our error for debugging purposes. - // We assume the body is valid JSON, err messages are not guaranteed to be json, and so capturing and sending a 500 obfuscates the response message. - // using the standard 'err' var will send an immediate error response back to the caller. - var responseErr error - responseErr = json.Unmarshal(body, &responseBody) - if responseErr != nil { - logger.V(logutil.DEFAULT).Error(responseErr, "Error unmarshaling request body") - } - - reqCtx, responseErr = s.HandleResponseBody(ctx, reqCtx, responseBody) - if responseErr != nil { - logger.V(logutil.DEFAULT).Error(responseErr, "Failed to process response body", "request", req) - } else if reqCtx.ResponseComplete { - reqCtx.ResponseCompleteTimestamp = time.Now() - metrics.RecordRequestLatencies(ctx, reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.RequestReceivedTimestamp, reqCtx.ResponseCompleteTimestamp) - metrics.RecordResponseSizes(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.ResponseSize) - metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, reqCtx.Usage.CompletionTokens) - } - } - } - case *extProcPb.ProcessingRequest_ResponseTrailers: - // This is currently unused. - } - - // Handle the err and fire an immediate response. - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Failed to process request", "request", req) - resp, err := BuildErrResponse(err) - if err != nil { - return err - } - if err := srv.Send(resp); err != nil { - logger.V(logutil.DEFAULT).Error(err, "Send failed") - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - return nil - } - loggerTrace.Info("checking", "request state", reqCtx.RequestState) - if err := reqCtx.updateStateAndSendIfNeeded(srv, logger); err != nil { - return err - } - } -} - -// updateStateAndSendIfNeeded checks state and can send mutiple responses in a single pass, but only if ordered properly. -// Order of requests matter in FULL_DUPLEX_STREAMING. For both request and response, the order of response sent back MUST be: Header->Body->Trailer, with trailer being optional. -func (r *RequestContext) updateStateAndSendIfNeeded(srv extProcPb.ExternalProcessor_ProcessServer, logger logr.Logger) error { - loggerTrace := logger.V(logutil.TRACE) - // No switch statement as we could send multiple responses in one pass. - if r.RequestState == RequestReceived && r.reqHeaderResp != nil { - loggerTrace.Info("Sending request header response", "obj", r.reqHeaderResp) - if err := srv.Send(r.reqHeaderResp); err != nil { - logger.V(logutil.DEFAULT).Error(err, "error sending response") - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - r.RequestState = HeaderRequestResponseComplete - } - if r.RequestState == HeaderRequestResponseComplete && r.reqBodyResp != nil { - loggerTrace.Info("Sending request body response") - if err := srv.Send(r.reqBodyResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - r.RequestState = BodyRequestResponsesComplete - metrics.IncRunningRequests(r.Model) - r.RequestRunning = true - // Dump the response so a new stream message can begin - r.reqBodyResp = nil - } - if r.RequestState == BodyRequestResponsesComplete && r.reqTrailerResp != nil { - // Trailers in requests are not guaranteed - if err := srv.Send(r.reqHeaderResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - } - if r.RequestState == ResponseRecieved && r.respHeaderResp != nil { - loggerTrace.Info("Sending response header response", "obj", r.respHeaderResp) - if err := srv.Send(r.respHeaderResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - r.RequestState = HeaderResponseResponseComplete - } - if r.RequestState == HeaderResponseResponseComplete && r.respBodyResp != nil { - loggerTrace.Info("Sending response body response") - if err := srv.Send(r.respBodyResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - - body := r.respBodyResp.Response.(*extProcPb.ProcessingResponse_ResponseBody) - if body.ResponseBody.Response.GetBodyMutation().GetStreamedResponse().GetEndOfStream() { - r.RequestState = BodyResponseResponsesComplete - } - // Dump the response so a new stream message can begin - r.respBodyResp = nil - } - if r.RequestState == BodyResponseResponsesComplete && r.respTrailerResp != nil { - // Trailers in requests are not guaranteed - if err := srv.Send(r.reqHeaderResp); err != nil { - return status.Errorf(codes.Unknown, "failed to send response back to Envoy: %v", err) - } - } - return nil -} - -// HandleRequestBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleRequestBody( - ctx context.Context, - reqCtx *RequestContext, - req *extProcPb.ProcessingRequest, - requestBodyMap map[string]interface{}, -) (*RequestContext, error) { - var requestBodyBytes []byte - logger := log.FromContext(ctx) - - // Resolve target models. - model, ok := requestBodyMap["model"].(string) - if !ok { - return reqCtx, errutil.Error{Code: errutil.BadRequest, Msg: "model not found in request"} - } - - modelName := model - - // NOTE: The nil checking for the modelObject means that we DO allow passthrough currently. - // This might be a security risk in the future where adapters not registered in the InferenceModel - // are able to be requested by using their distinct name. - modelObj := s.datastore.ModelGet(model) - if modelObj == nil { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error finding a model object in InferenceModel for input %v", model)} - } - if len(modelObj.Spec.TargetModels) > 0 { - modelName = RandomWeightedDraw(logger, modelObj, 0) - if modelName == "" { - return reqCtx, errutil.Error{Code: errutil.BadConfiguration, Msg: fmt.Sprintf("error getting target model name for model %v", modelObj.Name)} - } - } - llmReq := &scheduling.LLMRequest{ - Model: model, - ResolvedTargetModel: modelName, - Critical: datastore.IsCritical(modelObj), - } - logger.V(logutil.DEBUG).Info("LLM request assembled", "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "critical", llmReq.Critical) - - var err error - // Update target models in the body. - if llmReq.Model != llmReq.ResolvedTargetModel { - requestBodyMap["model"] = llmReq.ResolvedTargetModel - } - - requestBodyBytes, err = json.Marshal(requestBodyMap) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "Error marshaling request body") - return reqCtx, errutil.Error{Code: errutil.Internal, Msg: fmt.Sprintf("error marshaling request body: %v", err)} - } - - target, err := s.scheduler.Schedule(ctx, llmReq) - if err != nil { - return reqCtx, errutil.Error{Code: errutil.InferencePoolResourceExhausted, Msg: fmt.Errorf("failed to find target pod: %w", err).Error()} - } - targetPod := target.GetPod() - - // Insert target endpoint to instruct Envoy to route requests to the specified target pod. - // Attach the port number - pool, err := s.datastore.PoolGet() - if err != nil { - return reqCtx, err - } - endpoint := targetPod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - - logger.V(logutil.DEFAULT).Info("Request handled", - "model", llmReq.Model, "targetModel", llmReq.ResolvedTargetModel, "endpoint", targetPod, "endpoint metrics", - fmt.Sprintf("%+v", target)) - - reqCtx.Model = llmReq.Model - reqCtx.ResolvedTargetModel = llmReq.ResolvedTargetModel - reqCtx.RequestSize = len(requestBodyBytes) - reqCtx.TargetPod = targetPod.NamespacedName.String() - reqCtx.TargetEndpoint = endpoint - - s.populateRequestHeaderResponse(reqCtx, endpoint, len(requestBodyBytes)) - - reqCtx.reqBodyResp = &extProcPb.ProcessingResponse{ - // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header - // and as an unstructure ext-proc response metadata key/value pair. This enables different integration - // options for gateway providers. - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: requestBodyBytes, - EndOfStream: true, - }, - }, - }, - }, - }, - }, - } - return reqCtx, nil -} - -// HandleResponseBody always returns the requestContext even in the error case, as the request context is used in error handling. -func (s *StreamingServer) HandleResponseBody( - ctx context.Context, - reqCtx *RequestContext, - response map[string]interface{}, -) (*RequestContext, error) { - logger := log.FromContext(ctx) - responseBytes, err := json.Marshal(response) - if err != nil { - logger.V(logutil.DEFAULT).Error(err, "error marshalling responseBody") - return reqCtx, err - } - if response["usage"] != nil { - usg := response["usage"].(map[string]interface{}) - usage := Usage{ - PromptTokens: int(usg["prompt_tokens"].(float64)), - CompletionTokens: int(usg["completion_tokens"].(float64)), - TotalTokens: int(usg["total_tokens"].(float64)), - } - reqCtx.Usage = usage - logger.V(logutil.VERBOSE).Info("Response generated", "usage", reqCtx.Usage) - } - reqCtx.ResponseSize = len(responseBytes) - // ResponseComplete is to indicate the response is complete. In non-streaming - // case, it will be set to be true once the response is processed; in - // streaming case, it will be set to be true once the last chunk is processed. - // TODO(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/178) - // will add the processing for streaming case. - reqCtx.ResponseComplete = true - - reqCtx.respBodyResp = &extProcPb.ProcessingResponse{ - // The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header - // and as an unstructure ext-proc response metadata key/value pair. This enables different integration - // options for gateway providers. - Response: &extProcPb.ProcessingResponse_ResponseBody{ - ResponseBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_StreamedResponse{ - StreamedResponse: &extProcPb.StreamedBodyResponse{ - Body: responseBytes, - EndOfStream: true, - }, - }, - }, - }, - }, - }, - } - return reqCtx, nil -} - -// The function is to handle streaming response if the modelServer is streaming. -func (s *StreamingServer) HandleResponseBodyModelStreaming( - ctx context.Context, - reqCtx *RequestContext, - responseText string, -) { - if strings.Contains(responseText, streamingEndMsg) { - resp := ParseRespForUsage(ctx, responseText) - metrics.RecordInputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.PromptTokens) - metrics.RecordOutputTokens(reqCtx.Model, reqCtx.ResolvedTargetModel, resp.Usage.CompletionTokens) - } -} - -func (s *StreamingServer) HandleRequestHeaders(ctx context.Context, reqCtx *RequestContext, req *extProcPb.ProcessingRequest_RequestHeaders) error { - reqCtx.RequestReceivedTimestamp = time.Now() - - // an EoS in the request headers means this request has no body or trailers. - if req.RequestHeaders.EndOfStream { - // We will route this request to a random pod as this is assumed to just be a GET - // More context: https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/526 - // The above PR will address endpoint admission, but currently any request without a body will be - // routed to a random upstream pod. - pod := GetRandomPod(s.datastore) - pool, err := s.datastore.PoolGet() - if err != nil { - return err - } - endpoint := pod.Address + ":" + strconv.Itoa(int(pool.Spec.TargetPortNumber)) - s.populateRequestHeaderResponse(reqCtx, endpoint, 0) - } - return nil -} - -func (s *StreamingServer) populateRequestHeaderResponse(reqCtx *RequestContext, endpoint string, requestBodyLength int) { - headers := []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: s.destinationEndpointHintKey, - RawValue: []byte(endpoint), - }, - }, - } - if requestBodyLength > 0 { - // We need to update the content length header if the body is mutated, see Envoy doc: - // https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/http/ext_proc/v3/processing_mode.proto - headers = append(headers, &configPb.HeaderValueOption{ - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte(strconv.Itoa(requestBodyLength)), - }, - }) - } - - targetEndpointValue := &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintKey: { - Kind: &structpb.Value_StringValue{ - StringValue: endpoint, - }, - }, - }, - } - dynamicMetadata := targetEndpointValue - if s.destinationEndpointHintMetadataNamespace != "" { - // If a namespace is defined, wrap the selected endpoint with that. - dynamicMetadata = &structpb.Struct{ - Fields: map[string]*structpb.Value{ - s.destinationEndpointHintMetadataNamespace: { - Kind: &structpb.Value_StructValue{ - StructValue: targetEndpointValue, - }, - }, - }, - } - } - - reqCtx.reqHeaderResp = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestHeaders{ - RequestHeaders: &extProcPb.HeadersResponse{ - Response: &extProcPb.CommonResponse{ - ClearRouteCache: true, - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: headers, - }, - }, - }, - }, - DynamicMetadata: dynamicMetadata, - } -} - -func RandomWeightedDraw(logger logr.Logger, model *v1alpha2.InferenceModel, seed int64) string { - // TODO: after we are down to 1 server implementation, make these methods a part of the struct - // and handle random seeding on the struct. - source := rand.NewSource(rand.Int63()) - if seed > 0 { - source = rand.NewSource(seed) - } - r := rand.New(source) - - // all the weight values are nil, then we should return random model name - if model.Spec.TargetModels[0].Weight == nil { - index := r.Int31n(int32(len(model.Spec.TargetModels))) - return model.Spec.TargetModels[index].Name - } - - var weights int32 - for _, model := range model.Spec.TargetModels { - weights += *model.Weight - } - logger.V(logutil.TRACE).Info("Weights for model computed", "model", model.Name, "weights", weights) - randomVal := r.Int31n(weights) - // TODO: optimize this without using loop - for _, model := range model.Spec.TargetModels { - if randomVal < *model.Weight { - return model.Name - } - randomVal -= *model.Weight - } - return "" -} - -func GetRandomPod(ds datastore.Datastore) *backendmetrics.Pod { - pods := ds.PodGetAll() - number := rand.Intn(len(pods)) - pod := pods[number] - return pod.GetPod() -} diff --git a/pkg/epp/handlers/streamingserver_test.go b/pkg/epp/handlers/streamingserver_test.go index 72f7031a..23d2b68f 100644 --- a/pkg/epp/handlers/streamingserver_test.go +++ b/pkg/epp/handlers/streamingserver_test.go @@ -18,8 +18,14 @@ package handlers import ( "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) @@ -126,6 +132,55 @@ func TestRandomWeightedDraw(t *testing.T) { } } +func TestGetRandomPod(t *testing.T) { + tests := []struct { + name string + storePods []*corev1.Pod + expectNil bool + }{ + { + name: "No pods available", + storePods: []*corev1.Pod{}, + expectNil: true, + }, + { + name: "Single pod available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + }, + expectNil: false, + }, + { + name: "Multiple pods available", + storePods: []*corev1.Pod{ + {ObjectMeta: metav1.ObjectMeta{Name: "pod1"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod2"}}, + {ObjectMeta: metav1.ObjectMeta{Name: "pod3"}}, + }, + expectNil: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + pmf := metrics.NewPodMetricsFactory(&metrics.FakePodMetricsClient{}, time.Millisecond) + ds := datastore.NewDatastore(t.Context(), pmf) + for _, pod := range test.storePods { + ds.PodUpdateOrAddIfNotExist(pod) + } + + gotPod := GetRandomPod(ds) + + if test.expectNil && gotPod != nil { + t.Errorf("expected nil pod, got: %v", gotPod) + } + if !test.expectNil && gotPod == nil { + t.Errorf("expected non-nil pod, got nil") + } + }) + } +} + func pointer(v int32) *int32 { return &v } diff --git a/pkg/epp/metrics/metrics.go b/pkg/epp/metrics/metrics.go index 434b8381..6df3dab3 100644 --- a/pkg/epp/metrics/metrics.go +++ b/pkg/epp/metrics/metrics.go @@ -18,6 +18,7 @@ package metrics import ( "context" + "runtime/debug" "sync" "time" @@ -30,6 +31,13 @@ import ( const ( InferenceModelComponent = "inference_model" InferencePoolComponent = "inference_pool" + EPPComponent = "endpoint_picker" + InferenceExtension = "inference_extension" +) + +var ( + // The git hash of the latest commit in the build. + CommitHash string ) var ( @@ -131,6 +139,21 @@ var ( []string{"model_name"}, ) + // NTPOT - Normalized Time Per Output Token + NormalizedTimePerOutputToken = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: InferenceModelComponent, + Name: "normalized_time_per_output_token_seconds", + Help: "Inference model latency divided by number of output tokens in seconds for each model and target model.", + // From few milliseconds per token to multiple seconds per token + Buckets: []float64{ + 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0, + }, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"model_name", "target_model_name"}, + ) + // Inference Pool Metrics inferencePoolAvgKVCache = compbasemetrics.NewGaugeVec( &compbasemetrics.GaugeOpts{ @@ -161,6 +184,31 @@ var ( }, []string{"name"}, ) + + // Scheduler Plugin Metrics + SchedulerPluginProcessingLatencies = compbasemetrics.NewHistogramVec( + &compbasemetrics.HistogramOpts{ + Subsystem: EPPComponent, + Name: "scheduler_plugin_duration_seconds", + Help: "Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name.", + Buckets: []float64{ + 0.0001, 0.0002, 0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, + }, + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"plugin_type", "plugin_name"}, + ) + + // Info Metrics + InferenceExtensionInfo = compbasemetrics.NewGaugeVec( + &compbasemetrics.GaugeOpts{ + Subsystem: InferenceExtension, + Name: "info", + Help: "General information of the current build of Inference Extension.", + StabilityLevel: compbasemetrics.ALPHA, + }, + []string{"commit"}, + ) ) var registerMetrics sync.Once @@ -176,10 +224,15 @@ func Register() { legacyregistry.MustRegister(inputTokens) legacyregistry.MustRegister(outputTokens) legacyregistry.MustRegister(runningRequests) + legacyregistry.MustRegister(NormalizedTimePerOutputToken) legacyregistry.MustRegister(inferencePoolAvgKVCache) legacyregistry.MustRegister(inferencePoolAvgQueueSize) legacyregistry.MustRegister(inferencePoolReadyPods) + + legacyregistry.MustRegister(SchedulerPluginProcessingLatencies) + + legacyregistry.MustRegister(InferenceExtensionInfo) }) } @@ -231,6 +284,27 @@ func RecordOutputTokens(modelName, targetModelName string, size int) { } } +// RecordNormalizedTimePerOutputToken (NTPOT) records the normalized time per output token. +func RecordNormalizedTimePerOutputToken(ctx context.Context, modelName, targetModelName string, received time.Time, complete time.Time, outputTokenCount int) bool { + if !complete.After(received) { + log.FromContext(ctx).Error(nil, "Request latency values are invalid for NTPOT calculation", + "modelName", modelName, "targetModelName", targetModelName, "completeTime", complete, "receivedTime", received) + return false + } + + if outputTokenCount <= 0 { + log.FromContext(ctx).Error(nil, "Output token count must be positive for NTPOT calculation", + "modelName", modelName, "targetModelName", targetModelName, "outputTokenCount", outputTokenCount) + return false + } + + elapsedSeconds := complete.Sub(received).Seconds() + secondsPerToken := elapsedSeconds / float64(outputTokenCount) + + NormalizedTimePerOutputToken.WithLabelValues(modelName, targetModelName).Observe(secondsPerToken) + return true +} + // IncRunningRequests increases the current running requests. func IncRunningRequests(modelName string) { if modelName != "" { @@ -256,3 +330,32 @@ func RecordInferencePoolAvgQueueSize(name string, queueSize float64) { func RecordinferencePoolReadyPods(name string, runningPods float64) { inferencePoolReadyPods.WithLabelValues(name).Set(runningPods) } + +// RecordSchedulerPluginProcessingLatency records the processing latency for a scheduler plugin. +func RecordSchedulerPluginProcessingLatency(pluginType, pluginName string, duration time.Duration) { + SchedulerPluginProcessingLatencies.WithLabelValues(pluginType, pluginName).Observe(duration.Seconds()) +} + +func RecordInferenceExtensionInfo() { + if CommitHash != "" { + InferenceExtensionInfo.WithLabelValues(CommitHash).Set(1) + } +} + +func init() { + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + + var Commit = func(i *debug.BuildInfo) string { + for _, setting := range i.Settings { + if setting.Key == "vcs.revision" { + return setting.Value + } + } + return "" + }(info) + + CommitHash = Commit +} diff --git a/pkg/epp/metrics/metrics_test.go b/pkg/epp/metrics/metrics_test.go index dc4c7044..81797e6d 100644 --- a/pkg/epp/metrics/metrics_test.go +++ b/pkg/epp/metrics/metrics_test.go @@ -29,16 +29,17 @@ import ( ) const ( - RequestTotalMetric = InferenceModelComponent + "_request_total" - RequestErrorTotalMetric = InferenceModelComponent + "_request_error_total" - RequestLatenciesMetric = InferenceModelComponent + "_request_duration_seconds" - RequestSizesMetric = InferenceModelComponent + "_request_sizes" - ResponseSizesMetric = InferenceModelComponent + "_response_sizes" - InputTokensMetric = InferenceModelComponent + "_input_tokens" - OutputTokensMetric = InferenceModelComponent + "_output_tokens" - RunningRequestsMetric = InferenceModelComponent + "_running_requests" - KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" - QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" + RequestTotalMetric = InferenceModelComponent + "_request_total" + RequestErrorTotalMetric = InferenceModelComponent + "_request_error_total" + RequestLatenciesMetric = InferenceModelComponent + "_request_duration_seconds" + RequestSizesMetric = InferenceModelComponent + "_request_sizes" + ResponseSizesMetric = InferenceModelComponent + "_response_sizes" + InputTokensMetric = InferenceModelComponent + "_input_tokens" + OutputTokensMetric = InferenceModelComponent + "_output_tokens" + NormalizedTimePerOutputTokenMetric = InferenceModelComponent + "_normalized_time_per_output_token_seconds" + RunningRequestsMetric = InferenceModelComponent + "_running_requests" + KVCacheAvgUsageMetric = InferencePoolComponent + "_average_kv_cache_utilization" + QueueAvgSizeMetric = InferencePoolComponent + "_average_queue_size" ) func TestRecordRequestCounterandSizes(t *testing.T) { @@ -252,6 +253,107 @@ func TestRecordRequestLatencies(t *testing.T) { } } +func TestRecordNormalizedTimePerOutputToken(t *testing.T) { + ctx := logutil.NewTestLoggerIntoContext(context.Background()) + timeBaseline := time.Now() + type tokenRequests struct { + modelName string + targetModelName string + receivedTime time.Time + completeTime time.Time + outputTokens int + } + scenarios := []struct { + name string + reqs []tokenRequests + invalid bool + }{ + { + name: "multiple requests", + reqs: []tokenRequests{ + { + modelName: "m10", + targetModelName: "t10", + receivedTime: timeBaseline, + completeTime: timeBaseline.Add(time.Millisecond * 1000), + outputTokens: 100, // 10ms per token + }, + { + modelName: "m10", + targetModelName: "t10", + receivedTime: timeBaseline, + completeTime: timeBaseline.Add(time.Millisecond * 1600), + outputTokens: 80, // 20ms per token + }, + { + modelName: "m10", + targetModelName: "t11", + receivedTime: timeBaseline, + completeTime: timeBaseline.Add(time.Millisecond * 6000), + outputTokens: 300, // 20ms per token + }, + { + modelName: "m20", + targetModelName: "t20", + receivedTime: timeBaseline, + completeTime: timeBaseline.Add(time.Millisecond * 2400), + outputTokens: 400, // 6ms per token + }, + }, + }, + { + name: "invalid elapsed time", + reqs: []tokenRequests{ + { + modelName: "m10", + targetModelName: "t10", + receivedTime: timeBaseline.Add(time.Millisecond * 10), + completeTime: timeBaseline, + outputTokens: 100, + }, + }, + invalid: true, + }, + { + name: "invalid token count", + reqs: []tokenRequests{ + { + modelName: "m10", + targetModelName: "t10", + receivedTime: timeBaseline, + completeTime: timeBaseline.Add(time.Millisecond * 1000), + outputTokens: 0, // Invalid: zero tokens + }, + }, + invalid: true, + }, + } + Register() + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + for _, req := range scenario.reqs { + success := RecordNormalizedTimePerOutputToken(ctx, req.modelName, req.targetModelName, req.receivedTime, req.completeTime, req.outputTokens) + if success == scenario.invalid { + t.Errorf("got record success(%v), but the request expects invalid(%v)", success, scenario.invalid) + } + } + + wantLatencyPerToken, err := os.Open("testdata/normalized_time_per_output_token_seconds_metric") + defer func() { + if err := wantLatencyPerToken.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantLatencyPerToken, NormalizedTimePerOutputTokenMetric); err != nil { + t.Error(err) + } + }) + } +} + func TestRecordResponseMetrics(t *testing.T) { type responses struct { modelName string @@ -454,3 +556,67 @@ func TestInferencePoolMetrics(t *testing.T) { }) } } + +func TestSchedulerPluginProcessingLatencies(t *testing.T) { + type pluginLatency struct { + pluginType string + pluginName string + duration time.Duration + } + scenarios := []struct { + name string + latencies []pluginLatency + }{ + { + name: "multiple plugins", + latencies: []pluginLatency{ + { + pluginType: "PreSchedule", + pluginName: "PluginA", + duration: 100 * time.Millisecond, + }, + { + pluginType: "PostSchedule", + pluginName: "PluginB", + duration: 200 * time.Millisecond, + }, + { + pluginType: "Filter", + pluginName: "PluginC", + duration: 50 * time.Millisecond, + }, + { + pluginType: "Scorer", + pluginName: "PluginD", + duration: 10 * time.Millisecond, + }, + { + pluginType: "Picker", + pluginName: "PluginE", + duration: 10 * time.Microsecond, + }, + }, + }, + } + Register() + for _, scenario := range scenarios { + t.Run(scenario.name, func(t *testing.T) { + for _, latency := range scenario.latencies { + RecordSchedulerPluginProcessingLatency(latency.pluginType, latency.pluginName, latency.duration) + } + + wantPluginLatencies, err := os.Open("testdata/scheduler_plugin_processing_latencies_metric") + defer func() { + if err := wantPluginLatencies.Close(); err != nil { + t.Error(err) + } + }() + if err != nil { + t.Fatal(err) + } + if err := testutil.GatherAndCompare(legacyregistry.DefaultGatherer, wantPluginLatencies, "endpoint_picker_scheduler_plugin_processing_latencies"); err != nil { + t.Error(err) + } + }) + } +} diff --git a/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric b/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric new file mode 100644 index 00000000..bb6e9373 --- /dev/null +++ b/pkg/epp/metrics/testdata/normalized_time_per_output_token_seconds_metric @@ -0,0 +1,50 @@ +# HELP inference_model_normalized_time_per_output_token_seconds [ALPHA] Inference model latency divided by number of output tokens in seconds for each model and target model. +# TYPE inference_model_normalized_time_per_output_token_seconds histogram +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.001"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.002"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.005"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.01"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.02"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.05"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.1"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.2"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="0.5"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="1.0"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="2.0"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="5.0"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="10.0"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t10", le="+Inf"} 2 +inference_model_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t10"} 0.03 +inference_model_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t10"} 2 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.001"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.002"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.005"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.01"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.02"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.05"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.1"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.2"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="0.5"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="1.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="2.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="5.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="10.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m10", target_model_name="t11", le="+Inf"} 1 +inference_model_normalized_time_per_output_token_seconds_sum{model_name="m10", target_model_name="t11"} 0.02 +inference_model_normalized_time_per_output_token_seconds_count{model_name="m10", target_model_name="t11"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.001"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.002"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.005"} 0 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.01"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.02"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.05"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.1"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.2"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="0.5"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="1.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="2.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="5.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="10.0"} 1 +inference_model_normalized_time_per_output_token_seconds_bucket{model_name="m20", target_model_name="t20", le="+Inf"} 1 +inference_model_normalized_time_per_output_token_seconds_sum{model_name="m20", target_model_name="t20"} 0.006 +inference_model_normalized_time_per_output_token_seconds_count{model_name="m20", target_model_name="t20"} 1 diff --git a/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric new file mode 100644 index 00000000..8c11757f --- /dev/null +++ b/pkg/epp/metrics/testdata/scheduler_plugin_processing_latencies_metric @@ -0,0 +1,67 @@ +# HELP endpoint_picker_scheduler_plugin_duration_seconds [ALPHA] Scheduler plugin processing latency distribution in seconds for each plugin type and plugin name. +# TYPE endpoint_picker_scheduler_plugin_duration_seconds histogram +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginA",plugin_type="PreSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginA",plugin_type="PreSchedule"} 0.1 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginA",plugin_type="PreSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.05"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="0.1"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginB",plugin_type="PostSchedule",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginB",plugin_type="PostSchedule"} 0.2 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginB",plugin_type="PostSchedule"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.01"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.02"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginC",plugin_type="Filter",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginC",plugin_type="Filter"} 0.05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginC",plugin_type="Filter"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.0005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.001"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.002"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.005"} 0 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginD",plugin_type="Scorer",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginD",plugin_type="Scorer"} 0.01 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginD",plugin_type="Scorer"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.0005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.001"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.002"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.005"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.01"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.02"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.05"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="0.1"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_bucket{plugin_name="PluginE",plugin_type="Picker",le="+Inf"} 1 +endpoint_picker_scheduler_plugin_duration_seconds_sum{plugin_name="PluginE",plugin_type="Picker"} 1e-05 +endpoint_picker_scheduler_plugin_duration_seconds_count{plugin_name="PluginE",plugin_type="Picker"} 1 diff --git a/pkg/epp/scheduling/config.go b/pkg/epp/scheduling/config.go new file mode 100644 index 00000000..4ed109af --- /dev/null +++ b/pkg/epp/scheduling/config.go @@ -0,0 +1,41 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + +type SchedulerConfig struct { + preSchedulePlugins []plugins.PreSchedule + filters []plugins.Filter + scorers map[plugins.Scorer]int // map from scorer to weight + picker plugins.Picker + postSchedulePlugins []plugins.PostSchedule +} + +var defPlugin = &defaultPlugin{} + +// When the scheduler is initialized with NewScheduler function, this config will be used as default. +// it's possible to call NewSchedulerWithConfig to pass a different argument. + +// For build time plugins changes, it's recommended to change the defaultConfig variable in this file. +var defaultConfig = &SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{}, + filters: []plugins.Filter{defPlugin}, + scorers: map[plugins.Scorer]int{}, + picker: defPlugin, + postSchedulePlugins: []plugins.PostSchedule{}, +} diff --git a/pkg/epp/scheduling/config/config.go b/pkg/epp/scheduling/config/config.go new file mode 100644 index 00000000..e00b82ae --- /dev/null +++ b/pkg/epp/scheduling/config/config.go @@ -0,0 +1,58 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package config + +import ( + "sigs.k8s.io/controller-runtime/pkg/log" + envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +// Config holds all the configuration values for the scheduler +type Config struct { + KVCacheThreshold float64 + QueueThresholdCritical int + QueueingThresholdLoRA int + LoraAffinityThreshold float64 +} + +const ( + // Default values to use if environment variables are not set + defaultKVCacheThreshold = 0.8 + defaultQueueThresholdCritical = 5 + defaultQueueingThresholdLoRA = 128 + defaultLoraAffinityThreshold = 0.999 +) + +// LoadConfig loads configuration from environment variables +func LoadConfig() Config { + // Use a default logger for initial configuration loading + baseLogger := log.Log.WithName("scheduling-config") + + config := Config{ + KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), + QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), + QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), + LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), + } + + baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) + + return config +} + +var Conf = LoadConfig() diff --git a/pkg/epp/scheduling/filter_test.go b/pkg/epp/scheduling/filter_test.go deleted file mode 100644 index 127e6c21..00000000 --- a/pkg/epp/scheduling/filter_test.go +++ /dev/null @@ -1,554 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scheduling - -import ( - "errors" - "testing" - - "github.com/go-logr/logr" - "github.com/google/go-cmp/cmp" - "k8s.io/apimachinery/pkg/types" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" -) - -func TestFilter(t *testing.T) { - logger := logutil.NewTestLogger() - - tests := []struct { - name string - req *LLMRequest - input []*backendmetrics.FakePodMetrics - output []*backendmetrics.FakePodMetrics - err bool - filter *filter - }{ - { - name: "simple filter without successor, failure", - filter: &filter{filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - return nil, errors.New("filter error") - }}, - err: true, - }, - { - name: "default filter, critical request", - filter: defaultFilter, - req: &LLMRequest{ - Model: "critical", - ResolvedTargetModel: "critical", - Critical: true, - }, - // pod2 will be picked because it has relatively low queue size, with the requested - // model being active, and has low KV cache. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - }, - }, - { - name: "default filter, sheddable request, accepted", - filter: defaultFilter, - req: &LLMRequest{ - Model: "sheddable", - ResolvedTargetModel: "sheddable", - Critical: false, - }, - // pod1 will be picked because it has capacity for the sheddable request. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.1, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - }, - }, - { - name: "default filter, sheddable request, dropped", - filter: defaultFilter, - req: &LLMRequest{ - Model: "sheddable", - ResolvedTargetModel: "sheddable", - Critical: false, - }, - // All pods have higher KV cache thant the threshold, so the sheddable request will be - // dropped. - input: []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod1"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod2"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.85, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "critical": 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "pod3"}}, - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.85, - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{}, - err: true, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := test.filter.Filter(logger, test.req, toInterface(test.input)) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } - - if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } - }) - } -} - -func TestFilterFunc(t *testing.T) { - logger := logutil.NewTestLogger() - - tests := []struct { - name string - f filterFunc - req *LLMRequest - input []*backendmetrics.FakePodMetrics - output []*backendmetrics.FakePodMetrics - err bool - }{ - { - name: "least queuing empty input", - f: leastQueuingFilterFunc, - input: []*backendmetrics.FakePodMetrics{}, - output: []*backendmetrics.FakePodMetrics{}, - }, - { - name: "least queuing", - f: leastQueuingFilterFunc, - input: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 10, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 3, - }, - }, - }, - }, - { - name: "least kv cache empty input", - f: leastKVCacheFilterFunc, - input: []*backendmetrics.FakePodMetrics{}, - output: []*backendmetrics.FakePodMetrics{}, - }, - { - name: "least kv cache", - f: leastKVCacheFilterFunc, - input: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - KVCacheUsagePercent: 0, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - KVCacheUsagePercent: 0.3, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - KVCacheUsagePercent: 1.0, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - KVCacheUsagePercent: 0, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - KVCacheUsagePercent: 0.3, - }, - }, - }, - }, - { - name: "noQueueAndLessThanKVCacheThresholdPredicate", - f: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(0, 0.8)), - input: []*backendmetrics.FakePodMetrics{ - { - // This pod should be returned. - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0, - }, - }, - { - // Queue is non zero, despite low kv cache, should not return. - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 1, - KVCacheUsagePercent: 0.3, - }, - }, - { - // High kv cache despite zero queue, should not return - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 1.0, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - WaitingQueueSize: 0, - KVCacheUsagePercent: 0, - }, - }, - }, - }, - { - name: "low LoRA cost", - f: toFilterFunc(lowLoRACostPredicate), - req: &LLMRequest{ - Model: "model", - ResolvedTargetModel: "model", - }, - input: []*backendmetrics.FakePodMetrics{ - // ActiveModels include input model, should be returned. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "model": 1, - }, - }, - }, - // Input model is not active, however the server has room to load another adapter. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "another-model": 1, - }, - }, - }, - // Input is not active, and the server has reached max active models. - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - }, - output: []*backendmetrics.FakePodMetrics{ - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "model": 1, - }, - }, - }, - { - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - "another-model": 1, - }, - }, - }, - }, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := test.f(logger, test.req, toInterface(test.input)) - if test.err != (err != nil) { - t.Errorf("Unexpected error, got %v, want %v", err, test.err) - } - - if diff := cmp.Diff(test.output, toStruct(got)); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } - }) - } -} - -// TestLoRASoftAffinityDistribution tests that the loRASoftAffinityFilter function -// properly distributes requests according to the loraAffinityThreshold -func TestLoRASoftAffinityDistribution(t *testing.T) { - logger := logutil.NewTestLogger() - - const ( - testModelName = "test-model" - testAffinityModel = "test-affinity-model" - numIterations = 10000 - tolerancePercent = 5.0 // Allow 5% tolerance from expected distribution - ) - - // Save original config value to restore later - originalThreshold := config.LoraAffinityThreshold - - // Set a specific test value for this test - testThreshold := 0.75 // 75% - config.LoraAffinityThreshold = testThreshold - - // Ensure we restore the original threshold when test completes - defer func() { - config.LoraAffinityThreshold = originalThreshold - }() - - // Create a test request and pods - req := &LLMRequest{ - Model: testAffinityModel, - ResolvedTargetModel: testAffinityModel, - } - - // Test setup: One affinity pod and one available pod - pods := []*backendmetrics.FakePodMetrics{ - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "affinity-pod"}}, - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{ - testAffinityModel: 1, - }, - }, - }, - { - Pod: &backendmetrics.Pod{NamespacedName: types.NamespacedName{Name: "available-pod"}}, - Metrics: &backendmetrics.Metrics{ - MaxActiveModels: 2, - ActiveModels: map[string]int{}, - }, - }, - } - - // Run the filter function multiple times and count the results - affinityCount := 0 - availableCount := 0 - - // Use the test threshold value - expectedAffinityPercent := config.LoraAffinityThreshold * 100 - expectedAvailabilityPercent := 100 - expectedAffinityPercent - - for i := 0; i < numIterations; i++ { - result, err := loRASoftAffinityFilter(logger, req, toInterface(pods)) - if err != nil { - t.Fatalf("Unexpected error: %v", err) - } - - // Check which type of pod was returned - if len(result) != 1 { - t.Fatalf("Expected exactly one pod in result, got %d", len(result)) - } - - // Identify if the returned pod is the affinity pod or available pod - if _, exists := result[0].GetMetrics().ActiveModels[testAffinityModel]; exists { - affinityCount++ - } else { - availableCount++ - } - } - - // Calculate the actual percentages - actualAffinityPercent := float64(affinityCount) / float64(numIterations) * 100 - actualAvailablePercent := float64(availableCount) / float64(numIterations) * 100 - - // Check if the distribution matches expected threshold within tolerance - affinityLowerBound := expectedAffinityPercent - tolerancePercent - affinityUpperBound := expectedAffinityPercent + tolerancePercent - - availableLowerBound := expectedAvailabilityPercent - tolerancePercent - availableUpperBound := expectedAvailabilityPercent + tolerancePercent - - t.Logf("Distribution results over %d iterations:", numIterations) - t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.LoraAffinityThreshold) - t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.LoraAffinityThreshold) - t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations) - t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations) - - if actualAffinityPercent < affinityLowerBound || actualAffinityPercent > affinityUpperBound { - t.Errorf("Affinity selection percent %.2f%% outside expected range %.2f%% to %.2f%%", - actualAffinityPercent, affinityLowerBound, affinityUpperBound) - } - if actualAvailablePercent < availableLowerBound || actualAvailablePercent > availableUpperBound { - t.Errorf("Availability selection percent %.2f%% outside expected range %.2f%% to %.2f%%", - actualAvailablePercent, availableLowerBound, availableUpperBound) - } -} - -func toInterface(input []*backendmetrics.FakePodMetrics) []backendmetrics.PodMetrics { - output := []backendmetrics.PodMetrics{} - for _, i := range input { - output = append(output, i) - } - return output -} - -func toStruct(input []backendmetrics.PodMetrics) []*backendmetrics.FakePodMetrics { - if input == nil { - return nil - } - output := []*backendmetrics.FakePodMetrics{} - for _, i := range input { - output = append(output, i.(*backendmetrics.FakePodMetrics)) - } - return output -} diff --git a/pkg/epp/scheduling/filter.go b/pkg/epp/scheduling/plugins/filter/filter.go similarity index 56% rename from pkg/epp/scheduling/filter.go rename to pkg/epp/scheduling/plugins/filter/filter.go index f4848089..86620aa9 100644 --- a/pkg/epp/scheduling/filter.go +++ b/pkg/epp/scheduling/plugins/filter/filter.go @@ -14,103 +14,117 @@ See the License for the specific language governing permissions and limitations under the License. */ -package scheduling +package filter import ( - "errors" "math" "math/rand" "time" - "github.com/go-logr/logr" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -type Filter interface { - Name() string - Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) +type baseFilter struct { + name string + filter filterFunc } -// filter applies current filterFunc, and then recursively applies next filters depending success or -// failure of the current filterFunc. +func (f *baseFilter) Name() string { + if f == nil { + return "nil" + } + return f.name +} + +func (f *baseFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + loggerTrace := ctx.Logger.V(logutil.TRACE) + loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) + + return f.filter(ctx, pods) +} + +// DecisionTreeFilter applies current filterFunc, and then recursively applies next filters +// depending success or failure of the current filter. // It can be used to construct a flow chart algorithm. -type filter struct { - name string - filter filterFunc - // nextOnSuccess filter will be applied after successfully applying the current filter. +type DecisionTreeFilter struct { + Current plugins.Filter + // NextOnSuccess filter will be applied after successfully applying the current filter. // The filtered results will be passed to the next filter. - nextOnSuccess *filter - // nextOnFailure filter will be applied if current filter fails. + NextOnSuccess plugins.Filter + // NextOnFailure filter will be applied if current filter fails. // The original input will be passed to the next filter. - nextOnFailure *filter - // nextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the + NextOnFailure plugins.Filter + // NextOnSuccessOrFailure is a convenience field to configure the next filter regardless of the // success or failure of the current filter. - // NOTE: When using nextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. + // NOTE: When using NextOnSuccessOrFailure, both nextOnSuccess and nextOnFailure SHOULD be nil. // However if that's not the case, nextOnSuccess and nextOnFailure will be used, instead of - // nextOnSuccessOrFailure, in the success and failure scenarios, respectively. - nextOnSuccessOrFailure *filter + // NextOnSuccessOrFailure, in the success and failure scenarios, respectively. + NextOnSuccessOrFailure plugins.Filter } -func (f *filter) Name() string { +func (f *DecisionTreeFilter) Name() string { if f == nil { return "nil" } - return f.name + return f.Current.Name() } -func (f *filter) Filter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - loggerTrace := logger.V(logutil.TRACE) - loggerTrace.Info("Running a filter", "name", f.Name(), "podCount", len(pods)) - - filtered, err := f.filter(logger, req, pods) +func (f *DecisionTreeFilter) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + loggerTrace := ctx.Logger.V(logutil.TRACE) + filtered := f.Current.Filter(ctx, pods) - next := f.nextOnSuccessOrFailure - if err == nil && len(filtered) > 0 { - if f.nextOnSuccess == nil && f.nextOnSuccessOrFailure == nil { + next := f.NextOnSuccessOrFailure + if len(filtered) > 0 { + if f.NextOnSuccess == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } - if f.nextOnSuccess != nil { - next = f.nextOnSuccess + if f.NextOnSuccess != nil { + next = f.NextOnSuccess } loggerTrace.Info("Filter succeeded", "filter", f.Name(), "next", next.Name(), "filteredPodCount", len(filtered)) // On success, pass the filtered result to the next filter. - return next.Filter(logger, req, filtered) + return next.Filter(ctx, filtered) } else { - if f.nextOnFailure == nil && f.nextOnSuccessOrFailure == nil { + if f.NextOnFailure == nil && f.NextOnSuccessOrFailure == nil { // No succeeding filters to run, return. - return filtered, err + return filtered } - if f.nextOnFailure != nil { - next = f.nextOnFailure + if f.NextOnFailure != nil { + next = f.NextOnFailure } loggerTrace.Info("Filter failed", "filter", f.Name(), "next", next.Name()) // On failure, pass the initial set of pods to the next filter. - return next.Filter(logger, req, pods) + return next.Filter(ctx, pods) } } // filterFunc filters a set of input pods to a subset. -type filterFunc func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) +type filterFunc func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod // toFilterFunc is a helper function to convert a per pod filter func to the FilterFunc. func toFilterFunc(pp podPredicate) filterFunc { - return func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - filtered := []backendmetrics.PodMetrics{} + return func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + filtered := []types.Pod{} for _, pod := range pods { - pass := pp(req, pod) + pass := pp(ctx.Req, pod) if pass { filtered = append(filtered, pod) } } - if len(filtered) == 0 { - return nil, errors.New("no pods left") - } - return filtered, nil + + return filtered } } +var LeastQueueFilter = &baseFilter{ + name: "least queuing", + filter: leastQueuingFilterFunc, +} + // leastQueuingFilterFunc finds the max and min queue size of all pods, divides the whole range // (max-min) by the number of pods, and finds the pods that fall into the first range. // The intuition is that if there are multiple pods that share similar queue size in the low range, @@ -118,10 +132,10 @@ func toFilterFunc(pp podPredicate) filterFunc { // the least one as it gives more choices for the next filter, which on aggregate gave better // results. // TODO: Compare this strategy with other strategies such as top K. -func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func leastQueuingFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxInt max := 0 - filtered := []backendmetrics.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { if pod.GetMetrics().WaitingQueueSize <= min { @@ -137,11 +151,17 @@ func leastQueuingFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendm filtered = append(filtered, pod) } } - return filtered, nil + return filtered } -func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool { - return pod.GetMetrics().WaitingQueueSize < config.QueueingThresholdLoRA +var LowQueueFilter = &baseFilter{ + name: "low queueing filter", + filter: toFilterFunc((queueThresholdPredicate(config.Conf.QueueingThresholdLoRA))), +} + +var LeastKVCacheFilter = &baseFilter{ + name: "least KV cache percent", + filter: leastKVCacheFilterFunc, } // leastKVCacheFilterFunc finds the max and min KV cache of all pods, divides the whole range @@ -150,10 +170,10 @@ func lowQueueingPodPredicate(_ *LLMRequest, pod backendmetrics.PodMetrics) bool // should consider them all instead of the absolute minimum one. This worked better than picking the // least one as it gives more choices for the next filter, which on aggregate gave better results. // TODO: Compare this strategy with other strategies such as top K. -func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func leastKVCacheFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { min := math.MaxFloat64 var max float64 = 0 - filtered := []backendmetrics.PodMetrics{} + filtered := []types.Pod{} for _, pod := range pods { if pod.GetMetrics().KVCacheUsagePercent <= min { @@ -169,20 +189,12 @@ func leastKVCacheFilterFunc(logger logr.Logger, req *LLMRequest, pods []backendm filtered = append(filtered, pod) } } - return filtered, nil + return filtered } -// podPredicate is a filter function to check whether a pod is desired. -type podPredicate func(req *LLMRequest, pod backendmetrics.PodMetrics) bool - -// We consider serving an adapter low cost it the adapter is active in the model server, or the -// model server has room to load the adapter. The lowLoRACostPredicate ensures weak affinity by -// spreading the load of a LoRA adapter across multiple pods, avoiding "pinning" all requests to -// a single pod. This gave good performance in our initial benchmarking results in the scenario -// where # of lora slots > # of lora adapters. -func lowLoRACostPredicate(req *LLMRequest, pod backendmetrics.PodMetrics) bool { - _, ok := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel] - return ok || len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels +var LoRAAffinityFilter = &baseFilter{ + name: "affinity LoRA", + filter: loRASoftAffinityFilterFunc, } // loRASoftAffinityPredicate implements a pod selection strategy that prioritizes pods @@ -201,18 +213,20 @@ func lowLoRACostPredicate(req *LLMRequest, pod backendmetrics.PodMetrics) bool { // Returns: // - Filtered slice of pod metrics based on affinity and availability // - Error if any issues occur during filtering -func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { +func loRASoftAffinityFilterFunc(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { // Pre-allocate slices with estimated capacity - filtered_affinity := make([]backendmetrics.PodMetrics, 0, len(pods)) - filtered_available := make([]backendmetrics.PodMetrics, 0, len(pods)) + filtered_affinity := make([]types.Pod, 0, len(pods)) + filtered_available := make([]types.Pod, 0, len(pods)) // Categorize pods based on affinity and availability for _, pod := range pods { + _, active := pod.GetMetrics().ActiveModels[ctx.Req.ResolvedTargetModel] + _, waiting := pod.GetMetrics().WaitingModels[ctx.Req.ResolvedTargetModel] - if _, exists := pod.GetMetrics().ActiveModels[req.ResolvedTargetModel]; exists { + if active || waiting { filtered_affinity = append(filtered_affinity, pod) - } else if len(pod.GetMetrics().ActiveModels) < pod.GetMetrics().MaxActiveModels { + } else if len(pod.GetMetrics().ActiveModels)+len(pod.GetMetrics().WaitingModels) < pod.GetMetrics().MaxActiveModels { filtered_available = append(filtered_available, pod) } } @@ -223,26 +237,42 @@ func loRASoftAffinityFilter(logger logr.Logger, req *LLMRequest, pods []backendm // If both groups have pods, use probability to select which group to return if len(filtered_affinity) > 0 && len(filtered_available) > 0 { - if randGen.Float64() < config.LoraAffinityThreshold { - return filtered_affinity, nil + if randGen.Float64() < config.Conf.LoraAffinityThreshold { + return filtered_affinity } - return filtered_available, nil + return filtered_available } // Return whichever group has pods if len(filtered_affinity) > 0 { - return filtered_affinity, nil + return filtered_affinity } - return filtered_available, nil + return filtered_available } -func criticalRequestPredicate(req *LLMRequest, _ backendmetrics.PodMetrics) bool { - return req.Critical +var HasCapacityFilter = &baseFilter{ + name: "has capacity for sheddable requests", + filter: toFilterFunc(queueThresholdPredicate(config.Conf.QueueThresholdCritical).and(kvCacheThresholdPredicate(config.Conf.KVCacheThreshold))), +} + +// podPredicate is a filter function to check whether a pod is desired. +type podPredicate func(req *types.LLMRequest, pod types.Pod) bool + +func queueThresholdPredicate(queueThreshold int) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().WaitingQueueSize <= queueThreshold + } +} + +func kvCacheThresholdPredicate(kvCacheThreshold float64) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold + } } -func noQueueAndLessThanKVCacheThresholdPredicate(queueThreshold int, kvCacheThreshold float64) podPredicate { - return func(req *LLMRequest, pod backendmetrics.PodMetrics) bool { - return pod.GetMetrics().WaitingQueueSize <= queueThreshold && pod.GetMetrics().KVCacheUsagePercent <= kvCacheThreshold +func (pp podPredicate) and(another podPredicate) podPredicate { + return func(req *types.LLMRequest, pod types.Pod) bool { + return pp(req, pod) && another(req, pod) } } diff --git a/pkg/epp/scheduling/plugins/filter/filter_test.go b/pkg/epp/scheduling/plugins/filter/filter_test.go new file mode 100644 index 00000000..2354c3ef --- /dev/null +++ b/pkg/epp/scheduling/plugins/filter/filter_test.go @@ -0,0 +1,298 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package filter + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/config" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +func TestFilter(t *testing.T) { + tests := []struct { + name string + req *types.LLMRequest + input []types.Pod + output []types.Pod + filter *DecisionTreeFilter + }{ + { + name: "simple filter without available pods", + filter: &DecisionTreeFilter{ + Current: &baseFilter{ + name: "filter all", + filter: func(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + return []types.Pod{} + }, + }, + }, + output: []types.Pod{}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.filter.Filter(ctx, test.input) + + if diff := cmp.Diff(test.output, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} + +func TestFilterFunc(t *testing.T) { + tests := []struct { + name string + f filterFunc + req *types.LLMRequest + input []types.Pod + output []types.Pod + }{ + { + name: "least queuing empty input", + f: leastQueuingFilterFunc, + input: []types.Pod{}, + output: []types.Pod{}, + }, + { + name: "least queuing", + f: leastQueuingFilterFunc, + input: []types.Pod{ + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + }, + }, + }, + output: []types.Pod{ + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + }, + }, + }, + }, + { + name: "least kv cache empty input", + f: leastKVCacheFilterFunc, + input: []types.Pod{}, + output: []types.Pod{}, + }, + { + name: "least kv cache", + f: leastKVCacheFilterFunc, + input: []types.Pod{ + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + KVCacheUsagePercent: 0, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + KVCacheUsagePercent: 0.3, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + KVCacheUsagePercent: 1.0, + }, + }, + }, + output: []types.Pod{ + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + KVCacheUsagePercent: 0, + }, + }, + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + KVCacheUsagePercent: 0.3, + }, + }, + }, + }, + { + name: "lowQueueAndLessThanKVCacheThresholdPredicate", + f: toFilterFunc(queueThresholdPredicate(0).and(kvCacheThresholdPredicate(0.8))), + input: []types.Pod{ + &types.PodMetrics{ + // This pod should be returned. + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0, + }, + }, + &types.PodMetrics{ + // Queue is non zero, despite low kv cache, should not return. + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 1, + KVCacheUsagePercent: 0.3, + }, + }, + &types.PodMetrics{ + // High kv cache despite zero queue, should not return + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 1.0, + }, + }, + }, + output: []types.Pod{ + &types.PodMetrics{ + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := types.NewSchedulingContext(context.Background(), test.req, test.input) + got := test.f(ctx, test.input) + + if diff := cmp.Diff(test.output, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} + +// TestLoRASoftAffinityDistribution tests that the loRASoftAffinityFilter function +// properly distributes requests according to the loraAffinityThreshold +func TestLoRASoftAffinityDistribution(t *testing.T) { + const ( + testModelName = "test-model" + testAffinityModel = "test-affinity-model" + numIterations = 10000 + tolerancePercent = 5.0 // Allow 5% tolerance from expected distribution + ) + + // Save original config value to restore later + originalThreshold := config.Conf.LoraAffinityThreshold + + // Set a specific test value for this test + testThreshold := 0.75 // 75% + config.Conf.LoraAffinityThreshold = testThreshold + + // Ensure we restore the original threshold when test completes + defer func() { + config.Conf.LoraAffinityThreshold = originalThreshold + }() + + // Create a test request and pods + req := &types.LLMRequest{ + Model: testAffinityModel, + ResolvedTargetModel: testAffinityModel, + } + + // Test setup: One affinity pod and one available pod + pods := []types.Pod{ + &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "affinity-pod"}}, + Metrics: &backendmetrics.Metrics{ + MaxActiveModels: 2, + ActiveModels: map[string]int{ + testAffinityModel: 1, + }, + }, + }, + &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "available-pod"}}, + Metrics: &backendmetrics.Metrics{ + MaxActiveModels: 2, + ActiveModels: map[string]int{}, + }, + }, + } + ctx := types.NewSchedulingContext(context.Background(), req, pods) + + // Run the filter function multiple times and count the results + affinityCount := 0 + availableCount := 0 + + // Use the test threshold value + expectedAffinityPercent := config.Conf.LoraAffinityThreshold * 100 + expectedAvailabilityPercent := 100 - expectedAffinityPercent + + for i := 0; i < numIterations; i++ { + result := loRASoftAffinityFilterFunc(ctx, pods) + + // Check which type of pod was returned + if len(result) != 1 { + t.Fatalf("Expected exactly one pod in result, got %d", len(result)) + } + + // Identify if the returned pod is the affinity pod or available pod + if _, exists := result[0].GetMetrics().ActiveModels[testAffinityModel]; exists { + affinityCount++ + } else { + availableCount++ + } + } + + // Calculate the actual percentages + actualAffinityPercent := float64(affinityCount) / float64(numIterations) * 100 + actualAvailablePercent := float64(availableCount) / float64(numIterations) * 100 + + // Check if the distribution matches expected threshold within tolerance + affinityLowerBound := expectedAffinityPercent - tolerancePercent + affinityUpperBound := expectedAffinityPercent + tolerancePercent + + availableLowerBound := expectedAvailabilityPercent - tolerancePercent + availableUpperBound := expectedAvailabilityPercent + tolerancePercent + + t.Logf("Distribution results over %d iterations:", numIterations) + t.Logf("Expected affinity percent: %.2f%% (threshold: %.2f)", expectedAffinityPercent, config.Conf.LoraAffinityThreshold) + t.Logf("Expected availability percent: %.2f%% (threshold: %.2f)", expectedAvailabilityPercent, config.Conf.LoraAffinityThreshold) + t.Logf("Actual affinity percent: %.2f%% (%d out of %d)", actualAffinityPercent, affinityCount, numIterations) + t.Logf("Actual available percent: %.2f%% (%d out of %d)", actualAvailablePercent, availableCount, numIterations) + + if actualAffinityPercent < affinityLowerBound || actualAffinityPercent > affinityUpperBound { + t.Errorf("Affinity selection percent %.2f%% outside expected range %.2f%% to %.2f%%", + actualAffinityPercent, affinityLowerBound, affinityUpperBound) + } + if actualAvailablePercent < availableLowerBound || actualAvailablePercent > availableUpperBound { + t.Errorf("Availability selection percent %.2f%% outside expected range %.2f%% to %.2f%%", + actualAvailablePercent, availableLowerBound, availableUpperBound) + } +} diff --git a/pkg/epp/scheduling/plugins/picker/max_score_picker.go b/pkg/epp/scheduling/plugins/picker/max_score_picker.go new file mode 100644 index 00000000..1705b7dd --- /dev/null +++ b/pkg/epp/scheduling/plugins/picker/max_score_picker.go @@ -0,0 +1,49 @@ +package picker + +import ( + "fmt" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var _ plugins.Picker = &MaxScorePicker{} + +func NewMaxScorePicker() plugins.Picker { + return &MaxScorePicker{ + random: &RandomPicker{}, + } +} + +// MaxScorePicker picks the pod with the maximum score from the list of candidates. +type MaxScorePicker struct { + random *RandomPicker +} + +// Name returns the name of the picker. +func (p *MaxScorePicker) Name() string { + return "max_score" +} + +// Pick selects the pod with the maximum score from the list of candidates. +func (p *MaxScorePicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result { + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a pod with the max score from %d candidates: %+v", len(scoredPods), scoredPods)) + + highestScorePods := []*types.ScoredPod{} + maxScore := -1.0 // pods min score is 0, putting value lower than 0 in order to find at least one pod as highest + for _, pod := range scoredPods { + if pod.Score > maxScore { + maxScore = pod.Score + highestScorePods = []*types.ScoredPod{pod} + } else if pod.Score == maxScore { + highestScorePods = append(highestScorePods, pod) + } + } + + if len(highestScorePods) > 1 { + return p.random.Pick(ctx, highestScorePods) // pick randomly from the highest score pods + } + + return &types.Result{TargetPod: highestScorePods[0]} +} diff --git a/pkg/epp/scheduling/plugins/picker/random_picker.go b/pkg/epp/scheduling/plugins/picker/random_picker.go new file mode 100644 index 00000000..fb9f9a29 --- /dev/null +++ b/pkg/epp/scheduling/plugins/picker/random_picker.go @@ -0,0 +1,41 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package picker + +import ( + "fmt" + "math/rand" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" +) + +var _ plugins.Picker = &RandomPicker{} + +// RandomPicker picks a random pod from the list of candidates. +type RandomPicker struct{} + +func (p *RandomPicker) Name() string { + return "random" +} + +func (p *RandomPicker) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result { + ctx.Logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(scoredPods), scoredPods)) + i := rand.Intn(len(scoredPods)) + return &types.Result{TargetPod: scoredPods[i]} +} diff --git a/pkg/epp/scheduling/plugins/plugins.go b/pkg/epp/scheduling/plugins/plugins.go new file mode 100644 index 00000000..f3412ab7 --- /dev/null +++ b/pkg/epp/scheduling/plugins/plugins.go @@ -0,0 +1,76 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +const ( + PreSchedulerPluginType = "PreSchedule" + FilterPluginType = "Filter" + ScorerPluginType = "Scorer" + PostSchedulePluginType = "PostSchedule" + PickerPluginType = "Picker" + PostResponsePluginType = "PostResponse" +) + +// Plugin defines the interface for scheduler plugins, combining scoring, filtering, +// and event handling capabilities. +type Plugin interface { + // Name returns the name of the plugin. + Name() string +} + +// PreSchedule is called when the scheduler receives a new request. It can be used for various +// initialization work. +type PreSchedule interface { + Plugin + PreSchedule(ctx *types.SchedulingContext) +} + +// Filter defines the interface for filtering a list of pods based on context. +type Filter interface { + Plugin + Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod +} + +// Scorer defines the interface for scoring a list of pods based on context. +// Scorers must score pods with a value within the range of [0,1] where 1 is the highest score. +type Scorer interface { + Plugin + Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 +} + +// Picker picks the final pod(s) to send the request to. +type Picker interface { + Plugin + Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result +} + +// PostSchedule is called by the scheduler after it selects a targetPod for the request. +type PostSchedule interface { + Plugin + PostSchedule(ctx *types.SchedulingContext, res *types.Result) +} + +// PostResponse is called by the scheduler after a successful response was sent. +// The given pod argument is the pod that served the request. +type PostResponse interface { + Plugin + PostResponse(ctx *types.SchedulingContext, pod types.Pod) +} diff --git a/pkg/epp/scheduling/scheduler.go b/pkg/epp/scheduling/scheduler.go index e874724d..1a1d67b5 100644 --- a/pkg/epp/scheduling/scheduler.go +++ b/pkg/epp/scheduling/scheduler.go @@ -20,140 +20,200 @@ package scheduling import ( "context" "fmt" - "math/rand" + "time" - "github.com/go-logr/logr" "sigs.k8s.io/controller-runtime/pkg/log" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" - envutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/env" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/filter" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// Config holds all the configuration values for the scheduler -type Config struct { - KVCacheThreshold float64 - QueueThresholdCritical int - QueueingThresholdLoRA int - LoraAffinityThreshold float64 -} +var ( + lowLatencyFilter = &filter.DecisionTreeFilter{ + Current: filter.LowQueueFilter, + NextOnSuccess: &filter.DecisionTreeFilter{ + Current: filter.LoRAAffinityFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastQueueFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastKVCacheFilter, + }, + }, + }, + NextOnFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastQueueFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LoRAAffinityFilter, + NextOnSuccessOrFailure: &filter.DecisionTreeFilter{ + Current: filter.LeastKVCacheFilter, + }, + }, + }, + } -const ( - // Default values to use if environment variables are not set - defaultKVCacheThreshold = 0.8 - defaultQueueThresholdCritical = 5 - defaultQueueingThresholdLoRA = 128 - defaultLoraAffinityThreshold = 0.999 + sheddableRequestFilter = &filter.DecisionTreeFilter{ + // When there is at least one model server that's not queuing requests, and still has KV + // cache below a certain threshold, we consider this model server has capacity to handle + // a sheddable request without impacting critical requests. + Current: filter.HasCapacityFilter, + NextOnSuccess: lowLatencyFilter, + // If all pods are queuing or running above the KVCache threshold, we drop the sheddable + // request to make room for critical requests. for this, we don't define nextOnFailure. + } ) -// LoadConfig loads configuration from environment variables -func LoadConfig() Config { - // Use a default logger for initial configuration loading - baseLogger := log.Log.WithName("scheduling-config") +func NewScheduler(datastore Datastore) *Scheduler { + return NewSchedulerWithConfig(datastore, defaultConfig) +} - config := Config{ - KVCacheThreshold: envutil.GetEnvFloat("KV_CACHE_THRESHOLD", defaultKVCacheThreshold, baseLogger), - QueueThresholdCritical: envutil.GetEnvInt("QUEUE_THRESHOLD_CRITICAL", defaultQueueThresholdCritical, baseLogger), - QueueingThresholdLoRA: envutil.GetEnvInt("QUEUING_THRESHOLD_LORA", defaultQueueingThresholdLoRA, baseLogger), - LoraAffinityThreshold: envutil.GetEnvFloat("LORA_AFFINITY_THRESHOLD", defaultLoraAffinityThreshold, baseLogger), +func NewSchedulerWithConfig(datastore Datastore, config *SchedulerConfig) *Scheduler { + return &Scheduler{ + datastore: datastore, + preSchedulePlugins: config.preSchedulePlugins, + filters: config.filters, + scorers: config.scorers, + picker: config.picker, + postSchedulePlugins: config.postSchedulePlugins, } +} - baseLogger.V(logutil.DEFAULT).Info("Scheduler configuration loaded", "config", config) +type Scheduler struct { + datastore Datastore + preSchedulePlugins []plugins.PreSchedule + filters []plugins.Filter + scorers map[plugins.Scorer]int // map from scorer to its weight + picker plugins.Picker + postSchedulePlugins []plugins.PostSchedule +} - return config +type Datastore interface { + PodGetAll() []backendmetrics.PodMetrics } -var config = LoadConfig() +// Schedule finds the target pod based on metrics and the requested lora adapter. +func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (*types.Result, error) { + logger := log.FromContext(ctx).WithValues("request", req) + loggerDebug := logger.V(logutil.DEBUG) -var ( - defaultFilter = &filter{ - name: "critical request", - filter: toFilterFunc(criticalRequestPredicate), - nextOnSuccess: lowLatencyFilter, - nextOnFailure: sheddableRequestFilter, - } + // Snapshot pod metrics from the datastore to: + // 1. Reduce concurrent access to the datastore. + // 2. Ensure consistent data during the scheduling operation of a request. + sCtx := types.NewSchedulingContext(ctx, req, types.ToSchedulerPodMetrics(s.datastore.PodGetAll())) + loggerDebug.Info(fmt.Sprintf("Scheduling a request, Metrics: %+v", sCtx.PodsSnapshot)) - // queueLoRAAndKVCacheFilter applied least queue -> low cost lora -> least KV Cache filter - queueLoRAAndKVCacheFilter = &filter{ - name: "least queuing", - filter: leastQueuingFilterFunc, - nextOnSuccessOrFailure: &filter{ - name: "low cost LoRA", - filter: loRASoftAffinityFilter, - nextOnSuccessOrFailure: &filter{ - name: "least KV cache percent", - filter: leastKVCacheFilterFunc, - }, - }, + s.runPreSchedulePlugins(sCtx) + + pods := s.runFilterPlugins(sCtx) + if len(pods) == 0 { + return nil, errutil.Error{Code: errutil.Internal, Msg: "no pods available for the given request"} } + // if we got here, there is at least one pod to score + weightedScorePerPod := s.runScorerPlugins(sCtx, pods) - // queueAndKVCacheFilter applies least queue followed by least KV Cache filter - queueAndKVCacheFilter = &filter{ - name: "least queuing", - filter: leastQueuingFilterFunc, - nextOnSuccessOrFailure: &filter{ - name: "least KV cache percent", - filter: leastKVCacheFilterFunc, - }, + result := s.runPickerPlugin(sCtx, weightedScorePerPod) + + s.runPostSchedulePlugins(sCtx, result) + + return result, nil +} + +func (s *Scheduler) runPreSchedulePlugins(ctx *types.SchedulingContext) { + for _, plugin := range s.preSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running pre-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PreSchedule(ctx) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PreSchedulerPluginType, plugin.Name(), time.Since(before)) } +} - lowLatencyFilter = &filter{ - name: "low queueing filter", - filter: toFilterFunc((lowQueueingPodPredicate)), - nextOnSuccess: &filter{ - name: "affinity LoRA", - filter: loRASoftAffinityFilter, - nextOnSuccessOrFailure: queueAndKVCacheFilter, - }, - nextOnFailure: queueLoRAAndKVCacheFilter, +func (s *Scheduler) runFilterPlugins(ctx *types.SchedulingContext) []types.Pod { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + filteredPods := ctx.PodsSnapshot + loggerDebug.Info("Before running filter plugins", "pods", filteredPods) + + for _, filter := range s.filters { + loggerDebug.Info("Running filter plugin", "plugin", filter.Name()) + before := time.Now() + filteredPods = filter.Filter(ctx, filteredPods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.FilterPluginType, filter.Name(), time.Since(before)) + loggerDebug.Info("Filter plugin result", "plugin", filter.Name(), "pods", filteredPods) + if len(filteredPods) == 0 { + break + } } + loggerDebug.Info("After running filter plugins") - sheddableRequestFilter = &filter{ - // When there is at least one model server that's not queuing requests, and still has KV - // cache below a certain threshold, we consider this model server has capacity to handle - // a sheddable request without impacting critical requests. - name: "has capacity for sheddable requests", - filter: toFilterFunc(noQueueAndLessThanKVCacheThresholdPredicate(config.QueueThresholdCritical, config.KVCacheThreshold)), - nextOnSuccess: queueLoRAAndKVCacheFilter, - // If all pods are queuing or running above the KVCache threshold, we drop the sheddable - // request to make room for critical requests. - nextOnFailure: &filter{ - name: "drop request", - filter: func(logger logr.Logger, req *LLMRequest, pods []backendmetrics.PodMetrics) ([]backendmetrics.PodMetrics, error) { - logger.V(logutil.DEFAULT).Info("Request dropped", "request", req) - return []backendmetrics.PodMetrics{}, errutil.Error{ - Code: errutil.InferencePoolResourceExhausted, Msg: "dropping request due to limited backend resources", - } - }, - }, + return filteredPods +} + +func (s *Scheduler) runScorerPlugins(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + loggerDebug.Info("Before running scorer plugins", "pods", pods) + + weightedScorePerPod := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + weightedScorePerPod[pod] = float64(0) // initialize weighted score per pod with 0 value } -) + // Iterate through each scorer in the chain and accumulate the weighted scores. + for scorer, weight := range s.scorers { + loggerDebug.Info("Running scorer", "scorer", scorer.Name()) + before := time.Now() + scores := scorer.Score(ctx, pods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.ScorerPluginType, scorer.Name(), time.Since(before)) + for pod, score := range scores { // weight is relative to the sum of weights + weightedScorePerPod[pod] += score * float64(weight) // TODO normalize score before multiply with weight + } + loggerDebug.Info("After running scorer", "scorer", scorer.Name()) + } + loggerDebug.Info("After running scorer plugins") -func NewScheduler(datastore datastore.Datastore) *Scheduler { - return &Scheduler{ - datastore: datastore, - filter: defaultFilter, + return weightedScorePerPod +} + +func (s *Scheduler) runPickerPlugin(ctx *types.SchedulingContext, weightedScorePerPod map[types.Pod]float64) *types.Result { + loggerDebug := ctx.Logger.V(logutil.DEBUG) + scoredPods := make([]*types.ScoredPod, len(weightedScorePerPod)) + i := 0 + for pod, score := range weightedScorePerPod { + scoredPods[i] = &types.ScoredPod{Pod: pod, Score: score} + i++ } + + loggerDebug.Info("Before running picker plugin", "pods", weightedScorePerPod) + before := time.Now() + result := s.picker.Pick(ctx, scoredPods) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PickerPluginType, s.picker.Name(), time.Since(before)) + loggerDebug.Info("After running picker plugin", "result", result) + + return result } -type Scheduler struct { - datastore datastore.Datastore - filter Filter +func (s *Scheduler) runPostSchedulePlugins(ctx *types.SchedulingContext, res *types.Result) { + for _, plugin := range s.postSchedulePlugins { + ctx.Logger.V(logutil.DEBUG).Info("Running post-schedule plugin", "plugin", plugin.Name()) + before := time.Now() + plugin.PostSchedule(ctx, res) + metrics.RecordSchedulerPluginProcessingLatency(plugins.PostSchedulePluginType, plugin.Name(), time.Since(before)) + } } -// Schedule finds the target pod based on metrics and the requested lora adapter. -func (s *Scheduler) Schedule(ctx context.Context, req *LLMRequest) (targetPod backendmetrics.PodMetrics, err error) { - logger := log.FromContext(ctx).WithValues("request", req) +type defaultPlugin struct { + picker.RandomPicker +} - podMetrics := s.datastore.PodGetAll() - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Scheduling a request. Metrics: %+v", podMetrics)) +func (p *defaultPlugin) Name() string { + return "DefaultPlugin" +} - pods, err := s.filter.Filter(logger, req, podMetrics) - if err != nil || len(pods) == 0 { - return nil, fmt.Errorf("failed to apply filter, resulted %v pods, this should never happen: %w", len(pods), err) +func (p *defaultPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + if ctx.Req.Critical { + return lowLatencyFilter.Filter(ctx, pods) } - logger.V(logutil.DEBUG).Info(fmt.Sprintf("Selecting a random pod from %d candidates: %+v", len(pods), pods)) - i := rand.Intn(len(pods)) - return pods[i], nil + + return sheddableRequestFilter.Filter(ctx, pods) } diff --git a/pkg/epp/scheduling/scheduler_test.go b/pkg/epp/scheduling/scheduler_test.go new file mode 100644 index 00000000..b44c7ac2 --- /dev/null +++ b/pkg/epp/scheduling/scheduler_test.go @@ -0,0 +1,519 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package scheduling + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" +) + +// Tests the default scheduler configuration and expected behavior. +func TestSchedule(t *testing.T) { + tests := []struct { + name string + req *types.LLMRequest + input []*backendmetrics.FakePodMetrics + wantRes *types.Result + err bool + }{ + { + name: "no pods in datastore", + req: &types.LLMRequest{ + Model: "any-model", + ResolvedTargetModel: "any-model", + Critical: true, + }, + input: []*backendmetrics.FakePodMetrics{}, + err: true, + }, + { + name: "critical request", + req: &types.LLMRequest{ + Model: "critical", + ResolvedTargetModel: "critical", + Critical: true, + }, + // pod2 will be picked because it has relatively low queue size, with the requested + // model being active, and has low KV cache. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + WaitingModels: map[string]int{}, + }, + }, + }, + }, + }, + { + name: "sheddable request, accepted", + req: &types.LLMRequest{ + Model: "sheddable", + ResolvedTargetModel: "sheddable", + Critical: false, + }, + // pod1 will be picked because it has capacity for the sheddable request. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.1, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + wantRes: &types.Result{ + TargetPod: &types.ScoredPod{ + Pod: &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 0, + KVCacheUsagePercent: 0.2, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + WaitingModels: map[string]int{}, + }, + }, + }, + }, + }, + { + name: "sheddable request, dropped", + req: &types.LLMRequest{ + Model: "sheddable", + ResolvedTargetModel: "sheddable", + Critical: false, + }, + // All pods have higher KV cache thant the threshold, so the sheddable request will be + // dropped. + input: []*backendmetrics.FakePodMetrics{ + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.9, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "bar": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 3, + KVCacheUsagePercent: 0.85, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + "critical": 1, + }, + }, + }, + { + Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}, + Metrics: &backendmetrics.Metrics{ + WaitingQueueSize: 10, + KVCacheUsagePercent: 0.85, + MaxActiveModels: 2, + ActiveModels: map[string]int{ + "foo": 1, + }, + }, + }, + }, + wantRes: nil, + err: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + scheduler := NewScheduler(&fakeDataStore{pods: test.input}) + got, err := scheduler.Schedule(context.Background(), test.req) + if test.err != (err != nil) { + t.Errorf("Unexpected error, got %v, want %v", err, test.err) + } + + if diff := cmp.Diff(test.wantRes, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + }) + } +} + +func TestSchedulePlugins(t *testing.T) { + tp1 := &TestPlugin{ + NameRes: "test1", + ScoreRes: 0.3, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}, {Name: "pod3"}}, + } + tp2 := &TestPlugin{ + NameRes: "test2", + ScoreRes: 0.8, + FilterRes: []k8stypes.NamespacedName{{Name: "pod1"}, {Name: "pod2"}}, + } + tp_filterAll := &TestPlugin{ + NameRes: "filter all", + FilterRes: []k8stypes.NamespacedName{}, + } + pickerPlugin := &TestPlugin{ + NameRes: "picker", + PickRes: k8stypes.NamespacedName{Name: "pod1"}, + } + + tests := []struct { + name string + config SchedulerConfig + input []*backendmetrics.FakePodMetrics + wantTargetPod k8stypes.NamespacedName + targetPodScore float64 + // Number of expected pods to score (after filter) + numPodsToScore int + err bool + }{ + { + name: "all plugins executed successfully, all scorers with same weight", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 1.1, + numPodsToScore: 2, + err: false, + }, + { + name: "all plugins executed successfully, different scorers weights", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp2}, + scorers: map[plugins.Scorer]int{ + tp1: 60, + tp2: 40, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + wantTargetPod: k8stypes.NamespacedName{Name: "pod1"}, + targetPodScore: 50, + numPodsToScore: 2, + err: false, + }, + { + name: "filter all", + config: SchedulerConfig{ + preSchedulePlugins: []plugins.PreSchedule{tp1, tp2}, + filters: []plugins.Filter{tp1, tp_filterAll}, + scorers: map[plugins.Scorer]int{ + tp1: 1, + tp2: 1, + }, + picker: pickerPlugin, + postSchedulePlugins: []plugins.PostSchedule{tp1, tp2}, + }, + input: []*backendmetrics.FakePodMetrics{ + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod1"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod2"}}}, + {Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod3"}}}, + }, + numPodsToScore: 0, + err: true, // no available pods to server after filter all + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Reset all plugins before each new test case. + for _, plugin := range test.config.preSchedulePlugins { + plugin.(*TestPlugin).reset() + } + for _, plugin := range test.config.filters { + plugin.(*TestPlugin).reset() + } + for plugin := range test.config.scorers { + plugin.(*TestPlugin).reset() + } + test.config.picker.(*TestPlugin).reset() + for _, plugin := range test.config.postSchedulePlugins { + plugin.(*TestPlugin).reset() + } + + // Initialize the scheduler + scheduler := NewSchedulerWithConfig(&fakeDataStore{pods: test.input}, &test.config) + + req := &types.LLMRequest{Model: "test-model"} + got, err := scheduler.Schedule(context.Background(), req) + + // Validate error state + if test.err != (err != nil) { + t.Fatalf("Unexpected error, got %v, want %v", err, test.err) + } + + if err != nil { + return + } + + // Validate output + wantPod := &types.PodMetrics{ + Pod: &backend.Pod{NamespacedName: test.wantTargetPod}, + } + wantRes := &types.Result{TargetPod: wantPod} + if diff := cmp.Diff(wantRes, got); diff != "" { + t.Errorf("Unexpected output (-want +got): %v", diff) + } + + // Validate plugin execution counts dynamically + for _, plugin := range test.config.preSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PreScheduleCallCount != 1 { + t.Errorf("Plugin %s PreSchedule() called %d times, expected 1", plugin.Name(), tp.PreScheduleCallCount) + } + } + + for _, plugin := range test.config.filters { + tp, _ := plugin.(*TestPlugin) + if tp.FilterCallCount != 1 { + t.Errorf("Plugin %s Filter() called %d times, expected 1", plugin.Name(), tp.FilterCallCount) + } + } + + for plugin := range test.config.scorers { + tp, _ := plugin.(*TestPlugin) + if tp.ScoreCallCount != 1 { + t.Errorf("Plugin %s Score() called %d times, expected 1", plugin.Name(), tp.ScoreCallCount) + } + if test.numPodsToScore != tp.NumOfScoredPods { + t.Errorf("Plugin %s Score() called with %d pods, expected %d", plugin.Name(), tp.NumOfScoredPods, test.numPodsToScore) + } + } + + tp, _ := test.config.picker.(*TestPlugin) + if tp.NumOfPickerCandidates != test.numPodsToScore { + t.Errorf("Picker plugin %s Pick() called with %d candidates, expected %d", tp.Name(), tp.NumOfPickerCandidates, tp.NumOfScoredPods) + } + if tp.PickCallCount != 1 { + t.Errorf("Picker plugin %s Pick() called %d times, expected 1", tp.Name(), tp.PickCallCount) + } + if tp.WinnderPodScore != test.targetPodScore { + t.Errorf("winnder pod score %v, expected %v", tp.WinnderPodScore, test.targetPodScore) + } + + for _, plugin := range test.config.postSchedulePlugins { + tp, _ := plugin.(*TestPlugin) + if tp.PostScheduleCallCount != 1 { + t.Errorf("Plugin %s PostSchedule() called %d times, expected 1", plugin.Name(), tp.PostScheduleCallCount) + } + } + }) + } +} + +type fakeDataStore struct { + pods []*backendmetrics.FakePodMetrics +} + +func (fds *fakeDataStore) PodGetAll() []backendmetrics.PodMetrics { + pm := make([]backendmetrics.PodMetrics, 0, len(fds.pods)) + for _, pod := range fds.pods { + pm = append(pm, pod) + } + return pm +} + +// TestPlugin is an implementation useful in unit tests. +type TestPlugin struct { + NameRes string + ScoreCallCount int + NumOfScoredPods int + ScoreRes float64 + FilterCallCount int + FilterRes []k8stypes.NamespacedName + PreScheduleCallCount int + PostScheduleCallCount int + PickCallCount int + NumOfPickerCandidates int + PickRes k8stypes.NamespacedName + WinnderPodScore float64 +} + +func (tp *TestPlugin) Name() string { return tp.NameRes } + +func (tp *TestPlugin) PreSchedule(ctx *types.SchedulingContext) { + tp.PreScheduleCallCount++ +} + +func (tp *TestPlugin) Filter(ctx *types.SchedulingContext, pods []types.Pod) []types.Pod { + tp.FilterCallCount++ + return findPods(ctx, tp.FilterRes...) + +} + +func (tp *TestPlugin) Score(ctx *types.SchedulingContext, pods []types.Pod) map[types.Pod]float64 { + tp.ScoreCallCount++ + scoredPods := make(map[types.Pod]float64, len(pods)) + for _, pod := range pods { + scoredPods[pod] += tp.ScoreRes + } + tp.NumOfScoredPods = len(scoredPods) + return scoredPods +} + +func (tp *TestPlugin) Pick(ctx *types.SchedulingContext, scoredPods []*types.ScoredPod) *types.Result { + tp.PickCallCount++ + tp.NumOfPickerCandidates = len(scoredPods) + pod := findPods(ctx, tp.PickRes)[0] + tp.WinnderPodScore = getPodScore(scoredPods, pod) + return &types.Result{TargetPod: pod} +} + +func (tp *TestPlugin) PostSchedule(ctx *types.SchedulingContext, res *types.Result) { + tp.PostScheduleCallCount++ +} + +func (tp *TestPlugin) reset() { + tp.PreScheduleCallCount = 0 + tp.FilterCallCount = 0 + tp.ScoreCallCount = 0 + tp.NumOfScoredPods = 0 + tp.PostScheduleCallCount = 0 + tp.PickCallCount = 0 + tp.NumOfPickerCandidates = 0 +} + +func findPods(ctx *types.SchedulingContext, names ...k8stypes.NamespacedName) []types.Pod { + res := []types.Pod{} + for _, pod := range ctx.PodsSnapshot { + for _, name := range names { + if pod.GetPod().NamespacedName.String() == name.String() { + res = append(res, pod) + } + } + } + return res +} + +func getPodScore(scoredPods []*types.ScoredPod, selectedPod types.Pod) float64 { + finalScore := 0.0 + for _, scoredPod := range scoredPods { + if scoredPod.GetPod().NamespacedName.String() == selectedPod.GetPod().NamespacedName.String() { + finalScore = scoredPod.Score + break + } + } + return finalScore +} diff --git a/pkg/epp/scheduling/types/types.go b/pkg/epp/scheduling/types/types.go new file mode 100644 index 00000000..4f69fae0 --- /dev/null +++ b/pkg/epp/scheduling/types/types.go @@ -0,0 +1,104 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package types + +import ( + "context" + "fmt" + + "github.com/go-logr/logr" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" + backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" +) + +// LLMRequest is a structured representation of the fields we parse out of the LLMRequest body. +type LLMRequest struct { + Model string + // Target models is a map of target model name to weight. + TargetModels map[string]int + Prompt string + // Resolved target model is the final target model after traffic split. + ResolvedTargetModel string + Critical bool +} + +func (r *LLMRequest) String() string { + return fmt.Sprintf("Model: %s, TargetModels: %v, ResolvedTargetModel: %s, Critical: %t, PromptLength: %v", r.Model, r.TargetModels, r.ResolvedTargetModel, r.Critical, len(r.Prompt)) +} + +type Pod interface { + GetPod() *backend.Pod + GetMetrics() *backendmetrics.Metrics + String() string +} + +type ScoredPod struct { + Pod + Score float64 +} + +// SchedulingContext holds contextual information during a scheduling operation. +type SchedulingContext struct { + context.Context + Logger logr.Logger + Req *LLMRequest + PodsSnapshot []Pod +} + +func (pm *PodMetrics) String() string { + if pm == nil { + return "" + } + return fmt.Sprintf("%+v", *pm) +} + +func (pm *PodMetrics) GetPod() *backend.Pod { + return pm.Pod +} + +func (pm *PodMetrics) GetMetrics() *backendmetrics.Metrics { + return pm.Metrics +} + +type PodMetrics struct { + *backend.Pod + *backendmetrics.Metrics +} + +func NewSchedulingContext(ctx context.Context, req *LLMRequest, pods []Pod) *SchedulingContext { + logger := log.FromContext(ctx).WithValues("request", req) + return &SchedulingContext{ + Context: ctx, + Logger: logger, + Req: req, + PodsSnapshot: pods, + } +} + +func ToSchedulerPodMetrics(pods []backendmetrics.PodMetrics) []Pod { + pm := make([]Pod, 0, len(pods)) + for _, pod := range pods { + pm = append(pm, &PodMetrics{Pod: pod.GetPod().Clone(), Metrics: pod.GetMetrics().Clone()}) + } + return pm +} + +// Result captures the scheduler result. +type Result struct { + TargetPod Pod +} diff --git a/pkg/epp/server/controller_manager.go b/pkg/epp/server/controller_manager.go index 41fe86a9..e5668210 100644 --- a/pkg/epp/server/controller_manager.go +++ b/pkg/epp/server/controller_manager.go @@ -22,6 +22,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" "k8s.io/client-go/rest" @@ -36,32 +37,32 @@ var scheme = runtime.NewScheme() func init() { utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(v1alpha2.AddToScheme(scheme)) + utilruntime.Must(v1alpha2.Install(scheme)) } -// DefaultManagerOptions returns the default options used to create the manager. -func DefaultManagerOptions(namespace, name string) ctrl.Options { +// defaultManagerOptions returns the default options used to create the manager. +func defaultManagerOptions(namespacedName types.NamespacedName) ctrl.Options { return ctrl.Options{ Scheme: scheme, Cache: cache.Options{ ByObject: map[client.Object]cache.ByObject{ &corev1.Pod{}: { Namespaces: map[string]cache.Config{ - namespace: {}, + namespacedName.Namespace: {}, }, }, &v1alpha2.InferencePool{}: { Namespaces: map[string]cache.Config{ - namespace: { + namespacedName.Namespace: { FieldSelector: fields.SelectorFromSet(fields.Set{ - "metadata.name": name, + "metadata.name": namespacedName.Name, }), }, }, }, &v1alpha2.InferenceModel{}: { Namespaces: map[string]cache.Config{ - namespace: {}, + namespacedName.Namespace: {}, }, }, }, @@ -70,8 +71,8 @@ func DefaultManagerOptions(namespace, name string) ctrl.Options { } // NewDefaultManager creates a new controller manager with default configuration. -func NewDefaultManager(namespace, name string, restConfig *rest.Config) (ctrl.Manager, error) { - manager, err := ctrl.NewManager(restConfig, DefaultManagerOptions(namespace, name)) +func NewDefaultManager(namespacedName types.NamespacedName, restConfig *rest.Config) (ctrl.Manager, error) { + manager, err := ctrl.NewManager(restConfig, defaultManagerOptions(namespacedName)) if err != nil { return nil, fmt.Errorf("failed to create controller manager: %v", err) } diff --git a/pkg/epp/server/runserver.go b/pkg/epp/server/runserver.go index a6c9f1d3..687a555c 100644 --- a/pkg/epp/server/runserver.go +++ b/pkg/epp/server/runserver.go @@ -35,7 +35,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/controller" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/handlers" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" ) // ExtProcServerRunner provides methods to manage an external process server. @@ -43,13 +42,13 @@ type ExtProcServerRunner struct { GrpcPort int DestinationEndpointHintMetadataNamespace string DestinationEndpointHintKey string - PoolName string - PoolNamespace string + PoolNamespacedName types.NamespacedName Datastore datastore.Datastore SecureServing bool CertPath string UseStreaming bool RefreshPrometheusMetricsInterval time.Duration + Scheduler handlers.Scheduler // This should only be used in tests. We won't need this once we don't inject metrics in the tests. // TODO:(https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/432) Cleanup @@ -73,9 +72,9 @@ func NewDefaultExtProcServerRunner() *ExtProcServerRunner { GrpcPort: DefaultGrpcPort, DestinationEndpointHintKey: DefaultDestinationEndpointHintKey, DestinationEndpointHintMetadataNamespace: DefaultDestinationEndpointHintMetadataNamespace, - PoolName: DefaultPoolName, - PoolNamespace: DefaultPoolNamespace, + PoolNamespacedName: types.NamespacedName{Name: DefaultPoolName, Namespace: DefaultPoolNamespace}, SecureServing: DefaultSecureServing, + RefreshPrometheusMetricsInterval: DefaultRefreshPrometheusMetricsInterval, // Datastore can be assigned later. } } @@ -86,23 +85,16 @@ func (r *ExtProcServerRunner) SetupWithManager(ctx context.Context, mgr ctrl.Man if err := (&controller.InferencePoolReconciler{ Datastore: r.Datastore, Client: mgr.GetClient(), - PoolNamespacedName: types.NamespacedName{ - Name: r.PoolName, - Namespace: r.PoolNamespace, - }, - Record: mgr.GetEventRecorderFor("InferencePool"), + Record: mgr.GetEventRecorderFor("InferencePool"), }).SetupWithManager(mgr); err != nil { return fmt.Errorf("failed setting up InferencePoolReconciler: %w", err) } if err := (&controller.InferenceModelReconciler{ - Datastore: r.Datastore, - Client: mgr.GetClient(), - PoolNamespacedName: types.NamespacedName{ - Name: r.PoolName, - Namespace: r.PoolNamespace, - }, - Record: mgr.GetEventRecorderFor("InferenceModel"), + Datastore: r.Datastore, + Client: mgr.GetClient(), + PoolNamespacedName: r.PoolNamespacedName, + Record: mgr.GetEventRecorderFor("InferenceModel"), }).SetupWithManager(ctx, mgr); err != nil { return fmt.Errorf("failed setting up InferenceModelReconciler: %w", err) } @@ -145,14 +137,7 @@ func (r *ExtProcServerRunner) AsRunnable(logger logr.Logger) manager.Runnable { } else { srv = grpc.NewServer() } - var extProcServer extProcPb.ExternalProcessorServer - if r.UseStreaming { - logger.Info("Using streaming extproc server") - extProcServer = handlers.NewStreamingServer(scheduling.NewScheduler(r.Datastore), r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) - } else { - logger.Info("Using standard extproc server") - extProcServer = handlers.NewServer(scheduling.NewScheduler(r.Datastore), r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) - } + extProcServer := handlers.NewStreamingServer(r.Scheduler, r.DestinationEndpointHintMetadataNamespace, r.DestinationEndpointHintKey, r.Datastore) extProcPb.RegisterExternalProcessorServer( srv, extProcServer, diff --git a/pkg/epp/util/env/env.go b/pkg/epp/util/env/env.go index 11e3bde1..0c6d1c6d 100644 --- a/pkg/epp/util/env/env.go +++ b/pkg/epp/util/env/env.go @@ -5,26 +5,25 @@ import ( "strconv" "github.com/go-logr/logr" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) // getEnvFloat gets a float64 from an environment variable with a default value func GetEnvFloat(key string, defaultVal float64, logger logr.Logger) float64 { val, exists := os.LookupEnv(key) if !exists { - logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value", + logger.Info("Environment variable not set, using default value", "key", key, "defaultValue", defaultVal) return defaultVal } floatVal, err := strconv.ParseFloat(val, 64) if err != nil { - logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as float, using default value", + logger.Info("Failed to parse environment variable as float, using default value", "key", key, "value", val, "error", err, "defaultValue", defaultVal) return defaultVal } - logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable", + logger.Info("Successfully loaded environment variable", "key", key, "value", floatVal) return floatVal } @@ -33,19 +32,30 @@ func GetEnvFloat(key string, defaultVal float64, logger logr.Logger) float64 { func GetEnvInt(key string, defaultVal int, logger logr.Logger) int { val, exists := os.LookupEnv(key) if !exists { - logger.V(logutil.VERBOSE).Info("Environment variable not set, using default value", + logger.Info("Environment variable not set, using default value", "key", key, "defaultValue", defaultVal) return defaultVal } intVal, err := strconv.Atoi(val) if err != nil { - logger.V(logutil.VERBOSE).Info("Failed to parse environment variable as int, using default value", + logger.Info("Failed to parse environment variable as int, using default value", "key", key, "value", val, "error", err, "defaultValue", defaultVal) return defaultVal } - logger.V(logutil.VERBOSE).Info("Successfully loaded environment variable", + logger.Info("Successfully loaded environment variable", "key", key, "value", intVal) return intVal } + +// GetEnvString gets a string from an environment variable with a default value +func GetEnvString(key string, defaultVal string, logger logr.Logger) string { + val, exists := os.LookupEnv(key) + if !exists { + logger.Info("Environment variable not set, using default value", + "key", key, "defaultValue", defaultVal) + return defaultVal + } + return val +} diff --git a/pkg/epp/util/env/env_test.go b/pkg/epp/util/env/env_test.go index 02513e28..105beb28 100644 --- a/pkg/epp/util/env/env_test.go +++ b/pkg/epp/util/env/env_test.go @@ -142,3 +142,64 @@ func TestGetEnvInt(t *testing.T) { }) } } + +func TestGetEnvString(t *testing.T) { + logger := testr.New(t) + + tests := []struct { + name string + key string + value string + defaultVal string + expected string + setup func() + teardown func() + }{ + { + name: "env variable exists and is valid", + key: "TEST_STR", + value: "123", + defaultVal: "default", + expected: "123", + setup: func() { + os.Setenv("TEST_STR", "123") + }, + teardown: func() { + os.Unsetenv("TEST_STR") + }, + }, + { + name: "env variable does not exist", + key: "TEST_STR_MISSING", + defaultVal: "default", + expected: "default", + setup: func() {}, + teardown: func() {}, + }, + { + name: "env variable is empty string", + key: "TEST_STR_EMPTY", + value: "", + defaultVal: "default", + expected: "", + setup: func() { + os.Setenv("TEST_STR_EMPTY", "") + }, + teardown: func() { + os.Unsetenv("TEST_STR_EMPTY") + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.setup() + defer tc.teardown() + + result := GetEnvString(tc.key, tc.defaultVal, logger.V(logutil.VERBOSE)) + if result != tc.expected { + t.Errorf("GetEnvString(%s, %s) = %s, expected %s", tc.key, tc.defaultVal, result, tc.expected) + } + }) + } +} diff --git a/pkg/epp/util/pod/pod.go b/pkg/epp/util/pod/pod.go new file mode 100644 index 00000000..4fcb948f --- /dev/null +++ b/pkg/epp/util/pod/pod.go @@ -0,0 +1,36 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package pod + +import ( + corev1 "k8s.io/api/core/v1" +) + +func IsPodReady(pod *corev1.Pod) bool { + if !pod.DeletionTimestamp.IsZero() { + return false + } + for _, condition := range pod.Status.Conditions { + if condition.Type == corev1.PodReady { + if condition.Status == corev1.ConditionTrue { + return true + } + break + } + } + return false +} diff --git a/site-src/api-types/inferencepool.md b/site-src/api-types/inferencepool.md index baa604b6..1494d314 100644 --- a/site-src/api-types/inferencepool.md +++ b/site-src/api-types/inferencepool.md @@ -7,28 +7,56 @@ ## Background -The InferencePool resource is a logical grouping of compute resources, e.g. Pods, that run model servers. The InferencePool would deploy its own routing, and offer administrative configuration to the Platform Admin. +The **InferencePool** API defines a group of Pods (containers) dedicated to serving AI models. Pods within an InferencePool share the same compute configuration, accelerator type, base language model, and model server. This abstraction simplifies the management of AI model serving resources, providing a centralized point of administrative configuration for Platform Admins. -It is expected for the InferencePool to: +An InferencePool is expected to be bundled with an [Endpoint Picker](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp) extension. This extension is responsible for tracking key metrics on each model server (i.e. the KV-cache utilization, queue length of pending requests, active LoRA adapters, etc.) and routing incoming inference requests to the optimal model server replica based on these metrics. An EPP can only be associated with a single InferencePool. The associated InferencePool is specified by the [poolName](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/config/manifests/inferencepool-resources.yaml#L54) and [poolNamespace](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/main/config/manifests/inferencepool-resources.yaml#L56) flags. An HTTPRoute can have multiple backendRefs that reference the same InferencePool and therefore routes to the same EPP. An HTTPRoute can have multiple backendRefs that reference different InferencePools and therefore routes to different EPPs. - - Enforce fair consumption of resources across competing workloads - - Efficiently route requests across shared compute (as displayed by the PoC) - -It is _not_ expected for the InferencePool to: +Additionally, any Pod that seeks to join an InferencePool would need to support the [model server protocol](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/003-model-server-protocol), defined by this project, to ensure the Endpoint Picker has adequate information to intelligently route requests. - - Enforce any common set of adapters or base models are available on the Pods - - Manage Deployments of Pods within the Pool - - Manage Pod lifecycle of pods within the pool +## How to Configure an InferencePool -Additionally, any Pod that seeks to join an InferencePool would need to support a protocol, defined by this project, to ensure the Pool has adequate information to intelligently route requests. +The full spec of the InferencePool is defined [here](/reference/spec/#inferencepool). -`InferencePool` has some small overlap with `Service`, displayed here: +In summary, the InferencePoolSpec consists of 3 major parts: + +- The `selector` field specifies which Pods belong to this pool. The labels in this selector must exactly match the labels applied to your model server Pods. +- The `targetPortNumber` field defines the port number that the Inference Gateway should route to on model server Pods that belong to this pool. +- The `extensionRef` field references the [endpoint picker extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp) (EPP) service that monitors key metrics from model servers within the InferencePool and provides intelligent routing decisions. + +### Example Configuration + +Here is an example InferencePool configuration: + +``` +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + name: vllm-llama3-8b-instruct +spec: + targetPortNumber: 8000 + selector: + app: vllm-llama3-8b-instruct + extensionRef: + name: vllm-llama3-8b-instruct-epp + port: 9002 + failureMode: FailClose +``` + +In this example: + +- An InferencePool named `vllm-llama3-8b-instruct` is created in the `default` namespace. +- It will select Pods that have the label `app: vllm-llama3-8b-instruct`. +- Traffic routed to this InferencePool will call out to the EPP service `vllm-llama3-8b-instruct-epp` on port `9002` for making routing decisions. If EPP fails to pick an endpoint, or is not responsive, the request will be dropped. +- Traffic routed to this InferencePool will be forwarded to the port `8000` on the selected Pods. + +## Overlap with Service + +**InferencePool** has some small overlap with **Service**, displayed here: Comparing InferencePool with Service -The InferencePool is _not_ intended to be a mask of the Service object, simply exposing the absolute bare minimum required to allow the Platform Admin to focus less on networking, and more on Pool management. - -## Spec +The InferencePool is not intended to be a mask of the Service object. It provides a specialized abstraction tailored for managing and routing traffic to groups of LLM model servers, allowing Platform Admins to focus on pool-level management rather than low-level networking details. -The full spec of the InferencePool is defined [here](/reference/spec/#inferencepool). \ No newline at end of file +## Replacing an InferencePool +Please refer to the [Replacing an InferencePool](/guides/replacing-inference-pool) guide for details on uses cases and how to replace an InferencePool. diff --git a/site-src/guides/adapter-rollout.md b/site-src/guides/adapter-rollout.md index fdf62c3a..4e7a3667 100644 --- a/site-src/guides/adapter-rollout.md +++ b/site-src/guides/adapter-rollout.md @@ -18,28 +18,28 @@ Modify the LoRA syncer ConfigMap to initiate loading of the new adapter version. ```bash - kubectl edit configmap vllm-llama3-8b-instruct-adapters +kubectl edit configmap vllm-llama3-8b-instruct-adapters ``` Change the ConfigMap to match the following (note the new entry under models): ```yaml - apiVersion: v1 - kind: ConfigMap - metadata: - name: vllm-llama3-8b-instruct-adapters - data: - configmap.yaml: | - vLLMLoRAConfig: - name: vllm-llama3-8b-instruct-adapters - port: 8000 - defaultBaseModel: meta-llama/Llama-3.1-8B-Instruct - ensureExist: - models: - - id: food-review-1 - source: Kawon/llama3.1-food-finetune_v14_r8 - - id: food-review-2 - source: Kawon/llama3.1-food-finetune_v14_r8 +apiVersion: v1 +kind: ConfigMap +metadata: + name: vllm-llama3-8b-instruct-adapters +data: + configmap.yaml: | + vLLMLoRAConfig: + name: vllm-llama3-8b-instruct-adapters + port: 8000 + defaultBaseModel: meta-llama/Llama-3.1-8B-Instruct + ensureExist: + models: + - id: food-review-1 + source: Kawon/llama3.1-food-finetune_v14_r8 + - id: food-review-2 + source: Kawon/llama3.1-food-finetune_v14_r8 ``` The new adapter version is applied to the model servers live, without requiring a restart. @@ -51,35 +51,34 @@ Modify the InferenceModel to configure a canary rollout with traffic splitting. ```bash - kubectl edit inferencemodel food-review +kubectl edit inferencemodel food-review ``` Change the targetModels list in InferenceModel to match the following: ```yaml -apiVersion: inference.networking.x-k8s.io/v1alpha1 +apiVersion: inference.networking.x-k8s.io/v1alpha2 kind: InferenceModel metadata: - name: inferencemodel-sample + name: food-review spec: modelName: food-review - criticality: Critical + criticality: Standard poolRef: - name: vllm-llama3-8b-instruct-pool + name: vllm-llama3-8b-instruct targetModels: - name: food-review-1 weight: 90 - name: food-review-2 weight: 10 - ``` The above configuration means one in every ten requests should be sent to the new version. Try it out: 1. Get the gateway IP: ```bash -IP=$(kubectl get gateway/inference-gateway -o jsonpath='{.status.addresses[0].value}'); PORT=8081 +IP=$(kubectl get gateway/inference-gateway -o jsonpath='{.status.addresses[0].value}'); PORT=80 ``` 2. Send a few requests as follows: @@ -98,34 +97,41 @@ curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ Modify the InferenceModel to direct 100% of the traffic to the latest version of the adapter. ```yaml -model: - name: food-review - targetModels: - targetModelName: food-review-2 - weight: 100 +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferenceModel +metadata: + name: food-review +spec: + modelName: food-review + criticality: Standard + poolRef: + name: vllm-llama3-8b-instruct + targetModels: + - name: food-review-2 + weight: 100 ``` Unload the older versions from the servers by updating the LoRA syncer ConfigMap to list the older version under the `ensureNotExist` list: ```yaml - apiVersion: v1 - kind: ConfigMap - metadata: - name: dynamic-lora-config - data: - configmap.yaml: | - vLLMLoRAConfig: - name: sql-loras-llama - port: 8000 - defaultBaseModel: meta-llama/Llama-3.1-8B-Instruct - ensureExist: - models: - - id: food-review-2 - source: Kawon/llama3.1-food-finetune_v14_r8 - ensureNotExist: - models: - - id: food-review-1 - source: Kawon/llama3.1-food-finetune_v14_r8 +apiVersion: v1 +kind: ConfigMap +metadata: + name: vllm-llama3-8b-instruct-adapters +data: + configmap.yaml: | + vLLMLoRAConfig: + name: vllm-llama3-8b-instruct-adapters + port: 8000 + defaultBaseModel: meta-llama/Llama-3.1-8B-Instruct + ensureExist: + models: + - id: food-review-2 + source: Kawon/llama3.1-food-finetune_v14_r8 + ensureNotExist: + models: + - id: food-review-1 + source: Kawon/llama3.1-food-finetune_v14_r8 ``` With this, all requests should be served by the new adapter version. diff --git a/site-src/guides/implementers.md b/site-src/guides/implementers.md index 5d1c6267..7bfd536a 100644 --- a/site-src/guides/implementers.md +++ b/site-src/guides/implementers.md @@ -1,3 +1,113 @@ # Implementer's Guide -TODO \ No newline at end of file +This guide is intended for developers looking to implement support for the InferencePool custom resources within their Gateway API controller. It outlines how InferencePool fits into the existing resource model, discusses implementation options, explains how to interact with extensions, and provides guidance on testing. + +## InferencePool as a Gateway Backend +Before we dive into the implementation, let’s recap how an InferencePool works. + +Overview of API integration + +**InferencePool** represents a set of Inference-focused Pods and an extension that will be used to route to them. The InferencePool introduces a new type of backend within the Gateway API resource model. Instead of targeting Services, a Gateway can route traffic to an InferencePool. This InferencePool then becomes responsible for intelligent routing to the underlying model server pods based on the associated InferenceModel configurations. + +Here is an example of how to route traffic to an InferencePool using an HTTPRoute: +``` +apiVersion: gateway.networking.k8s.io/v1 +kind: HTTPRoute +metadata: + name: llm-route +spec: + parentRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: inference-gateway + rules: + - backendRefs: + - group: inference.networking.x-k8s.io + kind: InferencePool + name: base-model + matches: + - path: + type: PathPrefix + value: / +``` + +Note that the `rules.backendRefs` describes which InferencePool should receive the forwarded traffic when the path matches the corresponding path prefix. This is very similar to how we configure a Gateway with an HTTPRoute that directs traffic to a Service (a way to select Pods and specify a port). By using the InferencePool, it provides an abstraction over a set of compute resources (model server pods), and allows the controller to implement specialized routing strategies for these inference workloads. + +## Building the Gateway controller +The general idea of implementing a Gateway controller supporting the InferencePool involves two major steps: + +1. Tracking the endpoints for InferencePool backends +2. Callout to an extension to make intelligent routing decisions + +### Endpoint Tracking +Consider a simple inference pool like this: +``` +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + name: vllm-llama3-8b-instruct +spec: + targetPortNumber: 8000 + selector: + app: vllm-llama3-8b-instruct + extensionRef: + name: vllm-llama3-8b-instruct-epp +``` + +There are mainly two options for how to treat the Inference Pool in your controller. + +**Option 1: Shadow Service Creation** + +If your Gateway controller already handles Service as a backend, you can choose to create a headless Service that mirrors the endpoints defined by the InferencePool, like this: + +``` +apiVersion: v1 +kind: Service +metadata: + name: vllm-llama3-8b-instruct-shadow-service +spec: + ports: + - port: 54321 + protocol: TCP + targetPort: 8000 + selector: + app: vllm-llama3-8b-instruct + type: ClusterIP + clusterIP: None +``` + +The gateway controller would then treat this shadow service just like any other backend service it routes traffic to. + +This approach likely allows you to leverage existing service discovery, healthcheck infrastructure, and load balancing mechanisms that your controller already supports. However, it does come with the overhead of managing additional Service objects, and hence may affect the overall latency of the reconciliation of the Gateways. + +**Option 2: Tracking InferencePool Endpoints Separately** + +You can also choose to directly select and monitor the endpoints belonging to the InferencePool. For the simple inference pool example we have above, the controller would use the label `app: vllm-llama3-8b-instruct` to discover the pods matching the criteria, and get their endpoints (i.e. IP and port number). It would then need to monitor these pods for health and availability. + +With this approach, you can tailor the endpoint tracking and routing logic specifically to the characteristics and requirements of your InferencePool. + +### Callout Extension + +The [Endpoint Picker](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp), or EPP, is a core component of the inference extension. The primary interaction for routing requests is defined between the proxy (e.g., Envoy) and the EPP using the Envoy [external processing service protocol](https://www.envoyproxy.io/docs/envoy/latest/api-v3/service/ext_proc/v3/external_processor.proto). See the [Endpoint Picker Protocol](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/004-endpoint-picker-protocol) for more information. + +#### How to Callout to EPP + +For each HTTP request, the proxy CAN communicate the subset of endpoints the EPP MUST pick from by setting `x-gateway-destination-endpoint-subset` key in the filter metadata field of the ext-proc request. If this key is set, the EPP must select from this endpoint list. If the list is empty or no endpoints are eligible, it should return a 503 error. If the key isn't set, the EPP selects from the endpoints defined by the InferencePool selector. + +#### Response from the extension + +The EPP communicates the chosen endpoint to the proxy via the `x-gateway-destination-endpoint` HTTP header and the `dynamic_metadata` field of the ext-proc response. Failure to communicate the endpoint using both methods results in a 503 error if no endpoints are ready, or a 429 error if the request should be dropped. The header and metadata values must match. In addition to the chosen endpoint, a single fallback endpoint CAN be set using the key `x-gateway-destination-endpoint-fallback` in the same metadata namespace as one used for `x-gateway-destination-endpoint`. + +## Testing Tips + +Here are some tips for testing your controller end-to-end: + +- **Focus on Key Scenarios**: Add common scenarios like creating, updating, and deleting InferencePool resources, as well as different routing rules that target InferencePool backends. +- **Verify Routing Behaviors**: Design more complex routing scenarios and verify that requests are correctly routed to the appropriate model server pods within the InferencePool based on the InferenceModel configuration. +- **Test Error Handling**: Verify that the controller correctly handles scenarios like unsupported model names or resource constraints (if criticality-based shedding is implemented). Test with state transitions (such as constant requests while Pods behind EPP are being replaced and Pods behind InferencePool are being replaced) to ensure that the system is resilient to failures and can automatically recover by redirecting traffic to healthy Pods. +- **Using Reference EPP Implementation + Echoserver**: You can use the [reference EPP implementation](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp) for testing your controller end-to-end. Instead of a full-fledged model server, a simple mock server (like the [echoserver](https://github.com/kubernetes-sigs/ingress-controller-conformance/tree/master/images/echoserver)) can be very useful for verifying routing to ensure the correct pod received the request. +- **Performance Test**: Run end-to-end [benchmarks](https://gateway-api-inference-extension.sigs.k8s.io/performance/benchmark/) to make sure that your inference gateway can achieve the latency target that is desired. + +### Conformance Tests + +A set of conformance tests will be developed soon to help verify that a controller is working as expected. This guide will be updated once we have more information. Stay tuned! diff --git a/site-src/guides/index.md b/site-src/guides/index.md index 7fdb211c..bcd1068d 100644 --- a/site-src/guides/index.md +++ b/site-src/guides/index.md @@ -58,7 +58,7 @@ This quickstart guide is intended for engineers familiar with k8s and model serv === "Latest Release" ```bash - VERSION=v0.2.0 + VERSION=v0.3.0 kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/releases/download/$VERSION/manifests.yaml ``` @@ -70,8 +70,7 @@ This quickstart guide is intended for engineers familiar with k8s and model serv ### Deploy InferenceModel - Deploy the sample InferenceModel which is configured to load balance traffic between the `food-review-0` and `food-review-1` - [LoRA adapters](https://docs.vllm.ai/en/latest/features/lora.html) of the sample model server. + Deploy the sample InferenceModel which is configured to forward traffic to the `food-review-1` [LoRA adapter](https://docs.vllm.ai/en/latest/features/lora.html) of the sample model server. ```bash kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/inferencemodel.yaml @@ -120,9 +119,9 @@ This quickstart guide is intended for engineers familiar with k8s and model serv 5. Given that the default connection timeout may be insufficient for most inference workloads, it is recommended to configure a timeout appropriate for your intended use case. - ```bash - kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/gateway/gke/gcp-backend-policy.yaml - ``` + ```bash + kubectl apply -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/gateway/gke/gcp-backend-policy.yaml + ``` === "Istio" @@ -201,7 +200,7 @@ This quickstart guide is intended for engineers familiar with k8s and model serv 2. Set the Kgateway version and install the Kgateway CRDs. ```bash - KGTW_VERSION=v2.0.0-rc.2 + KGTW_VERSION=v2.0.0 helm upgrade -i --create-namespace --namespace kgateway-system --version $KGTW_VERSION kgateway-crds oci://cr.kgateway.dev/kgateway-dev/charts/kgateway-crds ``` @@ -240,24 +239,40 @@ This quickstart guide is intended for engineers familiar with k8s and model serv Wait until the gateway is ready. - ```bash - IP=$(kubectl get gateway/inference-gateway -o jsonpath='{.status.addresses[0].value}') - PORT=80 - - curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ - "model": "food-review", - "prompt": "Write as if you were a critic: San Francisco", - "max_tokens": 100, - "temperature": 0 - }' - ``` +=== "GPU-Based Model Server" + + ```bash + IP=$(kubectl get gateway/inference-gateway -o jsonpath='{.status.addresses[0].value}') + PORT=80 + + curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ + "model": "food-review", + "prompt": "Write as if you were a critic: San Francisco", + "max_tokens": 100, + "temperature": 0 + }' + ``` + +=== "CPU-Based Model Server" + + ```bash + IP=$(kubectl get gateway/inference-gateway -o jsonpath='{.status.addresses[0].value}') + PORT=80 + + curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ + "model": "Qwen/Qwen2.5-1.5B-Instruct", + "prompt": "Write as if you were a critic: San Francisco", + "max_tokens": 100, + "temperature": 0 + }' + ``` ### Cleanup - The following cleanup assumes you would like to clean ALL resources that were created in this quickstart guide. + The following instructions assume you would like to cleanup ALL resources that were created in this quickstart guide. Please be careful not to delete resources you'd like to keep. - 1. Uninstall the Inference Pool + 1. Uninstall the InferencePool, InferenceModel, and model server resources ```bash kubectl delete -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/inferencepool-resources.yaml --ignore-not-found @@ -267,7 +282,7 @@ This quickstart guide is intended for engineers familiar with k8s and model serv kubectl delete secret hf-token --ignore-not-found ``` - 1. Uninstall the Gateway + 1. Uninstall the Gateway API resources ```bash kubectl delete -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/gateway/gke/gateway.yaml --ignore-not-found @@ -281,8 +296,40 @@ This quickstart guide is intended for engineers familiar with k8s and model serv kubectl delete -f https://github.com/kubernetes-sigs/gateway-api-inference-extension/raw/main/config/manifests/gateway/kgateway/httproute.yaml --ignore-not-found ``` - 1. Uninstall the CRDs + 1. Uninstall the Gateway API Inference Extension CRDs ```bash kubectl delete -k https://github.com/kubernetes-sigs/gateway-api-inference-extension/config/crd --ignore-not-found ``` + + 1. Choose one of the following options to cleanup the Inference Gateway. + +=== "GKE" + + **TODO** + +=== "Istio" + + **TODO** + +=== "Kgateway" + + The following instructions assume you would like to cleanup ALL Kgateway resources that were created in this quickstart guide. + + 1. Uninstall Kgateway + + ```bash + helm uninstall kgateway -n kgateway-system + ``` + + 1. Uninstall the Kgateway CRDs. + + ```bash + helm uninstall kgateway-crds -n kgateway-system + ``` + + 1. Remove the Kgateway namespace. + + ```bash + kubectl delete ns kgateway-system + ``` diff --git a/site-src/guides/metrics.md b/site-src/guides/metrics.md index a781f721..ab3ba3fd 100644 --- a/site-src/guides/metrics.md +++ b/site-src/guides/metrics.md @@ -26,6 +26,7 @@ curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ | inference_model_request_total | Counter | The counter of requests broken out for each model. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | | inference_model_request_error_total | Counter | The counter of requests errors broken out for each model. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | | inference_model_request_duration_seconds | Distribution | Distribution of response latency. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | +| normalized_time_per_output_token_seconds | Distribution | Distribution of ntpot (response latency per output token) | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | | inference_model_request_sizes | Distribution | Distribution of request size in bytes. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | | inference_model_response_sizes | Distribution | Distribution of response size in bytes. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | | inference_model_input_tokens | Distribution | Distribution of input token count. | `model_name`=<model-name>
`target_model_name`=<target-model-name> | ALPHA | @@ -34,6 +35,8 @@ curl -i ${IP}:${PORT}/v1/completions -H 'Content-Type: application/json' -d '{ | inference_pool_average_kv_cache_utilization | Gauge | The average kv cache utilization for an inference server pool. | `name`=<inference-pool-name> | ALPHA | | inference_pool_average_queue_size | Gauge | The average number of requests pending in the model server queue. | `name`=<inference-pool-name> | ALPHA | | inference_pool_ready_pods | Gauge | The number of ready pods for an inference server pool. | `name`=<inference-pool-name> | ALPHA | +| inference_extension_info | Gauge | The general information of the current build. | `commit`=<hash-of-the-build> | ALPHA | + ## Scrape Metrics diff --git a/site-src/guides/replacing-inference-pool.md b/site-src/guides/replacing-inference-pool.md new file mode 100644 index 00000000..21294570 --- /dev/null +++ b/site-src/guides/replacing-inference-pool.md @@ -0,0 +1,59 @@ +# Replacing an InferencePool + +## Background + +Replacing an InferencePool is a powerful technique for performing various infrastructure and model updates with minimal disruption and built-in rollback capabilities. This method allows you to introduce changes incrementally, monitor their impact, and revert to the previous state if necessary. + +## Use Cases +Use Cases for Replacing an InferencePool: + +- Upgrading or replacing your model server framework +- Upgrading or replacing your base model +- Transitioning to new hardware + +## How to replace an InferencePool + +To replacing an InferencePool: + +1. **Deploy new infrastructure**: Create a new InferencePool configured with the new hardware / model server / base model that you chose. +1. **Configure traffic splitting**: Use an HTTPRoute to split traffic between the existing InferencePool and the new InferencePool. The `backendRefs.weight` field controls the traffic percentage allocated to each pool. +1. **Maintain InferenceModel integrity**: Keep your InferenceModel configuration unchanged. This ensures that the system applies the same LoRA adapters consistently across both base model versions. +1. **Preserve rollback capability**: Retain the original nodes and InferencePool during the roll out to facilitate a rollback if necessary. + +### Example + +You start with an existing lnferencePool named `llm-pool-v1`. To replace the original InferencePool, you create a new InferencePool named `llm-pool-v2`. By configuring an **HTTPRoute**, as shown below, you can incrementally split traffic between the original `llm-pool-v1` and new `llm-pool-v2`. + +1. Save the following sample manifest as `httproute.yaml`: + + ```yaml + apiVersion: gateway.networking.k8s.io/v1 + kind: HTTPRoute + metadata: + name: llm-route + spec: + parentRefs: + - group: gateway.networking.k8s.io + kind: Gateway + name: inference-gateway + rules: + backendRefs: + - group: inference.networking.x-k8s.io + kind: InferencePool + name: llm-pool-v1 + weight: 90 + - group: inference.networking.x-k8s.io + kind: InferencePool + name: llm-pool-v2 + weight: 10 + ``` + +1. Apply the sample manifest to your cluster: + + ``` + kubectl apply -f httproute.yaml + ``` + + The original `llm-pool-v1` InferencePool receives most of the traffic, while the `llm-pool-v2` InferencePool receives the rest. + +1. Increase the traffic weight gradually for the `llm-pool-v2` InferencePool to complete the new InferencePool roll out. diff --git a/site-src/images/favicon-64.png b/site-src/images/favicon-64.png new file mode 100644 index 00000000..f2bd3d64 Binary files /dev/null and b/site-src/images/favicon-64.png differ diff --git a/site-src/images/logo/logo-text-xl-dark.png b/site-src/images/logo/logo-text-xl-dark.png new file mode 100644 index 00000000..4d878e5c Binary files /dev/null and b/site-src/images/logo/logo-text-xl-dark.png differ diff --git a/site-src/implementations.md b/site-src/implementations.md deleted file mode 100644 index 89acb436..00000000 --- a/site-src/implementations.md +++ /dev/null @@ -1,56 +0,0 @@ -# Implementations - -This project has several implementations that are planned or in progress: - -* [Envoy Gateway][1] -* [Kgateway][2] -* [Google Kubernetes Engine][3] - -[1]:#envoy-gateway -[2]:#kgateway -[3]:#google-kubernetes-engine - -## Envoy Gateway - -[Envoy Gateway][eg-home] is an [Envoy][envoy-org] subproject for managing -Envoy-based application gateways. The supported APIs and fields of the Gateway -API are outlined [here][eg-supported]. Use the [quickstart][eg-quickstart] to -get Envoy Gateway running with Gateway API in a few simple steps. - -Progress towards supporting this project is tracked with a [GitHub -Issue](https://github.com/envoyproxy/gateway/issues/4423). - -[eg-home]:https://gateway.envoyproxy.io/ -[envoy-org]:https://github.com/envoyproxy -[eg-supported]:https://gateway.envoyproxy.io/docs/tasks/quickstart/ -[eg-quickstart]:https://gateway.envoyproxy.io/docs/tasks/quickstart - -## Kgateway - -[Kgateway](https://kgateway.dev/) is a feature-rich, Kubernetes-native -ingress controller and next-generation API gateway. Kgateway brings the -full power and community support of Gateway API to its existing control-plane -implementation. - -Progress towards supporting this project is tracked with a [GitHub -Issue](https://github.com/kgateway-dev/kgateway/issues/10411). - -## Google Kubernetes Engine - -[Google Kubernetes Engine (GKE)][gke] is a managed Kubernetes platform offered -by Google Cloud. GKE's implementation of the Gateway API is through the [GKE -Gateway controller][gke-gateway] which provisions Google Cloud Load Balancers -for Pods in GKE clusters. - -The GKE Gateway controller supports weighted traffic splitting, mirroring, -advanced routing, multi-cluster load balancing and more. See the docs to deploy -[private or public Gateways][gke-gateway-deploy] and also [multi-cluster -Gateways][gke-multi-cluster-gateway]. - -Progress towards supporting this project is tracked with a [GitHub -Issue](https://github.com/GoogleCloudPlatform/gke-gateway-api/issues/20). - -[gke]:https://cloud.google.com/kubernetes-engine -[gke-gateway]:https://cloud.google.com/kubernetes-engine/docs/concepts/gateway-api -[gke-gateway-deploy]:https://cloud.google.com/kubernetes-engine/docs/how-to/deploying-gateways -[gke-multi-cluster-gateway]:https://cloud.google.com/kubernetes-engine/docs/how-to/deploying-multi-cluster-gateways diff --git a/site-src/implementations/gateways.md b/site-src/implementations/gateways.md new file mode 100644 index 00000000..950c0833 --- /dev/null +++ b/site-src/implementations/gateways.md @@ -0,0 +1,88 @@ +# Gateway Implementations + +This project has several implementations that are planned or in progress: + +* [Envoy AI Gateway][1] +* [Kgateway][2] +* [Google Kubernetes Engine][3] +* [Istio][4] +* [Alibaba Cloud Container Service for Kubernetes][5] + +[1]:#envoy-gateway +[2]:#kgateway +[3]:#google-kubernetes-engine +[4]:#istio +[5]:#alibaba-cloud-container-service-for-kubernetes + +## Envoy AI Gateway + +[Envoy AI Gateway][aigw-home] is an open source project built on top of +[Envoy][envoy-org] and [Envoy Gateway][envoy-gateway] to handle request traffic +from application clients to GenAI services. The features and capabilities are outlined [here][aigw-capabilities]. Use the [quickstart][aigw-quickstart] to get Envoy AI Gateway running with Gateway API in a few simple steps. + +Progress towards supporting this project is tracked with a [GitHub +Issue](https://github.com/envoyproxy/ai-gateway/issues/423). + +[aigw-home]:https://aigateway.envoyproxy.io/ +[envoy-org]:https://github.com/envoyproxy +[envoy-gateway]: https://gateway.envoyproxy.io/ +[aigw-capabilities]:https://aigateway.envoyproxy.io/docs/capabilities/ +[aigw-quickstart]:https://aigateway.envoyproxy.io/docs/capabilities/gateway-api-inference-extension + +## Kgateway + +[Kgateway](https://kgateway.dev/) is a feature-rich, Kubernetes-native +ingress controller and next-generation API gateway. Kgateway brings the +full power and community support of Gateway API to its existing control-plane +implementation. + +Progress towards supporting this project is tracked with a [GitHub +Issue](https://github.com/kgateway-dev/kgateway/issues/10411). + +## Google Kubernetes Engine + +[Google Kubernetes Engine (GKE)][gke] is a managed Kubernetes platform offered +by Google Cloud. GKE's implementation of the Gateway API is through the [GKE +Gateway controller][gke-gateway] which provisions Google Cloud Load Balancers +for Pods in GKE clusters. + +The GKE Gateway controller supports weighted traffic splitting, mirroring, +advanced routing, multi-cluster load balancing and more. See the docs to deploy +[private or public Gateways][gke-gateway-deploy] and also [multi-cluster +Gateways][gke-multi-cluster-gateway]. + +Progress towards supporting this project is tracked with a [GitHub +Issue](https://github.com/GoogleCloudPlatform/gke-gateway-api/issues/20). + +[gke]:https://cloud.google.com/kubernetes-engine +[gke-gateway]:https://cloud.google.com/kubernetes-engine/docs/concepts/gateway-api +[gke-gateway-deploy]:https://cloud.google.com/kubernetes-engine/docs/how-to/deploying-gateways +[gke-multi-cluster-gateway]:https://cloud.google.com/kubernetes-engine/docs/how-to/deploying-multi-cluster-gateways + +## Istio + +[Istio](https://istio.io/) is an open source service mesh and gateway implementation. +It provides a fully compliant implementation of the Kubernetes Gateway API for cluster ingress traffic control. +For service mesh users, Istio also fully supports east-west (including [GAMMA](https://gateway-api.sigs.k8s.io/mesh/)) traffic management within the mesh. + +Gateway API Inference Extension support is being tracked by this [GitHub +Issue](https://github.com/istio/istio/issues/55768). + +## Alibaba Cloud Container Service for Kubernetes + +[Alibaba Cloud Container Service for Kubernetes (ACK)][ack] is a managed Kubernetes platform +offered by Alibaba Cloud. The implementation of the Gateway API in ACK is through the +[ACK Gateway with Inference Extension][ack-gie] component, which introduces model-aware, +GPU-efficient load balancing for AI workloads beyond basic HTTP routing. + +The ACK Gateway with Inference Extension implements the Gateway API Inference Extension +and provides optimized routing for serving generative AI workloads, +including weighted traffic splitting, mirroring, advanced routing, etc. +See the docs for the [usage][ack-gie-usage]. + +Progress towards supporting Gateway API Inference Extension is being tracked +by [this Issue](https://github.com/AliyunContainerService/ack-gateway-api/issues/1). + +[ack]:https://www.alibabacloud.com/help/en/ack +[ack-gie]:https://www.alibabacloud.com/help/en/ack/product-overview/ack-gateway-with-inference-extension +[ack-gie-usage]:https://www.alibabacloud.com/help/en/ack/ack-managed-and-ack-dedicated/user-guide/intelligent-routing-and-traffic-management-with-ack-gateway-inference-extension \ No newline at end of file diff --git a/site-src/implementations/model-servers.md b/site-src/implementations/model-servers.md new file mode 100644 index 00000000..3d475aaa --- /dev/null +++ b/site-src/implementations/model-servers.md @@ -0,0 +1,38 @@ + + +# Supported Model Servers + +Any model server that conform to the [model server protocol](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/docs/proposals/003-model-server-protocol) are supported by the inference extension. + +## Compatible Model Server Versions + +| Model Server | Version | Commit | Notes | +| -------------------- | ---------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | +| vLLM V0 | v0.6.4 and above | [commit 0ad216f](https://github.com/vllm-project/vllm/commit/0ad216f5750742115c686723bf38698372d483fd) | | +| vLLM V1 | v0.8.0 and above | [commit bc32bc7](https://github.com/vllm-project/vllm/commit/bc32bc73aad076849ac88565cff745b01b17d89c) | | +| Triton(TensorRT-LLM) | [25.03](https://docs.nvidia.com/deeplearning/triton-inference-server/release-notes/rel-25-03.html#rel-25-03) and above | [commit 15cb989](https://github.com/triton-inference-server/tensorrtllm_backend/commit/15cb989b00523d8e92dce5165b9b9846c047a70d). | LoRA affinity feature is not available as the required LoRA metrics haven't been implemented in Triton yet. | + +## vLLM + +vLLM is configured as the default in the [endpoint picker extension](https://github.com/kubernetes-sigs/gateway-api-inference-extension/tree/main/pkg/epp). No further configuration is required. + +## Triton with TensorRT-LLM Backend + +Triton specific metric names need to be specified when starting the EPP. + +### Option 1: Use Helm + +Use `--set inferencePool.modelServerType=triton-tensorrt-llm` to install the [`inferencepool` via helm](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/42eb5ff1c5af1275df43ac384df0ddf20da95134/config/charts/inferencepool). See the [`inferencepool` helm guide](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/42eb5ff1c5af1275df43ac384df0ddf20da95134/config/charts/inferencepool/README.md) for more details. + +### Option 2: Edit EPP deployment yaml + + Add the following to the `args` of the [EPP deployment](https://github.com/kubernetes-sigs/gateway-api-inference-extension/blob/42eb5ff1c5af1275df43ac384df0ddf20da95134/config/manifests/inferencepool-resources.yaml#L32) + + ``` +- -totalQueuedRequestsMetric +- "nv_trt_llm_request_metrics{request_type=waiting}" +- -kvCacheUsagePercentageMetric +- "nv_trt_llm_kv_cache_block_metrics{kv_cache_block_type=fraction}" +- -loraInfoMetric +- "" # Set an empty metric to disable LoRA metric scraping as they are not supported by Triton yet. +``` \ No newline at end of file diff --git a/site-src/index.md b/site-src/index.md index 04d1fadb..61bece27 100644 --- a/site-src/index.md +++ b/site-src/index.md @@ -91,7 +91,7 @@ This project is being driven by [WG-Serving](https://github.com/kubernetes/community/tree/master/wg-serving) [SIG-Network](https://github.com/kubernetes/community/tree/master/sig-network) to improve and standardize routing to inference workloads in Kubernetes. Check -out the [implementations reference](implementations.md) to see the latest +out the [implementations reference](implementations/gateways.md) to see the latest projects & products that support this project. If you are interested in contributing to or building an implementation using Gateway API then don’t hesitate to [get involved!](/contributing) diff --git a/site-src/reference/spec.md b/site-src/reference/spec.md index e16c113c..d8e0c95b 100644 --- a/site-src/reference/spec.md +++ b/site-src/reference/spec.md @@ -1,12 +1,14 @@ # API Reference ## Packages -- [inference.networking.x-k8s.io/v1alpha1](#inferencenetworkingx-k8siov1alpha1) +- [inference.networking.x-k8s.io/v1alpha2](#inferencenetworkingx-k8siov1alpha2) -## inference.networking.x-k8s.io/v1alpha1 +## inference.networking.x-k8s.io/v1alpha2 + +Package v1alpha2 contains API Schema definitions for the +inference.networking.x-k8s.io API group. -Package v1alpha1 contains API Schema definitions for the gateway v1alpha1 API group ### Resource Types - [InferenceModel](#inferencemodel) @@ -18,26 +20,152 @@ Package v1alpha1 contains API Schema definitions for the gateway v1alpha1 API gr _Underlying type:_ _string_ -Defines how important it is to serve the model compared to other models. +Criticality defines how important it is to serve the model compared to other models. +Criticality is intentionally a bounded enum to contain the possibilities that need to be supported by the load balancing algorithm. Any reference to the Criticality field must be optional(use a pointer), and set no default. +This allows us to union this with a oneOf field in the future should we wish to adjust/extend this behavior. _Validation:_ -- Enum: [Critical Default Sheddable] +- Enum: [Critical Standard Sheddable] _Appears in:_ - [InferenceModelSpec](#inferencemodelspec) | Field | Description | | --- | --- | -| `Critical` | Most important. Requests to this band will be shed last.
| -| `Default` | More important than Sheddable, less important than Critical.
Requests in this band will be shed before critical traffic.
+kubebuilder:default=Default
| -| `Sheddable` | Least important. Requests to this band will be shed before all other bands.
| +| `Critical` | Critical defines the highest level of criticality. Requests to this band will be shed last.
| +| `Standard` | Standard defines the base criticality level and is more important than Sheddable but less
important than Critical. Requests in this band will be shed before critical traffic.
Most models are expected to fall within this band.
| +| `Sheddable` | Sheddable defines the lowest level of criticality. Requests to this band will be shed before
all other bands.
| + + +#### EndpointPickerConfig + + + +EndpointPickerConfig specifies the configuration needed by the proxy to discover and connect to the endpoint picker extension. +This type is intended to be a union of mutually exclusive configuration options that we may add in the future. + + + +_Appears in:_ +- [InferencePoolSpec](#inferencepoolspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `extensionRef` _[Extension](#extension)_ | Extension configures an endpoint picker as an extension service. | | Required: \{\}
| + + +#### Extension + + + +Extension specifies how to configure an extension that runs the endpoint picker. + + + +_Appears in:_ +- [EndpointPickerConfig](#endpointpickerconfig) +- [InferencePoolSpec](#inferencepoolspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `group` _[Group](#group)_ | Group is the group of the referent.
The default value is "", representing the Core API group. | | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is the Kubernetes resource kind of the referent. For example
"Service".
Defaults to "Service" when not specified.
ExternalName services can refer to CNAME DNS records that may live
outside of the cluster and as such are difficult to reason about in
terms of conformance. They also may not be safe to forward to (see
CVE-2021-25740 for more information). Implementations MUST NOT
support ExternalName Services. | Service | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `portNumber` _[PortNumber](#portnumber)_ | The port number on the service running the extension. When unspecified,
implementations SHOULD infer a default value of 9002 when the Kind is
Service. | | Maximum: 65535
Minimum: 1
| +| `failureMode` _[ExtensionFailureMode](#extensionfailuremode)_ | Configures how the gateway handles the case when the extension is not responsive.
Defaults to failClose. | FailClose | Enum: [FailOpen FailClose]
| + + +#### ExtensionConnection + + + +ExtensionConnection encapsulates options that configures the connection to the extension. + + + +_Appears in:_ +- [Extension](#extension) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `failureMode` _[ExtensionFailureMode](#extensionfailuremode)_ | Configures how the gateway handles the case when the extension is not responsive.
Defaults to failClose. | FailClose | Enum: [FailOpen FailClose]
| + + +#### ExtensionFailureMode + +_Underlying type:_ _string_ + +ExtensionFailureMode defines the options for how the gateway handles the case when the extension is not +responsive. + +_Validation:_ +- Enum: [FailOpen FailClose] + +_Appears in:_ +- [Extension](#extension) +- [ExtensionConnection](#extensionconnection) + +| Field | Description | +| --- | --- | +| `FailOpen` | FailOpen specifies that the proxy should not drop the request and forward the request to and endpoint of its picking.
| +| `FailClose` | FailClose specifies that the proxy should drop the request.
| + + +#### ExtensionReference + + + +ExtensionReference is a reference to the extension deployment. + + + +_Appears in:_ +- [Extension](#extension) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `group` _[Group](#group)_ | Group is the group of the referent.
The default value is "", representing the Core API group. | | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is the Kubernetes resource kind of the referent. For example
"Service".
Defaults to "Service" when not specified.
ExternalName services can refer to CNAME DNS records that may live
outside of the cluster and as such are difficult to reason about in
terms of conformance. They also may not be safe to forward to (see
CVE-2021-25740 for more information). Implementations MUST NOT
support ExternalName Services. | Service | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `portNumber` _[PortNumber](#portnumber)_ | The port number on the service running the extension. When unspecified,
implementations SHOULD infer a default value of 9002 when the Kind is
Service. | | Maximum: 65535
Minimum: 1
| + + +#### Group + +_Underlying type:_ _string_ + +Group refers to a Kubernetes Group. It must either be an empty string or a +RFC 1123 subdomain. + +This validation is based off of the corresponding Kubernetes validation: +https://github.com/kubernetes/apimachinery/blob/02cfb53916346d085a6c6c7c66f882e3c6b0eca6/pkg/util/validation/validation.go#L208 + +Valid values include: + +* "" - empty string implies core Kubernetes API group +* "gateway.networking.k8s.io" +* "foo.example.com" + +Invalid values include: + +* "example.com/bar" - "/" is an invalid character + +_Validation:_ +- MaxLength: 253 +- Pattern: `^$|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$` + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + #### InferenceModel -InferenceModel is the Schema for the InferenceModels API +InferenceModel is the Schema for the InferenceModels API. @@ -45,29 +173,31 @@ InferenceModel is the Schema for the InferenceModels API | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha1` | | | +| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha2` | | | | `kind` _string_ | `InferenceModel` | | | | `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `spec` _[InferenceModelSpec](#inferencemodelspec)_ | | | | | `status` _[InferenceModelStatus](#inferencemodelstatus)_ | | | | + + + + #### InferenceModelSpec -InferenceModelSpec represents a specific model use case. This resource is +InferenceModelSpec represents the desired state of a specific model use case. This resource is managed by the "Inference Workload Owner" persona. - -The Inference Workload Owner persona is: a team that trains, verifies, and +The Inference Workload Owner persona is someone that trains, verifies, and leverages a large language model from a model frontend, drives the lifecycle and rollout of new versions of those models, and defines the specific performance and latency goals for the model. These workloads are expected to operate within an InferencePool sharing compute capacity with other InferenceModels, defined by the Inference Platform Admin. - InferenceModel's modelName (not the ObjectMeta name) is unique for a given InferencePool, if the name is reused, an error will be shown on the status of a InferenceModel that attempted to reuse. The oldest InferenceModel, based on @@ -81,10 +211,10 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `modelName` _string_ | The name of the model as the users set in the "model" parameter in the requests.
The name should be unique among the workloads that reference the same backend pool.
This is the parameter that will be used to match the request with. In the future, we may
allow to match on other request parameters. The other approach to support matching on
on other request parameters is to use a different ModelName per HTTPFilter.
Names can be reserved without implementing an actual model in the pool.
This can be done by specifying a target model and setting the weight to zero,
an error will be returned specifying that no valid target model is found. | | MaxLength: 253
| -| `criticality` _[Criticality](#criticality)_ | Defines how important it is to serve the model compared to other models referencing the same pool. | Default | Enum: [Critical Default Sheddable]
| -| `targetModels` _[TargetModel](#targetmodel) array_ | Allow multiple versions of a model for traffic splitting.
If not specified, the target model name is defaulted to the modelName parameter.
modelName is often in reference to a LoRA adapter. | | MaxItems: 10
| -| `poolRef` _[PoolObjectReference](#poolobjectreference)_ | Reference to the inference pool, the pool must exist in the same namespace. | | Required: \{\}
| +| `modelName` _string_ | ModelName is the name of the model as it will be set in the "model" parameter for an incoming request.
ModelNames must be unique for a referencing InferencePool
(names can be reused for a different pool in the same cluster).
The modelName with the oldest creation timestamp is retained, and the incoming
InferenceModel is sets the Ready status to false with a corresponding reason.
In the rare case of a race condition, one Model will be selected randomly to be considered valid, and the other rejected.
Names can be reserved without an underlying model configured in the pool.
This can be done by specifying a target model and setting the weight to zero,
an error will be returned specifying that no valid target model is found. | | MaxLength: 256
Required: \{\}
| +| `criticality` _[Criticality](#criticality)_ | Criticality defines how important it is to serve the model compared to other models referencing the same pool.
Criticality impacts how traffic is handled in resource constrained situations. It handles this by
queuing or rejecting requests of lower criticality. InferenceModels of an equivalent Criticality will
fairly share resources over throughput of tokens. In the future, the metric used to calculate fairness,
and the proportionality of fairness will be configurable.
Default values for this field will not be set, to allow for future additions of new field that may 'one of' with this field.
Any implementations that may consume this field may treat an unset value as the 'Standard' range. | | Enum: [Critical Standard Sheddable]
| +| `targetModels` _[TargetModel](#targetmodel) array_ | TargetModels allow multiple versions of a model for traffic splitting.
If not specified, the target model name is defaulted to the modelName parameter.
modelName is often in reference to a LoRA adapter. | | MaxItems: 10
| +| `poolRef` _[PoolObjectReference](#poolobjectreference)_ | PoolRef is a reference to the inference pool, the pool must exist in the same namespace. | | Required: \{\}
| #### InferenceModelStatus @@ -100,14 +230,14 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool. | | | +| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferenceModel.
Known condition types are:
* "Accepted" | [map[lastTransitionTime:1970-01-01T00:00:00Z message:Waiting for controller reason:Pending status:Unknown type:Ready]] | MaxItems: 8
| #### InferencePool -InferencePool is the Schema for the Inferencepools API +InferencePool is the Schema for the InferencePools API. @@ -115,13 +245,17 @@ InferencePool is the Schema for the Inferencepools API | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha1` | | | +| `apiVersion` _string_ | `inference.networking.x-k8s.io/v1alpha2` | | | | `kind` _string_ | `InferencePool` | | | | `metadata` _[ObjectMeta](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectmeta-v1-meta)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `spec` _[InferencePoolSpec](#inferencepoolspec)_ | | | | | `status` _[InferencePoolStatus](#inferencepoolstatus)_ | | | | + + + + #### InferencePoolSpec @@ -135,8 +269,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `selector` _object (keys:[LabelKey](#labelkey), values:[LabelValue](#labelvalue))_ | Selector uses a map of label to watch model server pods
that should be included in the InferencePool. ModelServers should not
be with any other Service or InferencePool, that behavior is not supported
and will result in sub-optimal utilization.
In some cases, implementations may translate this to a Service selector, so this matches the simple
map used for Service selectors instead of the full Kubernetes LabelSelector type. | | Required: \{\}
| -| `targetPortNumber` _integer_ | TargetPortNumber is the port number that the model servers within the pool expect
to receive traffic from.
This maps to the TargetPort in: https://pkg.go.dev/k8s.io/api/core/v1#ServicePort | | Maximum: 65535
Minimum: 0
Required: \{\}
| +| `selector` _object (keys:[LabelKey](#labelkey), values:[LabelValue](#labelvalue))_ | Selector defines a map of labels to watch model server pods
that should be included in the InferencePool.
In some cases, implementations may translate this field to a Service selector, so this matches the simple
map used for Service selectors instead of the full Kubernetes LabelSelector type.
If sepecified, it will be applied to match the model server pods in the same namespace as the InferencePool.
Cross namesoace selector is not supported. | | Required: \{\}
| +| `targetPortNumber` _integer_ | TargetPortNumber defines the port number to access the selected model servers.
The number must be in the range 1 to 65535. | | Maximum: 65535
Minimum: 1
Required: \{\}
| +| `extensionRef` _[Extension](#extension)_ | Extension configures an endpoint picker as an extension service. | | Required: \{\}
| #### InferencePoolStatus @@ -152,33 +287,56 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool. | | | +| `parent` _[PoolStatus](#poolstatus) array_ | Parents is a list of parent resources (usually Gateways) that are
associated with the route, and the status of the InferencePool with respect to
each parent.
A maximum of 32 Gateways will be represented in this list. An empty list
means the route has not been attached to any Gateway. | | MaxItems: 32
| + + +#### Kind + +_Underlying type:_ _string_ + +Kind refers to a Kubernetes Kind. + +Valid values include: + +* "Service" +* "HTTPRoute" + +Invalid values include: + +* "invalid/kind" - "/" is an invalid character + +_Validation:_ +- MaxLength: 63 +- MinLength: 1 +- Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$` + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + #### LabelKey _Underlying type:_ _string_ -Originally copied from: https://github.com/kubernetes-sigs/gateway-api/blob/99a3934c6bc1ce0874f3a4c5f20cafd8977ffcb4/apis/v1/shared_types.go#L694-L731 +LabelKey was originally copied from: https://github.com/kubernetes-sigs/gateway-api/blob/99a3934c6bc1ce0874f3a4c5f20cafd8977ffcb4/apis/v1/shared_types.go#L694-L731 Duplicated as to not take an unexpected dependency on gw's API. - LabelKey is the key of a label. This is used for validation of maps. This matches the Kubernetes "qualified name" validation that is used for labels. - +Labels are case sensitive, so: my-label and My-Label are considered distinct. Valid values include: - * example * example.com * example.com/path * example.com/path.html - Invalid values include: - * example~ - "~" is an invalid character * example.com. - can not start or end with "." @@ -202,10 +360,8 @@ of maps. This matches the Kubernetes label validation rules: * unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]), * could contain dashes (-), underscores (_), dots (.), and alphanumerics between. - Valid values include: - * MyValue * my.name * 123-my-value @@ -220,6 +376,25 @@ _Appears in:_ +#### ObjectName + +_Underlying type:_ _string_ + +ObjectName refers to the name of a Kubernetes object. +Object names can have a variety of forms, including RFC 1123 subdomains, +RFC 1123 labels, or RFC 1035 labels. + +_Validation:_ +- MaxLength: 253 +- MinLength: 1 + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) +- [PoolObjectReference](#poolobjectreference) + + + #### PoolObjectReference @@ -234,9 +409,42 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `group` _string_ | Group is the group of the referent. | inference.networking.x-k8s.io | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| -| `kind` _string_ | Kind is kind of the referent. For example "InferencePool". | InferencePool | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| -| `name` _string_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| +| `group` _[Group](#group)_ | Group is the group of the referent. | inference.networking.x-k8s.io | MaxLength: 253
Pattern: `^$\|^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$`
| +| `kind` _[Kind](#kind)_ | Kind is kind of the referent. For example "InferencePool". | InferencePool | MaxLength: 63
MinLength: 1
Pattern: `^[a-zA-Z]([-a-zA-Z0-9]*[a-zA-Z0-9])?$`
| +| `name` _[ObjectName](#objectname)_ | Name is the name of the referent. | | MaxLength: 253
MinLength: 1
Required: \{\}
| + + +#### PoolStatus + + + +PoolStatus defines the observed state of InferencePool from a Gateway. + + + +_Appears in:_ +- [InferencePoolStatus](#inferencepoolstatus) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `parentRef` _[ObjectReference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#objectreference-v1-core)_ | GatewayRef indicates the gateway that observed state of InferencePool. | | | +| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.31/#condition-v1-meta) array_ | Conditions track the state of the InferencePool.
Known condition types are:
* "Accepted"
* "ResolvedRefs" | [map[lastTransitionTime:1970-01-01T00:00:00Z message:Waiting for controller reason:Pending status:Unknown type:Accepted]] | MaxItems: 8
| + + +#### PortNumber + +_Underlying type:_ _integer_ + +PortNumber defines a network port. + +_Validation:_ +- Maximum: 65535 +- Minimum: 1 + +_Appears in:_ +- [Extension](#extension) +- [ExtensionReference](#extensionreference) + #### TargetModel @@ -246,10 +454,10 @@ _Appears in:_ TargetModel represents a deployed model or a LoRA adapter. The Name field is expected to match the name of the LoRA adapter (or base model) as it is registered within the model server. Inference -Gateway assumes that the model exists on the model server and is the +Gateway assumes that the model exists on the model server and it's the responsibility of the user to validate a correct match. Should a model fail -to exist at request time, the error is processed by the Instance Gateway, -and then emitted on the appropriate InferenceModel object. +to exist at request time, the error is processed by the Inference Gateway +and emitted on the appropriate InferenceModel object. @@ -258,7 +466,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `name` _string_ | The name of the adapter as expected by the ModelServer. | | MaxLength: 253
| -| `weight` _integer_ | Weight is used to determine the proportion of traffic that should be
sent to this target model when multiple versions of the model are specified. | 1 | Maximum: 1e+06
Minimum: 0
| +| `name` _string_ | Name is the name of the adapter or base model, as expected by the ModelServer. | | MaxLength: 253
Required: \{\}
| +| `weight` _integer_ | Weight is used to determine the proportion of traffic that should be
sent to this model when multiple target models are specified.
Weight defines the proportion of requests forwarded to the specified
model. This is computed as weight/(sum of all weights in this
TargetModels list). For non-zero values, there may be some epsilon from
the exact proportion defined here depending on the precision an
implementation supports. Weight is not a percentage and the sum of
weights does not need to equal 100.
If a weight is set for any targetModel, it must be set for all targetModels.
Conversely weights are optional, so long as ALL targetModels do not specify a weight. | | Maximum: 1e+06
Minimum: 1
| diff --git a/test/e2e/epp/README.md b/test/e2e/epp/README.md index 247e8b12..fcc974b8 100644 --- a/test/e2e/epp/README.md +++ b/test/e2e/epp/README.md @@ -28,6 +28,13 @@ Follow these steps to run the end-to-end tests: export HF_TOKEN= ``` +1. **(Optional): Set the test namespace**: By default, the e2e test creates resources in the `inf-ext-e2e` namespace. + If you would like to change this namespace, set the following environment variable: + + ```sh + export E2E_NS= + ``` + 1. **Run the Tests**: Run the `test-e2e` target: ```sh diff --git a/test/e2e/epp/e2e_suite_test.go b/test/e2e/epp/e2e_suite_test.go index 643bbf75..01ed639d 100644 --- a/test/e2e/epp/e2e_suite_test.go +++ b/test/e2e/epp/e2e_suite_test.go @@ -30,6 +30,7 @@ import ( corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" apiextv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/serializer" @@ -49,13 +50,14 @@ const ( defaultReadyTimeout = 3 * time.Minute // defaultModelReadyTimeout is the default timeout for the model server deployment to report a ready state. defaultModelReadyTimeout = 10 * time.Minute + // defaultCurlTimeout is the default timeout for the curl command to get a response. + defaultCurlTimeout = 30 * time.Second // defaultInterval is the default interval to check if a resource exists or ready conditions. defaultInterval = time.Millisecond * 250 // defaultCurlInterval is the default interval to run the test curl command. defaultCurlInterval = time.Second * 5 - // nsName is the name of the Namespace used for tests. - // TODO [danehans]: Must be "default" until https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/227 is fixed - nsName = "default" + // defaultNsName is the default name of the Namespace used for tests. Can override using the E2E_NS environment variable. + defaultNsName = "inf-ext-e2e" // modelServerName is the name of the model server test resources. modelServerName = "vllm-llama3-8b-instruct" // modelName is the test model name. @@ -75,7 +77,7 @@ const ( // inferModelManifest is the manifest for the inference model CRD. inferModelManifest = "../../../config/crd/bases/inference.networking.x-k8s.io_inferencemodels.yaml" // inferExtManifest is the manifest for the inference extension test resources. - inferExtManifest = "../../../config/manifests/inferencepool-resources.yaml" + inferExtManifest = "../../testdata/inferencepool-e2e.yaml" // envoyManifest is the manifest for the envoy proxy test resources. envoyManifest = "../../testdata/envoy.yaml" // modelServerManifestFilepathEnvVar is the env var that holds absolute path to the manifest for the model server test resource. @@ -89,6 +91,7 @@ var ( kubeCli *kubernetes.Clientset scheme = runtime.NewScheme() cfg = config.GetConfigOrDie() + nsName string ) func TestAPIs(t *testing.T) { @@ -99,6 +102,11 @@ func TestAPIs(t *testing.T) { } var _ = ginkgo.BeforeSuite(func() { + nsName = os.Getenv("E2E_NS") + if nsName == "" { + nsName = defaultNsName + } + ginkgo.By("Setting up the test suite") setupSuite() @@ -107,17 +115,24 @@ var _ = ginkgo.BeforeSuite(func() { }) func setupInfra() { - modelServerManifest := readModelServerManifestPath() + createNamespace(cli, nsName) + + modelServerManifestPath := readModelServerManifestPath() + modelServerManifestArray := getYamlsFromModelServerManifest(modelServerManifestPath) + if strings.Contains(modelServerManifestArray[0], "hf-token") { + createHfSecret(cli, modelServerSecretManifest) + } crds := map[string]string{ "inferencepools.inference.networking.x-k8s.io": inferPoolManifest, "inferencemodels.inference.networking.x-k8s.io": inferModelManifest, } + createCRDs(cli, crds) createInferExt(cli, inferExtManifest) createClient(cli, clientManifest) createEnvoy(cli, envoyManifest) // Run this step last, as it requires additional time for the model server to become ready. - createModelServer(cli, modelServerSecretManifest, modelServerManifest) + createModelServer(cli, modelServerManifestArray, modelServerManifestPath) } var _ = ginkgo.AfterSuite(func() { @@ -137,7 +152,7 @@ func setupSuite() { err = apiextv1.AddToScheme(scheme) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) - err = infextv1a2.AddToScheme(scheme) + err = infextv1a2.Install(scheme) gomega.ExpectWithOffset(1, err).NotTo(gomega.HaveOccurred()) cli, err = client.New(cfg, client.Options{Scheme: scheme}) @@ -171,10 +186,22 @@ var ( existsTimeout = getTimeout("EXISTS_TIMEOUT", defaultExistsTimeout) readyTimeout = getTimeout("READY_TIMEOUT", defaultReadyTimeout) modelReadyTimeout = getTimeout("MODEL_READY_TIMEOUT", defaultModelReadyTimeout) + curlTimeout = getTimeout("CURL_TIMEOUT", defaultCurlTimeout) interval = defaultInterval curlInterval = defaultCurlInterval ) +func createNamespace(k8sClient client.Client, ns string) { + ginkgo.By("Creating e2e namespace: " + ns) + obj := &corev1.Namespace{ + ObjectMeta: v1.ObjectMeta{ + Name: ns, + }, + } + err := k8sClient.Create(ctx, obj) + gomega.Expect(err).NotTo(gomega.HaveOccurred(), "Failed to create e2e test namespace") +} + // namespaceExists ensures that a specified namespace exists and is ready for use. func namespaceExists(k8sClient client.Client, ns string) { ginkgo.By("Ensuring namespace exists: " + ns) @@ -191,6 +218,13 @@ func readModelServerManifestPath() string { return modelServerManifestFilepath } +func getYamlsFromModelServerManifest(modelServerManifestPath string) []string { + ginkgo.By("Ensuring the model server manifest points to an existing file") + modelServerManifestArray := readYaml(modelServerManifestPath) + gomega.Expect(modelServerManifestArray).NotTo(gomega.BeEmpty()) + return modelServerManifestArray +} + // createCRDs creates the Inference Extension CRDs used for testing. func createCRDs(k8sClient client.Client, crds map[string]string) { for name, path := range crds { @@ -224,15 +258,7 @@ func createClient(k8sClient client.Client, filePath string) { } // createModelServer creates the model server resources used for testing from the given filePaths. -func createModelServer(k8sClient client.Client, secretPath, deployPath string) { - ginkgo.By("Ensuring the model server manifest points to an existing file") - modelServerManifestArray := readYaml(deployPath) - gomega.Expect(modelServerManifestArray).NotTo(gomega.BeEmpty()) - modelServerManifestYaml := modelServerManifestArray[0] - if strings.Contains(modelServerManifestYaml, "hf-token") { - createHfSecret(k8sClient, secretPath) - } - +func createModelServer(k8sClient client.Client, modelServerManifestArray []string, deployPath string) { ginkgo.By("Creating model server resources from manifest: " + deployPath) createObjsFromYaml(k8sClient, modelServerManifestArray) @@ -270,8 +296,15 @@ func createHfSecret(k8sClient client.Client, secretPath string) { // createEnvoy creates the envoy proxy resources used for testing from the given filePath. func createEnvoy(k8sClient client.Client, filePath string) { + inManifests := readYaml(filePath) + ginkgo.By("Replacing placeholder namespace with E2E_NS environment variable") + outManifests := []string{} + for _, m := range inManifests { + outManifests = append(outManifests, strings.ReplaceAll(m, "$E2E_NS", nsName)) + } + ginkgo.By("Creating envoy proxy resources from manifest: " + filePath) - applyYAMLFile(k8sClient, filePath) + createObjsFromYaml(k8sClient, outManifests) // Wait for the configmap to exist before proceeding with test. cfgMap := &corev1.ConfigMap{} @@ -296,8 +329,15 @@ func createEnvoy(k8sClient client.Client, filePath string) { // createInferExt creates the inference extension resources used for testing from the given filePath. func createInferExt(k8sClient client.Client, filePath string) { + inManifests := readYaml(filePath) + ginkgo.By("Replacing placeholder namespace with E2E_NS environment variable") + outManifests := []string{} + for _, m := range inManifests { + outManifests = append(outManifests, strings.ReplaceAll(m, "$E2E_NS", nsName)) + } + ginkgo.By("Creating inference extension resources from manifest: " + filePath) - applyYAMLFile(k8sClient, filePath) + createObjsFromYaml(k8sClient, outManifests) // Wait for the clusterrole to exist. testutils.EventuallyExists(ctx, func() error { diff --git a/test/e2e/epp/e2e_test.go b/test/e2e/epp/e2e_test.go index 09c8835a..7240cebc 100644 --- a/test/e2e/epp/e2e_test.go +++ b/test/e2e/epp/e2e_test.go @@ -18,7 +18,9 @@ package epp import ( "fmt" + "strconv" "strings" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -53,7 +55,7 @@ var _ = ginkgo.Describe("InferencePool", func() { }, existsTimeout, interval).Should(gomega.Succeed()) ginkgo.By("Verifying connectivity through the inference extension") - curlCmd := getCurlCommand(envoyName, nsName, envoyPort, modelName) + curlCmd := getCurlCommand(envoyName, nsName, envoyPort, modelName, curlTimeout) // Ensure the expected responses include the inferencemodel target model names. var expected []string @@ -94,11 +96,11 @@ var _ = ginkgo.Describe("InferencePool", func() { func newInferenceModel(ns string) *v1alpha2.InferenceModel { targets := []v1alpha2.TargetModel{ { - Name: modelName + "-0", + Name: modelName, Weight: ptr.To(int32(50)), }, { - Name: modelName + "-1", + Name: "cad-fabricator", Weight: ptr.To(int32(50)), }, } @@ -112,10 +114,12 @@ func newInferenceModel(ns string) *v1alpha2.InferenceModel { // getCurlCommand returns the command, as a slice of strings, for curl'ing // the test model server at the given name, namespace, port, and model name. -func getCurlCommand(name, ns, port, model string) []string { +func getCurlCommand(name, ns, port, model string, timeout time.Duration) []string { return []string{ "curl", "-i", + "--max-time", + strconv.Itoa((int)(timeout.Seconds())), fmt.Sprintf("%s.%s.svc:%s/v1/completions", name, ns, port), "-H", "Content-Type: application/json", diff --git a/test/integration/bbr/hermetic_test.go b/test/integration/bbr/hermetic_test.go index 718bfedf..b99186db 100644 --- a/test/integration/bbr/hermetic_test.go +++ b/test/integration/bbr/hermetic_test.go @@ -19,20 +19,19 @@ package bbr import ( "context" - "encoding/json" "fmt" "testing" "time" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/go-logr/logr" "github.com/google/go-cmp/cmp" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" - runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/body-based-routing/server" + runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/bbr/server" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration" ) var logger = logutil.NewTestLogger().V(logutil.VERBOSE) @@ -46,7 +45,7 @@ func TestBodyBasedRouting(t *testing.T) { }{ { name: "success adding model parameter to header", - req: generateRequest(logger, "llama"), + req: integrationutils.GenerateRequest(logger, "test", "llama"), wantHeaders: []*configPb.HeaderValueOption{ { Header: &configPb.HeaderValue{ @@ -59,7 +58,7 @@ func TestBodyBasedRouting(t *testing.T) { }, { name: "no model parameter", - req: generateRequest(logger, ""), + req: integrationutils.GenerateRequest(logger, "test1", ""), wantHeaders: []*configPb.HeaderValueOption{}, wantErr: false, }, @@ -67,7 +66,7 @@ func TestBodyBasedRouting(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer() + client, cleanup := setUpHermeticServer(false) t.Cleanup(cleanup) want := &extProcPb.ProcessingResponse{} @@ -88,7 +87,7 @@ func TestBodyBasedRouting(t *testing.T) { } } - res, err := sendRequest(t, client, test.req) + res, err := integrationutils.SendRequest(t, client, test.req) if err != nil && !test.wantErr { t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) } @@ -99,12 +98,171 @@ func TestBodyBasedRouting(t *testing.T) { } } -func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { +func TestFullDuplexStreamed_BodyBasedRouting(t *testing.T) { + tests := []struct { + name string + reqs []*extProcPb.ProcessingRequest + wantResponses []*extProcPb.ProcessingResponse + wantErr bool + }{ + { + name: "success adding model parameter to header", + reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", "foo"), + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "X-Gateway-Model-Name", + RawValue: []byte("foo"), + }, + }, + }}, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"foo\",\"prompt\":\"test\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "success adding model parameter to header with multiple body chunks", + reqs: []*extProcPb.ProcessingRequest{ + { + Request: &extProcPb.ProcessingRequest_RequestHeaders{ + RequestHeaders: &extProcPb.HttpHeaders{ + Headers: &configPb.HeaderMap{ + Headers: []*configPb.HeaderValue{ + { + Key: "hi", + Value: "mom", + }, + }, + }, + }, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lo"), EndOfStream: false}, + }, + }, + { + Request: &extProcPb.ProcessingRequest_RequestBody{ + RequestBody: &extProcPb.HttpBody{Body: []byte("ra-sheddable\",\"prompt\":\"test\",\"temperature\":0}"), EndOfStream: true}, + }, + }, + }, + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{ + Response: &extProcPb.CommonResponse{ + ClearRouteCache: true, + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: []*configPb.HeaderValueOption{ + { + Header: &configPb.HeaderValue{ + Key: "X-Gateway-Model-Name", + RawValue: []byte("sql-lora-sheddable"), + }, + }, + }}, + }, + }, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-sheddable\",\"prompt\":\"test\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "no model parameter", + reqs: integrationutils.GenerateStreamedRequestSet(logger, "test", ""), + wantResponses: []*extProcPb.ProcessingResponse{ + { + Response: &extProcPb.ProcessingResponse_RequestHeaders{ + RequestHeaders: &extProcPb.HeadersResponse{}, + }, + }, + { + Response: &extProcPb.ProcessingResponse_RequestBody{ + RequestBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + BodyMutation: &extProcPb.BodyMutation{ + Mutation: &extProcPb.BodyMutation_StreamedResponse{ + StreamedResponse: &extProcPb.StreamedBodyResponse{ + Body: []byte("{\"max_tokens\":100,\"prompt\":\"test\",\"temperature\":0}"), + EndOfStream: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + client, cleanup := setUpHermeticServer(true) + t.Cleanup(cleanup) + + responses, err := integrationutils.StreamedRequest(t, client, test.reqs, len(test.wantResponses)) + if err != nil && !test.wantErr { + t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) + } + + if diff := cmp.Diff(test.wantResponses, responses, protocmp.Transform()); diff != "" { + t.Errorf("Unexpected response, (-want +got): %v", diff) + } + }) + } +} + +func setUpHermeticServer(streaming bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { port := 9004 serverCtx, stopServer := context.WithCancel(context.Background()) serverRunner := runserver.NewDefaultExtProcServerRunner(port, false) serverRunner.SecureServing = false + serverRunner.Streaming = streaming go func() { if err := serverRunner.AsRunnable(logger.WithName("ext-proc")).Start(serverCtx); err != nil { @@ -133,41 +291,3 @@ func setUpHermeticServer() (client extProcPb.ExternalProcessor_ProcessClient, cl time.Sleep(5 * time.Second) } } - -func generateRequest(logger logr.Logger, model string) *extProcPb.ProcessingRequest { - j := map[string]interface{}{ - "prompt": "test1", - "max_tokens": 100, - "temperature": 0, - } - if model != "" { - j["model"] = model - } - - llmReq, err := json.Marshal(j) - if err != nil { - logutil.Fatal(logger, err, "Failed to unmarshal LLM request") - } - req := &extProcPb.ProcessingRequest{ - Request: &extProcPb.ProcessingRequest_RequestBody{ - RequestBody: &extProcPb.HttpBody{Body: llmReq}, - }, - } - return req -} - -func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { - t.Logf("Sending request: %v", req) - if err := client.Send(req); err != nil { - t.Logf("Failed to send request %+v: %v", req, err) - return nil, err - } - - res, err := client.Recv() - if err != nil { - t.Logf("Failed to receive: %v", err) - return nil, err - } - t.Logf("Received request %+v", res) - return res, err -} diff --git a/test/integration/epp/hermetic_test.go b/test/integration/epp/hermetic_test.go index 2acdacf8..c63fd017 100644 --- a/test/integration/epp/hermetic_test.go +++ b/test/integration/epp/hermetic_test.go @@ -44,6 +44,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" @@ -60,19 +61,22 @@ import ( "sigs.k8s.io/controller-runtime/pkg/envtest" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha2" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datastore" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server" logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" - utiltesting "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" + epptestutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/testing" + integrationutils "sigs.k8s.io/gateway-api-inference-extension/test/integration" "sigs.k8s.io/yaml" ) const ( port = runserver.DefaultGrpcPort - metricsPort = 8888 + metricsPort = 8889 ) var ( @@ -90,318 +94,11 @@ func TestMain(m *testing.M) { os.Exit(code) } -func TestKubeInferenceModelRequest(t *testing.T) { - tests := []struct { - name string - req *extProcPb.ProcessingRequest - pods map[backendmetrics.Pod]*backendmetrics.Metrics - wantHeaders []*configPb.HeaderValueOption - wantMetadata *structpb.Struct - wantBody []byte - wantMetrics string - wantErr bool - immediateResponse *extProcPb.ImmediateResponse - }{ - { - name: "select lower queue and kv cache, no active lora", - req: utiltesting.GenerateRequest(logger, "test1", "my-model"), - // pod-1 will be picked because it has relatively low queue size and low KV cache. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ - fakePod(0): { - WaitingQueueSize: 3, - KVCacheUsagePercent: 0.2, - }, - fakePod(1): { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - }, - fakePod(2): { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - }, - }, - wantHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.2:8000"), - }, - }, - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), - }, - }, - }, - wantMetadata: makeMetadata("192.168.1.2:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"my-model-12345\",\"prompt\":\"test1\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="my-model",target_model_name="my-model-12345"} 1 - `, - wantErr: false, - }, - { - name: "select active lora, low queue", - req: utiltesting.GenerateRequest(logger, "test2", "sql-lora"), - // pod-1 will be picked because it has relatively low queue size, with the requested - // model being active, and has low KV cache. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ - fakePod(0): { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - fakePod(1): { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, - }, - }, - fakePod(2): { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - }, - wantHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.2:8000"), - }, - }, - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), - }, - }, - }, - wantMetadata: makeMetadata("192.168.1.2:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test2\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 - `, - wantErr: false, - }, - { - name: "select no lora despite active model, avoid excessive queue size", - req: utiltesting.GenerateRequest(logger, "test3", "sql-lora"), - // pod-2 will be picked despite it NOT having the requested model being active - // as it's above the affinity for queue size. Also is critical, so we should - // still honor request despite all queues > 5 - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ - fakePod(0): { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - }, - }, - fakePod(1): { - WaitingQueueSize: 200, - KVCacheUsagePercent: 0.1, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg2": 1, - }, - }, - fakePod(2): { - WaitingQueueSize: 6, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - }, - }, - }, - wantHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.3:8000"), - }, - }, - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), - }, - }, - }, - wantMetadata: makeMetadata("192.168.1.3:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg2\",\"prompt\":\"test3\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora",target_model_name="sql-lora-1fdg2"} 1 - `, - wantErr: false, - }, - { - name: "noncritical and all models past threshold, shed request", - req: utiltesting.GenerateRequest(logger, "test4", "sql-lora-sheddable"), - // no pods will be picked as all models are either above kv threshold, - // queue threshold, or both. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ - fakePod(0): { - WaitingQueueSize: 6, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - "sql-lora-1fdg3": 1, - }, - }, - fakePod(1): { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.85, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, - }, - fakePod(2): { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, - }, - }, - wantHeaders: []*configPb.HeaderValueOption{}, - wantMetadata: &structpb.Struct{}, - wantBody: []byte(""), - wantErr: false, - immediateResponse: &extProcPb.ImmediateResponse{ - Status: &envoyTypePb.HttpStatus{ - Code: envoyTypePb.StatusCode_TooManyRequests, - }, - }, - wantMetrics: "", - }, - { - name: "noncritical, but one server has capacity, do not shed", - req: utiltesting.GenerateRequest(logger, "test5", "sql-lora-sheddable"), - // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ - fakePod(0): { - WaitingQueueSize: 4, - KVCacheUsagePercent: 0.2, - ActiveModels: map[string]int{ - "foo": 1, - "bar": 1, - "sql-lora-1fdg3": 1, - }, - }, - fakePod(1): { - WaitingQueueSize: 0, - KVCacheUsagePercent: 0.85, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, - }, - fakePod(2): { - WaitingQueueSize: 10, - KVCacheUsagePercent: 0.9, - ActiveModels: map[string]int{ - "foo": 1, - "sql-lora-1fdg3": 1, - }, - }, - }, - wantHeaders: []*configPb.HeaderValueOption{ - { - Header: &configPb.HeaderValue{ - Key: runserver.DefaultDestinationEndpointHintKey, - RawValue: []byte("192.168.1.1:8000"), - }, - }, - { - Header: &configPb.HeaderValue{ - Key: "Content-Length", - RawValue: []byte("76"), - }, - }, - }, - wantMetadata: makeMetadata("192.168.1.1:8000"), - wantBody: []byte("{\"max_tokens\":100,\"model\":\"sql-lora-1fdg3\",\"prompt\":\"test5\",\"temperature\":0}"), - wantMetrics: ` - # HELP inference_model_request_total [ALPHA] Counter of inference model requests broken out for each model and target model. - # TYPE inference_model_request_total counter - inference_model_request_total{model_name="sql-lora-sheddable",target_model_name="sql-lora-1fdg3"} 1 - `, - wantErr: false, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - client, cleanup := setUpHermeticServer(t, test.pods, false) - t.Cleanup(cleanup) - want := &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_RequestBody{ - RequestBody: &extProcPb.BodyResponse{ - Response: &extProcPb.CommonResponse{ - HeaderMutation: &extProcPb.HeaderMutation{ - SetHeaders: test.wantHeaders, - }, - BodyMutation: &extProcPb.BodyMutation{ - Mutation: &extProcPb.BodyMutation_Body{ - Body: test.wantBody, - }, - }, - }, - }, - }, - DynamicMetadata: test.wantMetadata, - } - res, err := sendRequest(t, client, test.req) - - if err != nil && !test.wantErr { - t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) - } - if test.immediateResponse != nil { - want = &extProcPb.ProcessingResponse{ - Response: &extProcPb.ProcessingResponse_ImmediateResponse{ - ImmediateResponse: test.immediateResponse, - }, - } - } - if diff := cmp.Diff(want, res, protocmp.Transform()); diff != "" { - t.Errorf("Unexpected response, (-want +got): %v", diff) - } - - if test.wantMetrics != "" { - if err := metricsutils.GatherAndCompare(legacyregistry.DefaultGatherer, strings.NewReader(test.wantMetrics), "inference_model_request_total"); err != nil { - t.Error(err) - } - } - - legacyregistry.Reset() - }) - } -} - func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { tests := []struct { name string requests []*extProcPb.ProcessingRequest - pods map[backendmetrics.Pod]*backendmetrics.Metrics + pods map[backend.Pod]*backendmetrics.Metrics wantResponses []*extProcPb.ProcessingResponse wantMetrics map[string]string wantErr bool @@ -410,9 +107,9 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { // Request flow tests { name: "select lower queue and kv cache, no active lora", - requests: utiltesting.GenerateStreamedRequestSet(logger, "test1", "my-model"), + requests: integrationutils.GenerateStreamedRequestSet(logger, "test1", "my-model"), // pod-1 will be picked because it has relatively low queue size and low KV cache. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 3, KVCacheUsagePercent: 0.2, @@ -484,10 +181,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, { name: "select active lora, low queue", - requests: utiltesting.GenerateStreamedRequestSet(logger, "test2", "sql-lora"), + requests: integrationutils.GenerateStreamedRequestSet(logger, "test2", "sql-lora"), // pod-1 will be picked because it has relatively low queue size, with the requested // model being active, and has low KV cache. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 0, KVCacheUsagePercent: 0.2, @@ -495,6 +192,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -503,6 +201,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -511,6 +210,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -565,11 +265,11 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, { name: "select no lora despite active model, avoid excessive queue size", - requests: utiltesting.GenerateStreamedRequestSet(logger, "test3", "sql-lora"), + requests: integrationutils.GenerateStreamedRequestSet(logger, "test3", "sql-lora"), // pod-2 will be picked despite it NOT having the requested model being active // as it's above the affinity for queue size. Also is critical, so we should // still honor request despite all queues > 5 - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 10, KVCacheUsagePercent: 0.2, @@ -577,6 +277,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "bar": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 200, @@ -585,6 +286,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg2": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 6, @@ -592,6 +294,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { ActiveModels: map[string]int{ "foo": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -646,10 +349,10 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, { name: "noncritical and all models past threshold, shed request", - requests: utiltesting.GenerateStreamedRequestSet(logger, "test4", "sql-lora-sheddable"), + requests: integrationutils.GenerateStreamedRequestSet(logger, "test4", "sql-lora-sheddable"), // no pods will be picked as all models are either above kv threshold, // queue threshold, or both. - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 6, KVCacheUsagePercent: 0.2, @@ -658,6 +361,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -666,6 +370,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -674,6 +379,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -692,9 +398,9 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { }, { name: "noncritical, but one server has capacity, do not shed", - requests: utiltesting.GenerateStreamedRequestSet(logger, "test5", "sql-lora-sheddable"), + requests: integrationutils.GenerateStreamedRequestSet(logger, "test5", "sql-lora-sheddable"), // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -703,6 +409,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -711,6 +418,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -719,6 +427,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -802,7 +511,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { // // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -811,6 +520,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -819,6 +529,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -827,6 +538,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -910,7 +622,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { // // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -919,6 +631,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -927,6 +640,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -935,6 +649,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_model_request_total`: ` @@ -1019,7 +734,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { // // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -1028,6 +743,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -1036,6 +752,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -1044,6 +761,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -1115,7 +833,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { // // pod 0 will be picked as all other models are above threshold - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -1124,6 +842,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(1): { WaitingQueueSize: 0, @@ -1132,6 +851,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, fakePod(2): { WaitingQueueSize: 10, @@ -1140,6 +860,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "foo": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantErr: false, @@ -1460,7 +1181,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { DynamicMetadata: makeMetadata("192.168.1.1:8000"), }, }, - pods: map[backendmetrics.Pod]*backendmetrics.Metrics{ + pods: map[backend.Pod]*backendmetrics.Metrics{ fakePod(0): { WaitingQueueSize: 4, KVCacheUsagePercent: 0.2, @@ -1469,6 +1190,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { "bar": 1, "sql-lora-1fdg3": 1, }, + WaitingModels: map[string]int{}, }, }, wantMetrics: map[string]string{`inference_pool_ready_pods`: ` @@ -1483,7 +1205,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { t.Run(test.name, func(t *testing.T) { client, cleanup := setUpHermeticServer(t, test.pods, true) t.Cleanup(cleanup) - responses, err := streamedRequest(t, client, test.requests, len(test.wantResponses)) + responses, err := integrationutils.StreamedRequest(t, client, test.requests, len(test.wantResponses)) if err != nil && !test.wantErr { t.Errorf("Unexpected error, got: %v, want error: %v", err, test.wantErr) @@ -1505,7 +1227,7 @@ func TestFullDuplexStreamed_KubeInferenceModelRequest(t *testing.T) { } } -func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*backendmetrics.Metrics, streamed bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { +func setUpHermeticServer(t *testing.T, podAndMetrics map[backend.Pod]*backendmetrics.Metrics, streamed bool) (client extProcPb.ExternalProcessor_ProcessClient, cleanup func()) { // Reconfigure the TestPodMetricsClient. res := map[types.NamespacedName]*backendmetrics.Metrics{} for pod, metrics := range podAndMetrics { @@ -1522,7 +1244,7 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*bac } for pod := range podAndMetrics { - pod := utiltesting.MakePod(pod.NamespacedName.Name). + pod := epptestutil.MakePod(pod.NamespacedName.Name). Namespace(pod.NamespacedName.Namespace). ReadyCondition(). Labels(podLabels). @@ -1547,6 +1269,8 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*bac } }() + time.Sleep(serverRunner.RefreshPrometheusMetricsInterval) // wait for metrics to get available before running tests that rely on these metrics + // check if all pods are synced to datastore assert.EventuallyWithT(t, func(t *assert.CollectT) { assert.Len(t, serverRunner.Datastore.PodGetAll(), len(podAndMetrics), "Datastore not synced") @@ -1571,7 +1295,7 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*bac // clear created pods for pod := range podAndMetrics { - pod := utiltesting.MakePod(pod.NamespacedName.Name). + pod := epptestutil.MakePod(pod.NamespacedName.Name). Namespace(pod.NamespacedName.Namespace).Complete().ObjRef() if err := k8sClient.Delete(context.Background(), pod); err != nil { @@ -1581,8 +1305,8 @@ func setUpHermeticServer(t *testing.T, podAndMetrics map[backendmetrics.Pod]*bac } } -func fakePod(index int) backendmetrics.Pod { - return backendmetrics.Pod{ +func fakePod(index int) backend.Pod { + return backend.Pod{ NamespacedName: types.NamespacedName{Name: fmt.Sprintf("pod-%v", index), Namespace: "default"}, Address: fmt.Sprintf("192.168.1.%d", index+1), } @@ -1601,7 +1325,7 @@ func BeforeSuite() func() { } utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(v1alpha2.AddToScheme(scheme)) + utilruntime.Must(v1alpha2.Install(scheme)) k8sClient, err = k8sclient.New(cfg, k8sclient.Options{Scheme: scheme}) if err != nil { @@ -1626,8 +1350,9 @@ func BeforeSuite() func() { serverRunner.TestPodMetricsClient = &backendmetrics.FakePodMetricsClient{} pmf := backendmetrics.NewPodMetricsFactory(serverRunner.TestPodMetricsClient, 10*time.Millisecond) // Adjust from defaults - serverRunner.PoolName = "vllm-llama3-8b-instruct-pool" + serverRunner.PoolNamespacedName = types.NamespacedName{Name: "vllm-llama3-8b-instruct-pool", Namespace: "default"} serverRunner.Datastore = datastore.NewDatastore(context.Background(), pmf) + serverRunner.Scheduler = scheduling.NewScheduler(serverRunner.Datastore) serverRunner.SecureServing = false if err := serverRunner.SetupWithManager(context.Background(), mgr); err != nil { @@ -1651,27 +1376,13 @@ func BeforeSuite() func() { } for _, doc := range docs { - inferenceModel := &v1alpha2.InferenceModel{} - if err = yaml.Unmarshal(doc, inferenceModel); err != nil { + obj := &unstructured.Unstructured{} + if err = yaml.Unmarshal(doc, obj); err != nil { logutil.Fatal(logger, err, "Can't unmarshal object", "document", doc) } - if inferenceModel.Kind == "InferenceModel" { - logger.Info("Creating inference model", "model", inferenceModel) - if err := k8sClient.Create(context.Background(), inferenceModel); err != nil { - logutil.Fatal(logger, err, "Unable to create inferenceModel", "modelName", inferenceModel.Name) - } - } - } - for _, doc := range docs { - inferencePool := &v1alpha2.InferencePool{} - if err = yaml.Unmarshal(doc, inferencePool); err != nil { - logutil.Fatal(logger, err, "Can't unmarshal object", "document", doc) - } - if inferencePool.Kind == "InferencePool" { - logger.Info("Creating inference pool", "pool", inferencePool) - if err := k8sClient.Create(context.Background(), inferencePool); err != nil { - logutil.Fatal(logger, err, "Unable to create inferencePool", "poolName", inferencePool.Name) - } + logger.Info("Creating object", "kind", obj.GetKind(), "object", obj) + if err := k8sClient.Create(context.Background(), obj); err != nil { + logutil.Fatal(logger, err, "Unable to create object", "object", obj.GetName()) } } @@ -1688,55 +1399,6 @@ func BeforeSuite() func() { } } -func sendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { - t.Logf("Sending request: %v", req) - if err := client.Send(req); err != nil { - t.Logf("Failed to send request %+v: %v", req, err) - return nil, err - } - - res, err := client.Recv() - if err != nil { - t.Logf("Failed to receive: %v", err) - return nil, err - } - t.Logf("Received request %+v", res) - return res, err -} - -func streamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) { - for _, req := range requests { - t.Logf("Sending request: %v", req) - if err := client.Send(req); err != nil { - t.Logf("Failed to send request %+v: %v", req, err) - return nil, err - } - } - responses := []*extProcPb.ProcessingResponse{} - - // Make an incredible simple timeout func in the case where - // there is less than the expected amount of responses; bail and fail. - var simpleTimeout bool - go func() { - time.Sleep(10 * time.Second) - simpleTimeout = true - }() - - for range expectedResponses { - if simpleTimeout { - break - } - res, err := client.Recv() - if err != nil && err != io.EOF { - t.Logf("Failed to receive: %v", err) - return nil, err - } - t.Logf("Received request %+v", res) - responses = append(responses, res) - } - return responses, nil -} - // readDocuments reads documents from file. func readDocuments(fp string) ([][]byte, error) { b, err := os.ReadFile(fp) diff --git a/pkg/epp/util/testing/request.go b/test/integration/util.go similarity index 56% rename from pkg/epp/util/testing/request.go rename to test/integration/util.go index 30772ad5..5fcc9d18 100644 --- a/pkg/epp/util/testing/request.go +++ b/test/integration/util.go @@ -14,10 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -package testing +package integration import ( "encoding/json" + "io" + "testing" + "time" envoyCorev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" @@ -25,13 +28,64 @@ import ( logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, req *extProcPb.ProcessingRequest) (*extProcPb.ProcessingResponse, error) { + t.Logf("Sending request: %v", req) + if err := client.Send(req); err != nil { + t.Logf("Failed to send request %+v: %v", req, err) + return nil, err + } + + res, err := client.Recv() + if err != nil { + t.Logf("Failed to receive: %v", err) + return nil, err + } + t.Logf("Received response %+v", res) + return res, err +} + +func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) { + for _, req := range requests { + t.Logf("Sending request: %v", req) + if err := client.Send(req); err != nil { + t.Logf("Failed to send request %+v: %v", req, err) + return nil, err + } + } + responses := []*extProcPb.ProcessingResponse{} + + // Make an incredible simple timeout func in the case where + // there is less than the expected amount of responses; bail and fail. + var simpleTimeout bool + go func() { + time.Sleep(10 * time.Second) + simpleTimeout = true + }() + + for range expectedResponses { + if simpleTimeout { + break + } + res, err := client.Recv() + if err != nil && err != io.EOF { + t.Logf("Failed to receive: %v", err) + return nil, err + } + t.Logf("Received response %+v", res) + responses = append(responses, res) + } + return responses, nil +} + func GenerateRequest(logger logr.Logger, prompt, model string) *extProcPb.ProcessingRequest { j := map[string]interface{}{ - "model": model, "prompt": prompt, "max_tokens": 100, "temperature": 0, } + if model != "" { + j["model"] = model + } llmReq, err := json.Marshal(j) if err != nil { diff --git a/test/testdata/envoy.yaml b/test/testdata/envoy.yaml index fc32b5aa..3fff8598 100644 --- a/test/testdata/envoy.yaml +++ b/test/testdata/envoy.yaml @@ -100,14 +100,15 @@ data: grpc_service: envoy_grpc: cluster_name: ext_proc - authority: vllm-llama3-8b-instruct-epp.default:9002 + authority: vllm-llama3-8b-instruct-epp.$E2E_NS:9002 timeout: 10s processing_mode: request_header_mode: SEND - response_header_mode: SKIP - request_body_mode: BUFFERED - request_trailer_mode: SKIP - response_trailer_mode: SKIP + response_header_mode: SEND + request_body_mode: FULL_DUPLEX_STREAMED + response_body_mode: FULL_DUPLEX_STREAMED + request_trailer_mode: SEND + response_trailer_mode: SEND message_timeout: 1000s # Mark it as disabled if needed for troubleshooting: # disabled: true @@ -194,7 +195,7 @@ data: - endpoint: address: socket_address: - address: vllm-llama3-8b-instruct-epp.default + address: vllm-llama3-8b-instruct-epp.$E2E_NS port_value: 9002 health_status: HEALTHY load_balancing_weight: 1 @@ -221,10 +222,10 @@ spec: spec: containers: - name: envoy - image: docker.io/envoyproxy/envoy:distroless-v1.32.2 + image: docker.io/envoyproxy/envoy:distroless-v1.33.2 args: - "--service-cluster" - - "default/inference-gateway" + - "$E2E_NS/inference-gateway" - "--service-node" - "$(ENVOY_POD_NAME)" - "--log-level" diff --git a/test/testdata/inferencepool-e2e.yaml b/test/testdata/inferencepool-e2e.yaml new file mode 100644 index 00000000..79339c5b --- /dev/null +++ b/test/testdata/inferencepool-e2e.yaml @@ -0,0 +1,126 @@ +apiVersion: inference.networking.x-k8s.io/v1alpha2 +kind: InferencePool +metadata: + labels: + name: vllm-llama3-8b-instruct +spec: + targetPortNumber: 8000 + selector: + app: vllm-llama3-8b-instruct + extensionRef: + name: vllm-llama3-8b-instruct-epp + namespace: $E2E_NS +--- +apiVersion: v1 +kind: Service +metadata: + name: vllm-llama3-8b-instruct-epp + namespace: $E2E_NS +spec: + selector: + app: vllm-llama3-8b-instruct-epp + ports: + - protocol: TCP + port: 9002 + targetPort: 9002 + appProtocol: http2 + type: ClusterIP +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vllm-llama3-8b-instruct-epp + namespace: $E2E_NS + labels: + app: vllm-llama3-8b-instruct-epp +spec: + replicas: 1 + selector: + matchLabels: + app: vllm-llama3-8b-instruct-epp + template: + metadata: + labels: + app: vllm-llama3-8b-instruct-epp + spec: + # Conservatively, this timeout should mirror the longest grace period of the pods within the pool + terminationGracePeriodSeconds: 130 + containers: + - name: epp + image: us-central1-docker.pkg.dev/k8s-staging-images/gateway-api-inference-extension/epp:main + imagePullPolicy: Always + args: + - -poolName + - "vllm-llama3-8b-instruct" + - -poolNamespace + - "$E2E_NS" + - -v + - "4" + - --zap-encoder + - "json" + - -grpcPort + - "9002" + - -grpcHealthPort + - "9003" + env: + - name: USE_STREAMING + value: "true" + ports: + - containerPort: 9002 + - containerPort: 9003 + - name: metrics + containerPort: 9090 + livenessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + readinessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 +--- +kind: ClusterRole +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pod-read +rules: +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencemodels"] + verbs: ["get", "watch", "list"] +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "watch", "list"] +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencepools"] + verbs: ["get", "watch", "list"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["get", "watch", "list"] +- apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create +- apiGroups: + - authorization.k8s.io + resources: + - subjectaccessreviews + verbs: + - create +--- +kind: ClusterRoleBinding +apiVersion: rbac.authorization.k8s.io/v1 +metadata: + name: pod-read-binding +subjects: +- kind: ServiceAccount + name: default + namespace: $E2E_NS +roleRef: + kind: ClusterRole + name: pod-read