From f1b18a2c5ab620f20a392a414f2123f1d021ecfa Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Wed, 6 Aug 2025 15:34:55 +0200 Subject: [PATCH 1/4] feat: allowed hosts --- README.md | 55 +++++- cmd/server/server.go | 80 ++++++-- cmd/server/server_test.go | 287 +++++++++++++++++++++++++-- go.mod | 1 + go.sum | 2 + lib/httpapi/server.go | 71 ++++++- lib/httpapi/server_test.go | 384 ++++++++++++++++++++++++++++++++++++- 7 files changed, 835 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 77c571f..8dff3c4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ Control [Claude Code](https://github.com/anthropics/claude-code), [Goose](https: ![agentapi-chat](https://github.com/user-attachments/assets/57032c9f-4146-4b66-b219-09e38ab7690d) - You can use AgentAPI: - to build a unified chat interface for coding agents @@ -54,9 +53,6 @@ You can use AgentAPI: Run an HTTP server that lets you control an agent. If you'd like to start an agent with additional arguments, pass the full agent command after the `--` flag. -> [!NOTE] -> When using Codex, always specify the agent type explicitly (`agentapi server --type=codex -- codex`), or message formatting may break. - ```bash agentapi server -- claude --allowedTools "Bash(git*) Edit Replace" ``` @@ -68,6 +64,9 @@ agentapi server -- aider --model sonnet --api-key anthropic=sk-ant-apio3-XXX agentapi server -- goose ``` +> [!NOTE] +> When using Codex, always specify the agent type explicitly (`agentapi server --type=codex -- codex`), or message formatting may break. + An OpenAPI schema is available in [openapi.json](openapi.json). By default, the server runs on port 3284. Additionally, the server exposes the same OpenAPI schema at http://localhost:3284/openapi.json and the available endpoints in a documentation UI at http://localhost:3284/docs. @@ -79,6 +78,54 @@ There are 4 endpoints: - GET `/status` - returns the current status of the agent, either "stable" or "running" - GET `/events` - an SSE stream of events from the agent: message and status updates +#### Allowed hosts + +By default, the server only allows requests with the host header set to localhost:3284. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. + +To allow requests from any host, use `*` as the allowed host. + +```bash +agentapi server --allowed-hosts '*' -- claude +``` + +To allow a specific host, use: + +```bash +agentapi server --allowed-hosts 'example.com' -- claude +``` + +To specify multiple hosts, use a comma-separated list when using the `--allowed-hosts` flag, or a space-separated list when using the `AGENTAPI_ALLOWED_HOSTS` environment variable. + +```bash +agentapi server --allowed-hosts 'example.com,example.org' -- claude +# or +AGENTAPI_ALLOWED_HOSTS='example.com example.org' agentapi server -- claude +``` + +#### Allowed origins + +By default, the server allows CORS requests from `http://localhost:3284`, `http://localhost:3000`, and `http://localhost:3001`. If you'd like to change which origins can make cross-origin requests to AgentAPI, you can change this by using the `AGENTAPI_ALLOWED_ORIGINS` environment variable or the `--allowed-origins` flag. + +To allow requests from any origin, use `*` as the allowed origin: + +```bash +agentapi server --allowed-origins '*' -- claude +``` + +To allow a specific origin, use: + +```bash +agentapi server --allowed-origins 'https://example.com' -- claude +``` + +To specify multiple origins, use a comma-separated list when using the `--allowed-origins` flag, or a space-separated list when using the `AGENTAPI_ALLOWED_ORIGINS` environment variable. Origins must include the protocol (`http://` or `https://`) and support wildcards (e.g., `https://*.example.com`): + +```bash +agentapi server --allowed-origins 'https://example.com,http://localhost:3000' -- claude +# or +AGENTAPI_ALLOWED_ORIGINS='https://example.com http://localhost:3000' agentapi server -- claude +``` + ### `agentapi attach` Attach to a running agent's terminal session. diff --git a/cmd/server/server.go b/cmd/server/server.go index 3313b51..be57053 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -9,6 +9,7 @@ import ( "os" "sort" "strings" + "unicode" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -58,6 +59,26 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) { return AgentTypeCustom, nil } +// Validate allowed hosts or origins don't contain whitespace or commas. +// Viper/Cobra use different separators (space for env vars, comma for flags), +// so these characters likely indicate user error. +func validateAllowedHostsOrOrigins(input []string) error { + if len(input) == 0 { + return fmt.Errorf("the list must not be empty") + } + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } + } + if strings.Contains(item, ",") { + return fmt.Errorf("'%s' contains comma characters, which are not allowed", item) + } + } + return nil +} + func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error { agent := argsToPass[0] agentTypeValue := viper.GetString(FlagType) @@ -95,12 +116,17 @@ func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) er } } port := viper.GetInt(FlagPort) - srv := httpapi.NewServer(ctx, httpapi.ServerConfig{ - AgentType: agentType, - Process: process, - Port: port, - ChatBasePath: viper.GetString(FlagChatBasePath), + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: agentType, + Process: process, + Port: port, + ChatBasePath: viper.GetString(FlagChatBasePath), + AllowedHosts: viper.GetStringSlice(FlagAllowedHosts), + AllowedOrigins: viper.GetStringSlice(FlagAllowedOrigins), }) + if err != nil { + return xerrors.Errorf("failed to create server: %w", err) + } if printOpenAPI { fmt.Println(srv.GetOpenAPI()) return nil @@ -150,12 +176,15 @@ type flagSpec struct { } const ( - FlagType = "type" - FlagPort = "port" - FlagPrintOpenAPI = "print-openapi" - FlagChatBasePath = "chat-base-path" - FlagTermWidth = "term-width" - FlagTermHeight = "term-height" + FlagType = "type" + FlagPort = "port" + FlagPrintOpenAPI = "print-openapi" + FlagChatBasePath = "chat-base-path" + FlagTermWidth = "term-width" + FlagTermHeight = "term-height" + FlagAllowedHosts = "allowed-hosts" + FlagAllowedOrigins = "allowed-origins" + FlagExit = "exit" ) func CreateServerCmd() *cobra.Command { @@ -164,7 +193,22 @@ func CreateServerCmd() *cobra.Command { Short: "Run the server", Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")), Args: cobra.MinimumNArgs(1), + PreRunE: func(cmd *cobra.Command, args []string) error { + allowedHosts := viper.GetStringSlice(FlagAllowedHosts) + if err := validateAllowedHostsOrOrigins(allowedHosts); err != nil { + return xerrors.Errorf("failed to validate allowed hosts: %w", err) + } + allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins) + if err := validateAllowedHostsOrOrigins(allowedOrigins); err != nil { + return xerrors.Errorf("failed to validate allowed origins: %w", err) + } + return nil + }, Run: func(cmd *cobra.Command, args []string) { + // The --exit flag is used for testing validation of flags in the test suite + if viper.GetBool(FlagExit) { + return + } logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) ctx := logctx.WithLogger(context.Background(), logger) if err := runServer(ctx, logger, cmd.Flags().Args()); err != nil { @@ -181,6 +225,10 @@ func CreateServerCmd() *cobra.Command { {FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"}, {FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"}, {FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"}, + // localhost:3284 is the default host for the server + {FlagAllowedHosts, "a", []string{"localhost:3284"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, + // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. + {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, } for _, spec := range flagSpecs { @@ -193,6 +241,8 @@ func CreateServerCmd() *cobra.Command { serverCmd.Flags().BoolP(spec.name, spec.shorthand, spec.defaultValue.(bool), spec.usage) case "uint16": serverCmd.Flags().Uint16P(spec.name, spec.shorthand, spec.defaultValue.(uint16), spec.usage) + case "stringSlice": + serverCmd.Flags().StringSliceP(spec.name, spec.shorthand, spec.defaultValue.([]string), spec.usage) default: panic(fmt.Sprintf("unknown flag type: %s", spec.flagType)) } @@ -201,6 +251,14 @@ func CreateServerCmd() *cobra.Command { } } + serverCmd.Flags().Bool(FlagExit, false, "Exit immediately after parsing arguments") + if err := serverCmd.Flags().MarkHidden(FlagExit); err != nil { + panic(fmt.Sprintf("failed to mark flag %s as hidden: %v", FlagExit, err)) + } + if err := viper.BindPFlag(FlagExit, serverCmd.Flags().Lookup(FlagExit)); err != nil { + panic(fmt.Sprintf("failed to bind flag %s: %v", FlagExit, err)) + } + viper.SetEnvPrefix("AGENTAPI") viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index 59b1ccc..4044d22 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -12,6 +12,20 @@ import ( "github.com/stretchr/testify/require" ) +type nullWriter struct{} + +func (w *nullWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +// setupCommandOutput configures a cobra command to use a null writer for output capture. +func setupCommandOutput(t *testing.T, cmd *cobra.Command) { + t.Helper() + + cmd.SetOut(&nullWriter{}) + cmd.SetErr(&nullWriter{}) +} + func TestParseAgentType(t *testing.T) { tests := []struct { firstArg string @@ -141,17 +155,17 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) { {"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }}, {"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }}, {"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }}, + {"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { isolateViper(t) serverCmd := CreateServerCmd() - cmd := &cobra.Command{} - cmd.AddCommand(serverCmd) - - // Execute with no args to get defaults - serverCmd.SetArgs([]string{"--help"}) // Use help to avoid actual execution + setupCommandOutput(t, serverCmd) + // Execute with --exit to get defaults + serverCmd.SetArgs([]string{"--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -175,6 +189,8 @@ func TestServerCmd_AllEnvVars(t *testing.T) { {"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }}, {"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }}, {"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }}, + {"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } for _, tt := range tests { @@ -183,10 +199,8 @@ func TestServerCmd_AllEnvVars(t *testing.T) { t.Setenv(tt.envVar, tt.envValue) serverCmd := CreateServerCmd() - cmd := &cobra.Command{} - cmd.AddCommand(serverCmd) - - serverCmd.SetArgs([]string{"--help"}) + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -247,6 +261,13 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) { uint16(600), func() any { return viper.GetUint16(FlagTermHeight) }, }, + { + "allowed-origins: CLI overrides env", + "AGENTAPI_ALLOWED_ORIGINS", "https://env-example.com http://localhost:3000", + []string{"--allowed-origins", "https://cli-example.com"}, + []string{"https://cli-example.com"}, + func() any { return viper.GetStringSlice(FlagAllowedOrigins) }, + }, } for _, tt := range tests { @@ -254,9 +275,9 @@ func TestServerCmd_ArgsPrecedenceOverEnv(t *testing.T) { isolateViper(t) t.Setenv(tt.envVar, tt.envValue) - // Mock execution to test arg parsing without running server - args := append(tt.args, "--help") + args := append(tt.args, "--exit", "dummy-command") serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) serverCmd.SetArgs(args) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) @@ -277,7 +298,8 @@ func TestMixed_ConfigurationScenarios(t *testing.T) { // Set some CLI args serverCmd := CreateServerCmd() - serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--help"}) + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs([]string{"--port", "9999", "--print-openapi", "--exit", "dummy-command"}) if err := serverCmd.Execute(); err != nil { t.Fatalf("Failed to execute server command: %v", err) } @@ -291,3 +313,244 @@ func TestMixed_ConfigurationScenarios(t *testing.T) { assert.Equal(t, uint16(1000), viper.GetUint16(FlagTermHeight)) // default }) } + +func TestServerCmd_AllowedHosts(t *testing.T) { + tests := []struct { + name string + env map[string]string + args []string + expectedErr string + expected []string // only checked if expectedErr is empty + }{ + // Environment variable scenarios (space-separated format) + { + name: "env: single valid host", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"}, + args: []string{}, + expected: []string{"localhost:3284"}, + }, + { + name: "env: multiple valid hosts space-separated", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"}, + args: []string{}, + expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"}, + }, + { + name: "env: host with tab", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"}, + args: []string{}, + expected: []string{"localhost:3284", "example.com"}, + }, + { + name: "env: host with comma (invalid)", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"}, + args: []string{}, + expectedErr: "contains comma characters", + }, + + // CLI flag scenarios (comma-separated format) + { + name: "flag: single valid host", + args: []string{"--allowed-hosts", "localhost:3284"}, + expected: []string{"localhost:3284"}, + }, + { + name: "flag: multiple valid hosts comma-separated", + args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"}, + expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"}, + }, + { + name: "flag: multiple valid hosts with multiple flags", + args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"}, + expected: []string{"localhost:3284", "example.com"}, + }, + { + name: "flag: host with newline", + args: []string{"--allowed-hosts", "localhost:3284\n"}, + expected: []string{"localhost:3284"}, + }, + { + name: "flag: host with space in comma-separated list (invalid)", + args: []string{"--allowed-hosts", "localhost:3284,example .com"}, + expectedErr: "contains whitespace characters", + }, + + // Mixed scenarios (env + flag precedence) + { + name: "mixed: flag overrides env", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"}, + args: []string{"--allowed-hosts", "override.com"}, + expected: []string{"override.com"}, + }, + { + name: "mixed: flag overrides env but flag is invalid", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"}, + args: []string{"--allowed-hosts", "invalid .com"}, + expectedErr: "contains whitespace characters", + }, + + // Empty hosts are not allowed + { + name: "empty host", + args: []string{"--allowed-hosts", ""}, + expectedErr: "the list must not be empty", + }, + + // Default behavior + { + name: "default hosts when neither env nor flag provided", + args: []string{}, + expected: []string{"localhost:3284"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isolateViper(t) + + // Set environment variables if provided + for key, value := range tt.env { + t.Setenv(key, value) + } + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs(append(tt.args, "--exit", "dummy-command")) + err := serverCmd.Execute() + + if tt.expectedErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedHosts)) + } + }) + } +} + +func TestServerCmd_AllowedOrigins(t *testing.T) { + tests := []struct { + name string + env map[string]string + args []string + expectedErr string + expected []string // only checked if expectedErr is empty + }{ + // Environment variable scenarios (space-separated format) + { + name: "env: single valid origin", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com"}, + args: []string{}, + expected: []string{"https://example.com"}, + }, + { + name: "env: multiple valid origins space-separated", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com http://localhost:3000 https://app.example.com"}, + args: []string{}, + expected: []string{"https://example.com", "http://localhost:3000", "https://app.example.com"}, + }, + { + name: "env: wildcard origin", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "*"}, + args: []string{}, + expected: []string{"*"}, + }, + { + name: "env: origin with tab", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com\thttp://localhost:3000"}, + args: []string{}, + expected: []string{"https://example.com", "http://localhost:3000"}, + }, + { + name: "env: origin with comma (invalid)", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com,http://localhost:3000"}, + args: []string{}, + expectedErr: "contains comma characters", + }, + + // CLI flag scenarios (comma-separated format) + { + name: "flag: single valid origin", + args: []string{"--allowed-origins", "https://example.com"}, + expected: []string{"https://example.com"}, + }, + { + name: "flag: multiple valid origins comma-separated", + args: []string{"--allowed-origins", "https://example.com,http://localhost:3000,https://app.example.com"}, + expected: []string{"https://example.com", "http://localhost:3000", "https://app.example.com"}, + }, + { + name: "flag: multiple valid origins with multiple flags", + args: []string{"--allowed-origins", "https://example.com", "--allowed-origins", "http://localhost:3000"}, + expected: []string{"https://example.com", "http://localhost:3000"}, + }, + { + name: "flag: wildcard origin", + args: []string{"--allowed-origins", "*"}, + expected: []string{"*"}, + }, + { + name: "flag: origin with newline", + args: []string{"--allowed-origins", "https://example.com\n"}, + expected: []string{"https://example.com"}, + }, + { + name: "flag: origin with space in comma-separated list (invalid)", + args: []string{"--allowed-origins", "https://example.com,http://localhost :3000"}, + expectedErr: "contains whitespace characters", + }, + + // Mixed scenarios (env + flag precedence) + { + name: "mixed: flag overrides env", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://env-example.com"}, + args: []string{"--allowed-origins", "https://override.com"}, + expected: []string{"https://override.com"}, + }, + { + name: "mixed: flag overrides env but flag is invalid", + env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://env-example.com"}, + args: []string{"--allowed-origins", "invalid origin"}, + expectedErr: "contains whitespace characters", + }, + + // Empty origins are not allowed + { + name: "empty origin", + args: []string{"--allowed-origins", ""}, + expectedErr: "the list must not be empty", + }, + + // Default behavior + { + name: "default origins when neither env nor flag provided", + args: []string{}, + expected: []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isolateViper(t) + + // Set environment variables if provided + for key, value := range tt.env { + t.Setenv(key, value) + } + + serverCmd := CreateServerCmd() + setupCommandOutput(t, serverCmd) + serverCmd.SetArgs(append(tt.args, "--exit", "dummy-command")) + err := serverCmd.Execute() + + if tt.expectedErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedErr) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, viper.GetStringSlice(FlagAllowedOrigins)) + } + }) + } +} diff --git a/go.mod b/go.mod index b4212f0..d0587ad 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/spf13/afero v1.14.0 github.com/spf13/pflag v1.0.6 // indirect + github.com/unrolled/secure v1.17.0 golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect diff --git a/go.sum b/go.sum index afb727f..1fff5e8 100644 --- a/go.sum +++ b/go.sum @@ -113,6 +113,8 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA= github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8= +github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= +github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e9e71cb..e6c89be 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -7,6 +7,7 @@ import ( "log/slog" "net/http" "net/url" + "slices" "strings" "sync" "time" @@ -20,6 +21,7 @@ import ( "github.com/danielgtaylor/huma/v2/sse" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" + "github.com/unrolled/secure" "golang.org/x/xerrors" ) @@ -60,18 +62,71 @@ func (s *Server) GetOpenAPI() string { const snapshotInterval = 25 * time.Millisecond type ServerConfig struct { - AgentType mf.AgentType - Process *termexec.Process - Port int - ChatBasePath string + AgentType mf.AgentType + Process *termexec.Process + Port int + ChatBasePath string + AllowedHosts []string + AllowedOrigins []string +} + +func parseAllowedHosts(hosts []string) ([]string, error) { + if slices.Contains(hosts, "*") { + return []string{}, nil + } + for _, host := range hosts { + if strings.Contains(host, "*") { + return nil, xerrors.Errorf("wildcard characters are not supported: %q", host) + } + if strings.Contains(host, "http://") || strings.Contains(host, "https://") { + return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host) + } + } + return hosts, nil +} + +func parseAllowedOrigins(origins []string) ([]string, error) { + for _, origin := range origins { + if !strings.Contains(origin, "*") && !(strings.Contains(origin, "http://") || strings.Contains(origin, "https://")) { + return nil, xerrors.Errorf("origin must contain http:// or https://: %q", origin) + } + } + return origins, nil } // NewServer creates a new server instance -func NewServer(ctx context.Context, config ServerConfig) *Server { +func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { router := chi.NewMux() + logger := logctx.From(ctx) + + allowedHosts, err := parseAllowedHosts(config.AllowedHosts) + if err != nil { + return nil, xerrors.Errorf("failed to validate allowed hosts: %w", err) + } + if len(allowedHosts) > 0 { + logger.Info(fmt.Sprintf("Allowed hosts: %s", strings.Join(allowedHosts, ", "))) + } else { + logger.Info("Allowed hosts: *") + } + + allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins) + if err != nil { + return nil, xerrors.Errorf("failed to validate allowed origins: %w", err) + } + logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", "))) + + secureMiddleware := secure.New(secure.Options{ + AllowedHosts: allowedHosts, + }) + badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest) + }) + secureMiddleware.SetBadHostHandler(badHostHandler) + router.Use(secureMiddleware.Handler) + corsMiddleware := cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, + AllowedOrigins: allowedOrigins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, ExposedHeaders: []string{"Link"}, @@ -101,7 +156,7 @@ func NewServer(ctx context.Context, config ServerConfig) *Server { api: api, port: config.Port, conversation: conversation, - logger: logctx.From(ctx), + logger: logger, agentio: config.Process, agentType: config.AgentType, emitter: emitter, @@ -111,7 +166,7 @@ func NewServer(ctx context.Context, config ServerConfig) *Server { // Register API routes s.registerRoutes() - return s + return s, nil } // Handler returns the underlying chi.Router for testing purposes. diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index badc974..f154ef8 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -46,12 +46,15 @@ func TestOpenAPISchema(t *testing.T) { t.Parallel() ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - srv := httpapi.NewServer(ctx, httpapi.ServerConfig{ - AgentType: msgfmt.AgentTypeClaude, - Process: nil, - Port: 0, - ChatBasePath: "/chat", + srv, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, }) + require.NoError(t, err) currentSchemaStr := srv.GetOpenAPI() var currentSchema any if err := json.Unmarshal([]byte(currentSchemaStr), ¤tSchema); err != nil { @@ -95,12 +98,15 @@ func TestServer_redirectToChat(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() tCtx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) - s := httpapi.NewServer(tCtx, httpapi.ServerConfig{ - AgentType: msgfmt.AgentTypeClaude, - Process: nil, - Port: 0, - ChatBasePath: tc.chatBasePath, + s, err := httpapi.NewServer(tCtx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: tc.chatBasePath, + AllowedHosts: []string{"*"}, + AllowedOrigins: []string{"*"}, }) + require.NoError(t, err) tsServer := httptest.NewServer(s.Handler()) t.Cleanup(tsServer.Close) @@ -120,3 +126,361 @@ func TestServer_redirectToChat(t *testing.T) { }) } } + +func TestServer_AllowedHosts(t *testing.T) { + cases := []struct { + name string + allowedHosts []string + hostHeader string + expectedStatusCode int + expectedErrorMsg string + }{ + { + name: "wildcard hosts - any host allowed", + allowedHosts: []string{"*"}, + hostHeader: "example.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "wildcard hosts - another host allowed", + allowedHosts: []string{"*"}, + hostHeader: "malicious.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - valid host allowed", + allowedHosts: []string{"localhost:3000", "app.example.com"}, + hostHeader: "localhost:3000", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - another valid host allowed", + allowedHosts: []string{"localhost:3000", "app.example.com"}, + hostHeader: "app.example.com", + expectedStatusCode: http.StatusOK, + }, + { + name: "specific hosts - invalid host rejected", + allowedHosts: []string{"localhost:3000", "app.example.com"}, + hostHeader: "malicious.com", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: localhost:3000, app.example.com", + }, + { + name: "empty hosts - any host allowed", + allowedHosts: []string{}, + hostHeader: "anything.com", + expectedStatusCode: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: tc.allowedHosts, + AllowedOrigins: []string{"https://example.com"}, // Set a default to isolate host testing + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("GET", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.hostHeader != "" { + req.Host = tc.hostHeader + } + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectedErrorMsg != "" { + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), tc.expectedErrorMsg) + } + }) + } +} + +func TestServer_CORSPreflightWithHosts(t *testing.T) { + cases := []struct { + name string + allowedHosts []string + hostHeader string + originHeader string + expectedStatusCode int + expectCORSHeaders bool + }{ + { + name: "preflight with wildcard hosts", + allowedHosts: []string{"*"}, + hostHeader: "example.com", + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with specific valid host", + allowedHosts: []string{"localhost:3000"}, + hostHeader: "localhost:3000", + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with invalid host", + allowedHosts: []string{"localhost:3000"}, + hostHeader: "malicious.com", + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusBadRequest, + expectCORSHeaders: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: tc.allowedHosts, + AllowedOrigins: []string{"*"}, // Set wildcard origins to isolate host testing + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + // Test CORS preflight request + req, err := http.NewRequest("OPTIONS", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.hostHeader != "" { + req.Host = tc.hostHeader + } + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSHeaders { + allowMethods := resp.Header.Get("Access-Control-Allow-Methods") + require.Contains(t, allowMethods, "GET", "expected GET in allowed methods") + + allowHeaders := resp.Header.Get("Access-Control-Allow-Headers") + require.Contains(t, allowHeaders, "Content-Type", "expected Content-Type in allowed headers") + } + }) + } +} + +func TestServer_CORSOrigins(t *testing.T) { + cases := []struct { + name string + allowedOrigins []string + originHeader string + expectedStatusCode int + expectedCORSOrigin string + expectCORSOriginHeader bool + }{ + { + name: "wildcard origins - any origin allowed", + allowedOrigins: []string{"*"}, + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + }, + { + name: "wildcard origins - malicious origin allowed", + allowedOrigins: []string{"*"}, + originHeader: "http://malicious.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - valid origin allowed https", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "https://localhost:3000", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - valid origin allowed http", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "http://app.example.com", + expectedStatusCode: http.StatusOK, + expectedCORSOrigin: "http://app.example.com", + expectCORSOriginHeader: true, + }, + { + name: "specific origins - invalid origin rejected", + allowedOrigins: []string{"https://localhost:3000", "http://app.example.com"}, + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusOK, // Server allows request - CORS is enforced by browser + expectCORSOriginHeader: false, + }, + { + name: "no origin header - request not coming from a browser", + allowedOrigins: []string{"https://example.com"}, + originHeader: "", + expectedStatusCode: http.StatusOK, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing + AllowedOrigins: tc.allowedOrigins, + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("GET", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSOriginHeader { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Equal(t, tc.expectedCORSOrigin, corsOrigin, + "expected CORS origin %q, got %q", tc.expectedCORSOrigin, corsOrigin) + } else if tc.expectedStatusCode == http.StatusOK && tc.originHeader != "" { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Empty(t, corsOrigin, "expected no CORS origin header, got %q", corsOrigin) + } + }) + } +} + +func TestServer_CORSPreflightOrigins(t *testing.T) { + cases := []struct { + name string + allowedOrigins []string + originHeader string + expectedStatusCode int + expectCORSHeaders bool + }{ + { + name: "preflight with wildcard origins", + allowedOrigins: []string{"*"}, + originHeader: "https://example.com", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with specific valid origin", + allowedOrigins: []string{"https://localhost:3000"}, + originHeader: "https://localhost:3000", + expectedStatusCode: http.StatusOK, + expectCORSHeaders: true, + }, + { + name: "preflight with invalid origin", + allowedOrigins: []string{"https://localhost:3000"}, + originHeader: "https://malicious.com", + expectedStatusCode: http.StatusOK, // Request succeeds but no CORS headers + expectCORSHeaders: false, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := logctx.WithLogger(context.Background(), slog.New(slog.NewTextHandler(os.Stdout, nil))) + s, err := httpapi.NewServer(ctx, httpapi.ServerConfig{ + AgentType: msgfmt.AgentTypeClaude, + Process: nil, + Port: 0, + ChatBasePath: "/chat", + AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing + AllowedOrigins: tc.allowedOrigins, + }) + require.NoError(t, err) + tsServer := httptest.NewServer(s.Handler()) + t.Cleanup(tsServer.Close) + + req, err := http.NewRequest("OPTIONS", tsServer.URL+"/status", nil) + require.NoError(t, err) + + if tc.originHeader != "" { + req.Header.Set("Origin", tc.originHeader) + } + req.Header.Set("Access-Control-Request-Method", "GET") + req.Header.Set("Access-Control-Request-Headers", "Content-Type") + + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + t.Cleanup(func() { + _ = resp.Body.Close() + }) + + require.Equal(t, tc.expectedStatusCode, resp.StatusCode, + "expected status code %d, got %d", tc.expectedStatusCode, resp.StatusCode) + + if tc.expectCORSHeaders { + allowMethods := resp.Header.Get("Access-Control-Allow-Methods") + require.Contains(t, allowMethods, "GET", "expected GET in allowed methods") + + allowHeaders := resp.Header.Get("Access-Control-Allow-Headers") + require.Contains(t, allowHeaders, "Content-Type", "expected Content-Type in allowed headers") + + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.NotEmpty(t, corsOrigin, "expected CORS origin header for valid preflight") + } else if tc.originHeader != "" { + corsOrigin := resp.Header.Get("Access-Control-Allow-Origin") + require.Empty(t, corsOrigin, "expected no CORS origin header for invalid origin") + } + }) + } +} From d2400b91f9424cdc5db19520aa12b2a57a90f81f Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Fri, 8 Aug 2025 20:08:05 +0200 Subject: [PATCH 2/4] fix: strip port from host --- README.md | 2 +- cmd/server/server.go | 66 +++++++++++++++++++++++++++--- cmd/server/server_test.go | 82 +++++++++++++++++++++++++++++--------- go.mod | 1 - go.sum | 2 - lib/httpapi/server.go | 69 ++++++++++++++++++++++++++++---- lib/httpapi/server_test.go | 37 ++++++++++++++--- 7 files changed, 217 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 8dff3c4..c4c39bc 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ There are 4 endpoints: #### Allowed hosts -By default, the server only allows requests with the host header set to localhost:3284. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. +By default, the server only allows requests with the host header set to `localhost`. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. Hosts must be hostnames only (no ports); the server ignores the port portion of incoming requests when authorizing. To allow requests from any host, use `*` as the allowed host. diff --git a/cmd/server/server.go b/cmd/server/server.go index be57053..69757b8 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "net/http" + "net/url" "os" "sort" "strings" @@ -59,10 +60,63 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) { return AgentTypeCustom, nil } -// Validate allowed hosts or origins don't contain whitespace or commas. +// Validate allowed hosts don't contain whitespace, commas, schemes, or ports. // Viper/Cobra use different separators (space for env vars, comma for flags), // so these characters likely indicate user error. -func validateAllowedHostsOrOrigins(input []string) error { +func validateAllowedHosts(input []string) error { + if len(input) == 0 { + return fmt.Errorf("the list must not be empty") + } + // First pass: whitespace & comma checks (surface these errors first) + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } + } + if strings.Contains(item, ",") { + return fmt.Errorf("'%s' contains comma characters, which are not allowed", item) + } + } + // Second pass: scheme check + for _, item := range input { + if strings.Contains(item, "http://") || strings.Contains(item, "https://") { + return fmt.Errorf("'%s' must not include http:// or https://", item) + } + } + // Third pass: port check (but allow IPv6 literals without ports) + for _, item := range input { + trimmed := strings.TrimSpace(item) + colonCount := strings.Count(trimmed, ":") + // If bracketed, rely on url.Parse to detect a port in "]:" form. + if strings.HasPrefix(trimmed, "[") { + if u, err := url.Parse("http://" + trimmed); err == nil { + if u.Port() != "" { + return fmt.Errorf("'%s' must not include a port", item) + } + } + continue + } + // Unbracketed IPv6: multiple colons and no brackets; treat as valid (no ports allowed here) + if colonCount >= 2 { + continue + } + // IPv4 or hostname: if URL parsing finds a port or there's a single colon, it's invalid + if u, err := url.Parse("http://" + trimmed); err == nil { + if u.Port() != "" { + return fmt.Errorf("'%s' must not include a port", item) + } + } + if colonCount == 1 { + return fmt.Errorf("'%s' must not include a port", item) + } + } + return nil +} + +// Validate allowed origins don't contain whitespace or commas. +// Origins must include a scheme, validated later by the HTTP layer. +func validateAllowedOrigins(input []string) error { if len(input) == 0 { return fmt.Errorf("the list must not be empty") } @@ -195,11 +249,11 @@ func CreateServerCmd() *cobra.Command { Args: cobra.MinimumNArgs(1), PreRunE: func(cmd *cobra.Command, args []string) error { allowedHosts := viper.GetStringSlice(FlagAllowedHosts) - if err := validateAllowedHostsOrOrigins(allowedHosts); err != nil { + if err := validateAllowedHosts(allowedHosts); err != nil { return xerrors.Errorf("failed to validate allowed hosts: %w", err) } allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins) - if err := validateAllowedHostsOrOrigins(allowedOrigins); err != nil { + if err := validateAllowedOrigins(allowedOrigins); err != nil { return xerrors.Errorf("failed to validate allowed origins: %w", err) } return nil @@ -225,8 +279,8 @@ func CreateServerCmd() *cobra.Command { {FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"}, {FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"}, {FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"}, - // localhost:3284 is the default host for the server - {FlagAllowedHosts, "a", []string{"localhost:3284"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, + // localhost is the default host for the server. Port is ignored during matching. + {FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, } diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index 4044d22..eb339f4 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -155,7 +155,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) { {"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }}, {"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }}, {"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }}, - {"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, {"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } @@ -189,7 +189,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) { {"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }}, {"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }}, {"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }}, - {"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, {"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } @@ -325,21 +325,21 @@ func TestServerCmd_AllowedHosts(t *testing.T) { // Environment variable scenarios (space-separated format) { name: "env: single valid host", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"}, + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, args: []string{}, - expected: []string{"localhost:3284"}, + expected: []string{"localhost"}, }, { name: "env: multiple valid hosts space-separated", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"}, + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost example.com 192.168.1.1"}, args: []string{}, - expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"}, + expected: []string{"localhost", "example.com", "192.168.1.1"}, }, { name: "env: host with tab", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"}, + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost\texample.com"}, args: []string{}, - expected: []string{"localhost:3284", "example.com"}, + expected: []string{"localhost", "example.com"}, }, { name: "env: host with comma (invalid)", @@ -347,44 +347,88 @@ func TestServerCmd_AllowedHosts(t *testing.T) { args: []string{}, expectedErr: "contains comma characters", }, + { + name: "env: host with port (invalid)", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"}, + args: []string{}, + expectedErr: "must not include a port", + }, + { + name: "env: ipv6 literal", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "2001:db8::1"}, + args: []string{}, + expected: []string{"2001:db8::1"}, + }, + { + name: "env: ipv6 bracketed literal", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]"}, + args: []string{}, + expected: []string{"[2001:db8::1]"}, + }, + { + name: "env: ipv6 with port (invalid)", + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]:443"}, + args: []string{}, + expectedErr: "must not include a port", + }, // CLI flag scenarios (comma-separated format) { name: "flag: single valid host", - args: []string{"--allowed-hosts", "localhost:3284"}, - expected: []string{"localhost:3284"}, + args: []string{"--allowed-hosts", "localhost"}, + expected: []string{"localhost"}, }, { name: "flag: multiple valid hosts comma-separated", - args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"}, - expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"}, + args: []string{"--allowed-hosts", "localhost,example.com,192.168.1.1"}, + expected: []string{"localhost", "example.com", "192.168.1.1"}, }, { name: "flag: multiple valid hosts with multiple flags", - args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"}, - expected: []string{"localhost:3284", "example.com"}, + args: []string{"--allowed-hosts", "localhost", "--allowed-hosts", "example.com"}, + expected: []string{"localhost", "example.com"}, }, { name: "flag: host with newline", - args: []string{"--allowed-hosts", "localhost:3284\n"}, - expected: []string{"localhost:3284"}, + args: []string{"--allowed-hosts", "localhost\n"}, + expected: []string{"localhost"}, }, { name: "flag: host with space in comma-separated list (invalid)", args: []string{"--allowed-hosts", "localhost:3284,example .com"}, expectedErr: "contains whitespace characters", }, + { + name: "flag: host with port (invalid)", + args: []string{"--allowed-hosts", "localhost:3284"}, + expectedErr: "must not include a port", + }, + { + name: "flag: ipv6 literal", + args: []string{"--allowed-hosts", "2001:db8::1"}, + expected: []string{"2001:db8::1"}, + }, + { + name: "flag: ipv6 bracketed literal", + args: []string{"--allowed-hosts", "[2001:db8::1]"}, + expected: []string{"[2001:db8::1]"}, + }, + { + name: "flag: ipv6 with port (invalid)", + args: []string{"--allowed-hosts", "[2001:db8::1]:443"}, + expectedErr: "must not include a port", + }, // Mixed scenarios (env + flag precedence) { name: "mixed: flag overrides env", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"}, + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, args: []string{"--allowed-hosts", "override.com"}, expected: []string{"override.com"}, }, { name: "mixed: flag overrides env but flag is invalid", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"}, + env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, args: []string{"--allowed-hosts", "invalid .com"}, expectedErr: "contains whitespace characters", }, @@ -400,7 +444,7 @@ func TestServerCmd_AllowedHosts(t *testing.T) { { name: "default hosts when neither env nor flag provided", args: []string{}, - expected: []string{"localhost:3284"}, + expected: []string{"localhost"}, }, } diff --git a/go.mod b/go.mod index d0587ad..b4212f0 100644 --- a/go.mod +++ b/go.mod @@ -54,7 +54,6 @@ require ( github.com/rogpeppe/go-internal v1.14.1 // indirect github.com/spf13/afero v1.14.0 github.com/spf13/pflag v1.0.6 // indirect - github.com/unrolled/secure v1.17.0 golang.org/x/sync v0.12.0 // indirect golang.org/x/sys v0.31.0 // indirect golang.org/x/text v0.23.0 // indirect diff --git a/go.sum b/go.sum index 1fff5e8..afb727f 100644 --- a/go.sum +++ b/go.sum @@ -113,8 +113,6 @@ github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8 github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA= github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8= -github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbWU= -github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index e6c89be..29c9332 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "log/slog" + "net" "net/http" "net/url" "slices" @@ -21,7 +22,6 @@ import ( "github.com/danielgtaylor/huma/v2/sse" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" - "github.com/unrolled/secure" "golang.org/x/xerrors" ) @@ -82,7 +82,32 @@ func parseAllowedHosts(hosts []string) ([]string, error) { return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host) } } - return hosts, nil + // Normalize hosts to bare hostnames/IPs by stripping any port and brackets. + // This ensures allowed entries match the Host header hostname only. + normalized := make([]string, 0, len(hosts)) + for _, raw := range hosts { + h := strings.TrimSpace(raw) + // If it's an IPv6 literal (possibly bracketed) without an obvious port, keep the literal. + unbracketed := strings.Trim(h, "[]") + if ip := net.ParseIP(unbracketed); ip != nil { + // It's an IP literal; use the bare form without brackets. + normalized = append(normalized, unbracketed) + continue + } + // If likely host:port (single colon) or bracketed host, use url.Parse to extract hostname. + if strings.Count(h, ":") == 1 || (strings.HasPrefix(h, "[") && strings.Contains(h, "]")) { + if u, err := url.Parse("http://" + h); err == nil { + hn := u.Hostname() + if hn != "" { + normalized = append(normalized, hn) + continue + } + } + } + // Fallback: use as-is (e.g., hostname without port) + normalized = append(normalized, h) + } + return normalized, nil } func parseAllowedOrigins(origins []string) ([]string, error) { @@ -116,14 +141,11 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { } logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", "))) - secureMiddleware := secure.New(secure.Options{ - AllowedHosts: allowedHosts, - }) + // Enforce allowed hosts in a custom middleware that ignores the port during matching. badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest) }) - secureMiddleware.SetBadHostHandler(badHostHandler) - router.Use(secureMiddleware.Handler) + router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler)) corsMiddleware := cors.New(cors.Options{ AllowedOrigins: allowedOrigins, @@ -174,6 +196,39 @@ func (s *Server) Handler() http.Handler { return s.router } +// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed +// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed. +// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6). +func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler { + // Copy for safety; also build a map for O(1) lookups with case-insensitive keys. + allowed := make(map[string]struct{}, len(allowedHosts)) + for _, h := range allowedHosts { + allowed[strings.ToLower(h)] = struct{}{} + } + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if len(allowedHosts) == 0 { // wildcard semantics: allow all + next.ServeHTTP(w, r) + return + } + // Extract hostname from the Host header using url.Parse; ignore any port. + hostHeader := r.Host + if hostHeader == "" { + badHostHandler.ServeHTTP(w, r) + return + } + if u, err := url.Parse("http://" + hostHeader); err == nil { + hostname := u.Hostname() + if _, ok := allowed[strings.ToLower(hostname)]; ok { + next.ServeHTTP(w, r) + return + } + } + badHostHandler.ServeHTTP(w, r) + }) + } +} + func (s *Server) StartSnapshotLoop(ctx context.Context) { s.conversation.StartSnapshotLoop(ctx) go func() { diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index f154ef8..5a80ee9 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -149,22 +149,22 @@ func TestServer_AllowedHosts(t *testing.T) { }, { name: "specific hosts - valid host allowed", - allowedHosts: []string{"localhost:3000", "app.example.com"}, + allowedHosts: []string{"localhost", "app.example.com"}, hostHeader: "localhost:3000", expectedStatusCode: http.StatusOK, }, { name: "specific hosts - another valid host allowed", - allowedHosts: []string{"localhost:3000", "app.example.com"}, + allowedHosts: []string{"localhost", "app.example.com"}, hostHeader: "app.example.com", expectedStatusCode: http.StatusOK, }, { name: "specific hosts - invalid host rejected", - allowedHosts: []string{"localhost:3000", "app.example.com"}, + allowedHosts: []string{"localhost", "app.example.com"}, hostHeader: "malicious.com", expectedStatusCode: http.StatusBadRequest, - expectedErrorMsg: "Invalid host header. Allowed hosts: localhost:3000, app.example.com", + expectedErrorMsg: "Invalid host header. Allowed hosts: localhost, app.example.com", }, { name: "empty hosts - any host allowed", @@ -172,6 +172,31 @@ func TestServer_AllowedHosts(t *testing.T) { hostHeader: "anything.com", expectedStatusCode: http.StatusOK, }, + { + name: "ipv6 literal allowed - no port", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::1]", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 literal allowed - with port", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::1]:1234", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 bracketed configured allowed - with port", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 literal invalid host rejected", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::2]", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", + }, } for _, tc := range cases { @@ -235,7 +260,7 @@ func TestServer_CORSPreflightWithHosts(t *testing.T) { }, { name: "preflight with specific valid host", - allowedHosts: []string{"localhost:3000"}, + allowedHosts: []string{"localhost"}, hostHeader: "localhost:3000", originHeader: "https://localhost:3000", expectedStatusCode: http.StatusOK, @@ -243,7 +268,7 @@ func TestServer_CORSPreflightWithHosts(t *testing.T) { }, { name: "preflight with invalid host", - allowedHosts: []string{"localhost:3000"}, + allowedHosts: []string{"localhost"}, hostHeader: "malicious.com", originHeader: "https://malicious.com", expectedStatusCode: http.StatusBadRequest, From 6167d5e8594d1f9c96b8d53123fa5307d08af6dd Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Mon, 11 Aug 2025 13:46:50 +0200 Subject: [PATCH 3/4] format --- lib/httpapi/server_test.go | 50 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index 5a80ee9..ca5fbcd 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -172,31 +172,31 @@ func TestServer_AllowedHosts(t *testing.T) { hostHeader: "anything.com", expectedStatusCode: http.StatusOK, }, - { - name: "ipv6 literal allowed - no port", - allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::1]", - expectedStatusCode: http.StatusOK, - }, - { - name: "ipv6 literal allowed - with port", - allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::1]:1234", - expectedStatusCode: http.StatusOK, - }, - { - name: "ipv6 bracketed configured allowed - with port", - allowedHosts: []string{"[2001:db8::1]"}, - hostHeader: "[2001:db8::1]:80", - expectedStatusCode: http.StatusOK, - }, - { - name: "ipv6 literal invalid host rejected", - allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::2]", - expectedStatusCode: http.StatusBadRequest, - expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", - }, + { + name: "ipv6 literal allowed - no port", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::1]", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 literal allowed - with port", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::1]:1234", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 bracketed configured allowed - with port", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 literal invalid host rejected", + allowedHosts: []string{"2001:db8::1"}, + hostHeader: "[2001:db8::2]", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", + }, } for _, tc := range cases { From 2441cdc496b2ed3ae7d7073f9c9265d61a11f696 Mon Sep 17 00:00:00 2001 From: Hugo Dutka Date: Mon, 11 Aug 2025 14:27:47 +0200 Subject: [PATCH 4/4] chore: move allowed origins and hosts validation into the httpapi package --- cmd/server/server.go | 88 +-------------------- cmd/server/server_test.go | 110 +------------------------- lib/httpapi/server.go | 122 +++++++++++++++++----------- lib/httpapi/server_test.go | 158 ++++++++++++++++++++++++++++++++----- 4 files changed, 222 insertions(+), 256 deletions(-) diff --git a/cmd/server/server.go b/cmd/server/server.go index 69757b8..b236532 100644 --- a/cmd/server/server.go +++ b/cmd/server/server.go @@ -6,11 +6,9 @@ import ( "fmt" "log/slog" "net/http" - "net/url" "os" "sort" "strings" - "unicode" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -60,79 +58,6 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) { return AgentTypeCustom, nil } -// Validate allowed hosts don't contain whitespace, commas, schemes, or ports. -// Viper/Cobra use different separators (space for env vars, comma for flags), -// so these characters likely indicate user error. -func validateAllowedHosts(input []string) error { - if len(input) == 0 { - return fmt.Errorf("the list must not be empty") - } - // First pass: whitespace & comma checks (surface these errors first) - for _, item := range input { - for _, r := range item { - if unicode.IsSpace(r) { - return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) - } - } - if strings.Contains(item, ",") { - return fmt.Errorf("'%s' contains comma characters, which are not allowed", item) - } - } - // Second pass: scheme check - for _, item := range input { - if strings.Contains(item, "http://") || strings.Contains(item, "https://") { - return fmt.Errorf("'%s' must not include http:// or https://", item) - } - } - // Third pass: port check (but allow IPv6 literals without ports) - for _, item := range input { - trimmed := strings.TrimSpace(item) - colonCount := strings.Count(trimmed, ":") - // If bracketed, rely on url.Parse to detect a port in "]:" form. - if strings.HasPrefix(trimmed, "[") { - if u, err := url.Parse("http://" + trimmed); err == nil { - if u.Port() != "" { - return fmt.Errorf("'%s' must not include a port", item) - } - } - continue - } - // Unbracketed IPv6: multiple colons and no brackets; treat as valid (no ports allowed here) - if colonCount >= 2 { - continue - } - // IPv4 or hostname: if URL parsing finds a port or there's a single colon, it's invalid - if u, err := url.Parse("http://" + trimmed); err == nil { - if u.Port() != "" { - return fmt.Errorf("'%s' must not include a port", item) - } - } - if colonCount == 1 { - return fmt.Errorf("'%s' must not include a port", item) - } - } - return nil -} - -// Validate allowed origins don't contain whitespace or commas. -// Origins must include a scheme, validated later by the HTTP layer. -func validateAllowedOrigins(input []string) error { - if len(input) == 0 { - return fmt.Errorf("the list must not be empty") - } - for _, item := range input { - for _, r := range item { - if unicode.IsSpace(r) { - return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) - } - } - if strings.Contains(item, ",") { - return fmt.Errorf("'%s' contains comma characters, which are not allowed", item) - } - } - return nil -} - func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error { agent := argsToPass[0] agentTypeValue := viper.GetString(FlagType) @@ -247,17 +172,6 @@ func CreateServerCmd() *cobra.Command { Short: "Run the server", Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")), Args: cobra.MinimumNArgs(1), - PreRunE: func(cmd *cobra.Command, args []string) error { - allowedHosts := viper.GetStringSlice(FlagAllowedHosts) - if err := validateAllowedHosts(allowedHosts); err != nil { - return xerrors.Errorf("failed to validate allowed hosts: %w", err) - } - allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins) - if err := validateAllowedOrigins(allowedOrigins); err != nil { - return xerrors.Errorf("failed to validate allowed origins: %w", err) - } - return nil - }, Run: func(cmd *cobra.Command, args []string) { // The --exit flag is used for testing validation of flags in the test suite if viper.GetBool(FlagExit) { @@ -280,7 +194,7 @@ func CreateServerCmd() *cobra.Command { {FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"}, {FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"}, // localhost is the default host for the server. Port is ignored during matching. - {FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, + {FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"}, // localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development. {FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"}, } diff --git a/cmd/server/server_test.go b/cmd/server/server_test.go index eb339f4..ed88fce 100644 --- a/cmd/server/server_test.go +++ b/cmd/server/server_test.go @@ -155,7 +155,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) { {"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }}, {"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }}, {"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }}, - {"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, + {"allowed-hosts default", FlagAllowedHosts, []string{"localhost", "127.0.0.1", "[::1]"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }}, {"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }}, } @@ -341,37 +341,6 @@ func TestServerCmd_AllowedHosts(t *testing.T) { args: []string{}, expected: []string{"localhost", "example.com"}, }, - { - name: "env: host with comma (invalid)", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"}, - args: []string{}, - expectedErr: "contains comma characters", - }, - { - name: "env: host with port (invalid)", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"}, - args: []string{}, - expectedErr: "must not include a port", - }, - { - name: "env: ipv6 literal", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "2001:db8::1"}, - args: []string{}, - expected: []string{"2001:db8::1"}, - }, - { - name: "env: ipv6 bracketed literal", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]"}, - args: []string{}, - expected: []string{"[2001:db8::1]"}, - }, - { - name: "env: ipv6 with port (invalid)", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]:443"}, - args: []string{}, - expectedErr: "must not include a port", - }, - // CLI flag scenarios (comma-separated format) { name: "flag: single valid host", @@ -394,30 +363,10 @@ func TestServerCmd_AllowedHosts(t *testing.T) { expected: []string{"localhost"}, }, { - name: "flag: host with space in comma-separated list (invalid)", - args: []string{"--allowed-hosts", "localhost:3284,example .com"}, - expectedErr: "contains whitespace characters", - }, - { - name: "flag: host with port (invalid)", - args: []string{"--allowed-hosts", "localhost:3284"}, - expectedErr: "must not include a port", + name: "flag: ipv6 bracketed literal", + args: []string{"--allowed-hosts", "[2001:db8::1]"}, + expected: []string{"[2001:db8::1]"}, }, - { - name: "flag: ipv6 literal", - args: []string{"--allowed-hosts", "2001:db8::1"}, - expected: []string{"2001:db8::1"}, - }, - { - name: "flag: ipv6 bracketed literal", - args: []string{"--allowed-hosts", "[2001:db8::1]"}, - expected: []string{"[2001:db8::1]"}, - }, - { - name: "flag: ipv6 with port (invalid)", - args: []string{"--allowed-hosts", "[2001:db8::1]:443"}, - expectedErr: "must not include a port", - }, // Mixed scenarios (env + flag precedence) { @@ -426,26 +375,6 @@ func TestServerCmd_AllowedHosts(t *testing.T) { args: []string{"--allowed-hosts", "override.com"}, expected: []string{"override.com"}, }, - { - name: "mixed: flag overrides env but flag is invalid", - env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"}, - args: []string{"--allowed-hosts", "invalid .com"}, - expectedErr: "contains whitespace characters", - }, - - // Empty hosts are not allowed - { - name: "empty host", - args: []string{"--allowed-hosts", ""}, - expectedErr: "the list must not be empty", - }, - - // Default behavior - { - name: "default hosts when neither env nor flag provided", - args: []string{}, - expected: []string{"localhost"}, - }, } for _, tt := range tests { @@ -506,12 +435,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) { args: []string{}, expected: []string{"https://example.com", "http://localhost:3000"}, }, - { - name: "env: origin with comma (invalid)", - env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com,http://localhost:3000"}, - args: []string{}, - expectedErr: "contains comma characters", - }, // CLI flag scenarios (comma-separated format) { @@ -539,11 +462,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) { args: []string{"--allowed-origins", "https://example.com\n"}, expected: []string{"https://example.com"}, }, - { - name: "flag: origin with space in comma-separated list (invalid)", - args: []string{"--allowed-origins", "https://example.com,http://localhost :3000"}, - expectedErr: "contains whitespace characters", - }, // Mixed scenarios (env + flag precedence) { @@ -552,26 +470,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) { args: []string{"--allowed-origins", "https://override.com"}, expected: []string{"https://override.com"}, }, - { - name: "mixed: flag overrides env but flag is invalid", - env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://env-example.com"}, - args: []string{"--allowed-origins", "invalid origin"}, - expectedErr: "contains whitespace characters", - }, - - // Empty origins are not allowed - { - name: "empty origin", - args: []string{"--allowed-origins", ""}, - expectedErr: "the list must not be empty", - }, - - // Default behavior - { - name: "default origins when neither env nor flag provided", - args: []string{}, - expected: []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, - }, } for _, tt := range tests { diff --git a/lib/httpapi/server.go b/lib/httpapi/server.go index 29c9332..4f1bb14 100644 --- a/lib/httpapi/server.go +++ b/lib/httpapi/server.go @@ -5,13 +5,13 @@ import ( "encoding/json" "fmt" "log/slog" - "net" "net/http" "net/url" "slices" "strings" "sync" "time" + "unicode" "github.com/coder/agentapi/lib/logctx" mf "github.com/coder/agentapi/lib/msgfmt" @@ -70,52 +70,87 @@ type ServerConfig struct { AllowedOrigins []string } -func parseAllowedHosts(hosts []string) ([]string, error) { - if slices.Contains(hosts, "*") { - return []string{}, nil +// Validate allowed hosts don't contain whitespace, commas, schemes, or ports. +// Viper/Cobra use different separators (space for env vars, comma for flags), +// so these characters likely indicate user error. +func parseAllowedHosts(input []string) ([]string, error) { + if len(input) == 0 { + return nil, fmt.Errorf("the list must not be empty") } - for _, host := range hosts { - if strings.Contains(host, "*") { - return nil, xerrors.Errorf("wildcard characters are not supported: %q", host) + if slices.Contains(input, "*") { + return []string{"*"}, nil + } + // First pass: whitespace & comma checks (surface these errors first) + // Viper/Cobra use different separators (space for env vars, comma for flags), + // so these characters likely indicate user error. + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return nil, fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } } - if strings.Contains(host, "http://") || strings.Contains(host, "https://") { - return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host) + if strings.Contains(item, ",") { + return nil, fmt.Errorf("'%s' contains comma characters, which are not allowed", item) } } - // Normalize hosts to bare hostnames/IPs by stripping any port and brackets. - // This ensures allowed entries match the Host header hostname only. - normalized := make([]string, 0, len(hosts)) - for _, raw := range hosts { - h := strings.TrimSpace(raw) - // If it's an IPv6 literal (possibly bracketed) without an obvious port, keep the literal. - unbracketed := strings.Trim(h, "[]") - if ip := net.ParseIP(unbracketed); ip != nil { - // It's an IP literal; use the bare form without brackets. - normalized = append(normalized, unbracketed) - continue + // Second pass: scheme check + for _, item := range input { + if strings.Contains(item, "http://") || strings.Contains(item, "https://") { + return nil, fmt.Errorf("'%s' must not include http:// or https://", item) } - // If likely host:port (single colon) or bracketed host, use url.Parse to extract hostname. - if strings.Count(h, ":") == 1 || (strings.HasPrefix(h, "[") && strings.Contains(h, "]")) { - if u, err := url.Parse("http://" + h); err == nil { - hn := u.Hostname() - if hn != "" { - normalized = append(normalized, hn) - continue - } - } + } + hosts := make([]*url.URL, 0, len(input)) + // Third pass: url parse + for _, item := range input { + trimmed := strings.TrimSpace(item) + u, err := url.Parse("http://" + trimmed) + if err != nil { + return nil, fmt.Errorf("'%s' is not a valid host: %w", item, err) + } + hosts = append(hosts, u) + } + // Fourth pass: port check + for _, u := range hosts { + if u.Port() != "" { + return nil, fmt.Errorf("'%s' must not include a port", u.Host) } - // Fallback: use as-is (e.g., hostname without port) - normalized = append(normalized, h) } - return normalized, nil + hostStrings := make([]string, 0, len(hosts)) + for _, u := range hosts { + hostStrings = append(hostStrings, u.Hostname()) + } + return hostStrings, nil } -func parseAllowedOrigins(origins []string) ([]string, error) { - for _, origin := range origins { - if !strings.Contains(origin, "*") && !(strings.Contains(origin, "http://") || strings.Contains(origin, "https://")) { - return nil, xerrors.Errorf("origin must contain http:// or https://: %q", origin) +// Validate allowed origins +func parseAllowedOrigins(input []string) ([]string, error) { + if len(input) == 0 { + return nil, fmt.Errorf("the list must not be empty") + } + if slices.Contains(input, "*") { + return []string{"*"}, nil + } + // Viper/Cobra use different separators (space for env vars, comma for flags), + // so these characters likely indicate user error. + for _, item := range input { + for _, r := range item { + if unicode.IsSpace(r) { + return nil, fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item) + } + } + if strings.Contains(item, ",") { + return nil, fmt.Errorf("'%s' contains comma characters, which are not allowed", item) } } + origins := make([]string, 0, len(input)) + for _, item := range input { + trimmed := strings.TrimSpace(item) + u, err := url.Parse(trimmed) + if err != nil { + return nil, fmt.Errorf("'%s' is not a valid origin: %w", item, err) + } + origins = append(origins, fmt.Sprintf("%s://%s", u.Scheme, u.Host)) + } return origins, nil } @@ -127,18 +162,14 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) { allowedHosts, err := parseAllowedHosts(config.AllowedHosts) if err != nil { - return nil, xerrors.Errorf("failed to validate allowed hosts: %w", err) - } - if len(allowedHosts) > 0 { - logger.Info(fmt.Sprintf("Allowed hosts: %s", strings.Join(allowedHosts, ", "))) - } else { - logger.Info("Allowed hosts: *") + return nil, xerrors.Errorf("failed to parse allowed hosts: %w", err) } - allowedOrigins, err := parseAllowedOrigins(config.AllowedOrigins) if err != nil { - return nil, xerrors.Errorf("failed to validate allowed origins: %w", err) + return nil, xerrors.Errorf("failed to parse allowed origins: %w", err) } + + logger.Info(fmt.Sprintf("Allowed hosts: %s", strings.Join(allowedHosts, ", "))) logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", "))) // Enforce allowed hosts in a custom middleware that ignores the port during matching. @@ -205,9 +236,10 @@ func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Hand for _, h := range allowedHosts { allowed[strings.ToLower(h)] = struct{}{} } + wildcard := slices.Contains(allowedHosts, "*") return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if len(allowedHosts) == 0 { // wildcard semantics: allow all + if wildcard { // wildcard semantics: allow all next.ServeHTTP(w, r) return } diff --git a/lib/httpapi/server_test.go b/lib/httpapi/server_test.go index ca5fbcd..bc50d3e 100644 --- a/lib/httpapi/server_test.go +++ b/lib/httpapi/server_test.go @@ -134,6 +134,7 @@ func TestServer_AllowedHosts(t *testing.T) { hostHeader string expectedStatusCode int expectedErrorMsg string + validationErrorMsg string }{ { name: "wildcard hosts - any host allowed", @@ -167,35 +168,93 @@ func TestServer_AllowedHosts(t *testing.T) { expectedErrorMsg: "Invalid host header. Allowed hosts: localhost, app.example.com", }, { - name: "empty hosts - any host allowed", - allowedHosts: []string{}, - hostHeader: "anything.com", + name: "ipv6 bracketed configured allowed - with port", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]:80", expectedStatusCode: http.StatusOK, }, { - name: "ipv6 literal allowed - no port", - allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::1]", - expectedStatusCode: http.StatusOK, + name: "ipv6 literal invalid host rejected", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::2]", + expectedStatusCode: http.StatusBadRequest, + expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", }, { - name: "ipv6 literal allowed - with port", + name: "allowed hosts must not be empty", + allowedHosts: []string{}, + validationErrorMsg: "the list must not be empty", + }, + { + name: "ipv6 literal without square brackets is invalid", allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::1]:1234", + validationErrorMsg: "must not include a port", + }, + { + name: "host with port in config is invalid", + allowedHosts: []string{"example.com:8080"}, + validationErrorMsg: "must not include a port", + }, + { + name: "bracketed ipv6 with port in config is invalid", + allowedHosts: []string{"[2001:db8::1]:443"}, + validationErrorMsg: "must not include a port", + }, + { + name: "hostname with http scheme is invalid", + allowedHosts: []string{"http://example.com"}, + validationErrorMsg: "must not include http:// or https://", + }, + { + name: "hostname with https scheme is invalid", + allowedHosts: []string{"https://example.com"}, + validationErrorMsg: "must not include http:// or https://", + }, + { + name: "hostname containing comma is invalid", + allowedHosts: []string{"example.com,malicious.com"}, + validationErrorMsg: "contains comma characters, which are not allowed", + }, + { + name: "hostname with leading whitespace is invalid", + allowedHosts: []string{" example.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "hostname with internal whitespace is invalid", + allowedHosts: []string{"exa mple.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "uppercase allowed host matches lowercase request", + allowedHosts: []string{"EXAMPLE.COM"}, + hostHeader: "example.com:80", expectedStatusCode: http.StatusOK, }, { - name: "ipv6 bracketed configured allowed - with port", - allowedHosts: []string{"[2001:db8::1]"}, - hostHeader: "[2001:db8::1]:80", + name: "wildcard with extra invalid entries still allows all", + allowedHosts: []string{"*", "https://bad.com", "example.com:8080", " space.com"}, + hostHeader: "malicious.com", expectedStatusCode: http.StatusOK, }, { - name: "ipv6 literal invalid host rejected", - allowedHosts: []string{"2001:db8::1"}, - hostHeader: "[2001:db8::2]", + name: "trailing dot in allowed host requires trailing dot in request (no match)", + allowedHosts: []string{"example.com."}, + hostHeader: "example.com", expectedStatusCode: http.StatusBadRequest, - expectedErrorMsg: "Invalid host header. Allowed hosts: 2001:db8::1", + expectedErrorMsg: "Invalid host header. Allowed hosts: example.com.", + }, + { + name: "trailing dot in allowed host matches trailing dot in request", + allowedHosts: []string{"example.com."}, + hostHeader: "example.com.:80", + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 bracketed configured allowed - without port header", + allowedHosts: []string{"[2001:db8::1]"}, + hostHeader: "[2001:db8::1]", + expectedStatusCode: http.StatusOK, }, } @@ -211,7 +270,13 @@ func TestServer_AllowedHosts(t *testing.T) { AllowedHosts: tc.allowedHosts, AllowedOrigins: []string{"https://example.com"}, // Set a default to isolate host testing }) - require.NoError(t, err) + if tc.validationErrorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.validationErrorMsg) + return + } else { + require.NoError(t, err) + } tsServer := httptest.NewServer(s.Handler()) t.Cleanup(tsServer.Close) @@ -334,6 +399,7 @@ func TestServer_CORSOrigins(t *testing.T) { expectedStatusCode int expectedCORSOrigin string expectCORSOriginHeader bool + validationErrorMsg string }{ { name: "wildcard origins - any origin allowed", @@ -380,6 +446,58 @@ func TestServer_CORSOrigins(t *testing.T) { originHeader: "", expectedStatusCode: http.StatusOK, }, + { + name: "allowed origins must not be empty", + allowedOrigins: []string{}, + validationErrorMsg: "the list must not be empty", + }, + { + name: "origin containing comma is invalid", + allowedOrigins: []string{"https://example.com,http://localhost:3000"}, + validationErrorMsg: "contains comma characters, which are not allowed", + }, + { + name: "origin with internal whitespace is invalid", + allowedOrigins: []string{"https://exa mple.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "origin with leading whitespace is invalid", + allowedOrigins: []string{" https://example.com"}, + validationErrorMsg: "contains whitespace characters, which are not allowed", + }, + { + name: "wildcard with extra invalid entries still allows all", + allowedOrigins: []string{"*", "https://bad.com,too", "http://bad host"}, + originHeader: "http://malicious.com", + expectedCORSOrigin: "*", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "ipv6 origin allowed", + allowedOrigins: []string{"http://[2001:db8::1]:8080"}, + originHeader: "http://[2001:db8::1]:8080", + expectedCORSOrigin: "http://[2001:db8::1]:8080", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "origin with path, query, and fragment normalizes to scheme+host", + allowedOrigins: []string{"https://example.com/path?x=1#frag"}, + originHeader: "https://example.com", + expectedCORSOrigin: "https://example.com", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, + { + name: "trailing slash is ignored for matching", + allowedOrigins: []string{"https://example.com/"}, + originHeader: "https://example.com", + expectedCORSOrigin: "https://example.com", + expectCORSOriginHeader: true, + expectedStatusCode: http.StatusOK, + }, } for _, tc := range cases { @@ -394,7 +512,11 @@ func TestServer_CORSOrigins(t *testing.T) { AllowedHosts: []string{"*"}, // Set wildcard to isolate CORS testing AllowedOrigins: tc.allowedOrigins, }) - require.NoError(t, err) + if tc.validationErrorMsg != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.validationErrorMsg) + return + } tsServer := httptest.NewServer(s.Handler()) t.Cleanup(tsServer.Close)