Skip to content

Commit 8f2d898

Browse files
committed
fix: strip port from host
1 parent f1b18a2 commit 8f2d898

File tree

5 files changed

+220
-37
lines changed

5 files changed

+220
-37
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ There are 4 endpoints:
8080

8181
#### Allowed hosts
8282

83-
By default, the server only allows requests with the host header set to localhost:3284. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag.
83+
By default, the server only allows requests with the host header set to `localhost`. If you'd like to host AgentAPI elsewhere, you can change this by using the `AGENTAPI_ALLOWED_HOSTS` environment variable or the `--allowed-hosts` flag. Hosts must be hostnames only (no ports); the server ignores the port portion of incoming requests when authorizing.
8484

8585
To allow requests from any host, use `*` as the allowed host.
8686

cmd/server/server.go

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"log/slog"
88
"net/http"
9+
"net/url"
910
"os"
1011
"sort"
1112
"strings"
@@ -59,10 +60,63 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
5960
return AgentTypeCustom, nil
6061
}
6162

62-
// Validate allowed hosts or origins don't contain whitespace or commas.
63+
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
6364
// Viper/Cobra use different separators (space for env vars, comma for flags),
6465
// so these characters likely indicate user error.
65-
func validateAllowedHostsOrOrigins(input []string) error {
66+
func validateAllowedHosts(input []string) error {
67+
if len(input) == 0 {
68+
return fmt.Errorf("the list must not be empty")
69+
}
70+
// First pass: whitespace & comma checks (surface these errors first)
71+
for _, item := range input {
72+
for _, r := range item {
73+
if unicode.IsSpace(r) {
74+
return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item)
75+
}
76+
}
77+
if strings.Contains(item, ",") {
78+
return fmt.Errorf("'%s' contains comma characters, which are not allowed", item)
79+
}
80+
}
81+
// Second pass: scheme check
82+
for _, item := range input {
83+
if strings.Contains(item, "http://") || strings.Contains(item, "https://") {
84+
return fmt.Errorf("'%s' must not include http:// or https://", item)
85+
}
86+
}
87+
// Third pass: port check (but allow IPv6 literals without ports)
88+
for _, item := range input {
89+
trimmed := strings.TrimSpace(item)
90+
colonCount := strings.Count(trimmed, ":")
91+
// If bracketed, rely on url.Parse to detect a port in "]:<port>" form.
92+
if strings.HasPrefix(trimmed, "[") {
93+
if u, err := url.Parse("http://" + trimmed); err == nil {
94+
if u.Port() != "" {
95+
return fmt.Errorf("'%s' must not include a port", item)
96+
}
97+
}
98+
continue
99+
}
100+
// Unbracketed IPv6: multiple colons and no brackets; treat as valid (no ports allowed here)
101+
if colonCount >= 2 {
102+
continue
103+
}
104+
// IPv4 or hostname: if URL parsing finds a port or there's a single colon, it's invalid
105+
if u, err := url.Parse("http://" + trimmed); err == nil {
106+
if u.Port() != "" {
107+
return fmt.Errorf("'%s' must not include a port", item)
108+
}
109+
}
110+
if colonCount == 1 {
111+
return fmt.Errorf("'%s' must not include a port", item)
112+
}
113+
}
114+
return nil
115+
}
116+
117+
// Validate allowed origins don't contain whitespace or commas.
118+
// Origins must include a scheme, validated later by the HTTP layer.
119+
func validateAllowedOrigins(input []string) error {
66120
if len(input) == 0 {
67121
return fmt.Errorf("the list must not be empty")
68122
}
@@ -195,11 +249,11 @@ func CreateServerCmd() *cobra.Command {
195249
Args: cobra.MinimumNArgs(1),
196250
PreRunE: func(cmd *cobra.Command, args []string) error {
197251
allowedHosts := viper.GetStringSlice(FlagAllowedHosts)
198-
if err := validateAllowedHostsOrOrigins(allowedHosts); err != nil {
252+
if err := validateAllowedHosts(allowedHosts); err != nil {
199253
return xerrors.Errorf("failed to validate allowed hosts: %w", err)
200254
}
201255
allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins)
202-
if err := validateAllowedHostsOrOrigins(allowedOrigins); err != nil {
256+
if err := validateAllowedOrigins(allowedOrigins); err != nil {
203257
return xerrors.Errorf("failed to validate allowed origins: %w", err)
204258
}
205259
return nil
@@ -225,8 +279,8 @@ func CreateServerCmd() *cobra.Command {
225279
{FlagChatBasePath, "c", "/chat", "Base path for assets and routes used in the static files of the chat interface", "string"},
226280
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
227281
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
228-
// localhost:3284 is the default host for the server
229-
{FlagAllowedHosts, "a", []string{"localhost:3284"}, "HTTP allowed hosts. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
282+
// localhost is the default host for the server. Port is ignored during matching.
283+
{FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
230284
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
231285
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
232286
}

cmd/server/server_test.go

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
155155
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
156156
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
157157
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
158-
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost:3284"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
158+
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
159159
{"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
160160
}
161161

@@ -189,7 +189,7 @@ func TestServerCmd_AllEnvVars(t *testing.T) {
189189
{"AGENTAPI_CHAT_BASE_PATH", "AGENTAPI_CHAT_BASE_PATH", "/api", "/api", func() any { return viper.GetString(FlagChatBasePath) }},
190190
{"AGENTAPI_TERM_WIDTH", "AGENTAPI_TERM_WIDTH", "120", uint16(120), func() any { return viper.GetUint16(FlagTermWidth) }},
191191
{"AGENTAPI_TERM_HEIGHT", "AGENTAPI_TERM_HEIGHT", "500", uint16(500), func() any { return viper.GetUint16(FlagTermHeight) }},
192-
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost:3284 localhost:3285", []string{"localhost:3284", "localhost:3285"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
192+
{"AGENTAPI_ALLOWED_HOSTS", "AGENTAPI_ALLOWED_HOSTS", "localhost example.com", []string{"localhost", "example.com"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
193193
{"AGENTAPI_ALLOWED_ORIGINS", "AGENTAPI_ALLOWED_ORIGINS", "https://example.com http://localhost:3000", []string{"https://example.com", "http://localhost:3000"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
194194
}
195195

@@ -325,66 +325,110 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
325325
// Environment variable scenarios (space-separated format)
326326
{
327327
name: "env: single valid host",
328-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
328+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
329329
args: []string{},
330-
expected: []string{"localhost:3284"},
330+
expected: []string{"localhost"},
331331
},
332332
{
333333
name: "env: multiple valid hosts space-separated",
334-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284 example.com 192.168.1.1:8080"},
334+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost example.com 192.168.1.1"},
335335
args: []string{},
336-
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
336+
expected: []string{"localhost", "example.com", "192.168.1.1"},
337337
},
338338
{
339339
name: "env: host with tab",
340-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284\texample.com"},
340+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost\texample.com"},
341341
args: []string{},
342-
expected: []string{"localhost:3284", "example.com"},
342+
expected: []string{"localhost", "example.com"},
343343
},
344344
{
345345
name: "env: host with comma (invalid)",
346346
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"},
347347
args: []string{},
348348
expectedErr: "contains comma characters",
349349
},
350+
{
351+
name: "env: host with port (invalid)",
352+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
353+
args: []string{},
354+
expectedErr: "must not include a port",
355+
},
356+
{
357+
name: "env: ipv6 literal",
358+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "2001:db8::1"},
359+
args: []string{},
360+
expected: []string{"2001:db8::1"},
361+
},
362+
{
363+
name: "env: ipv6 bracketed literal",
364+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]"},
365+
args: []string{},
366+
expected: []string{"[2001:db8::1]"},
367+
},
368+
{
369+
name: "env: ipv6 with port (invalid)",
370+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]:443"},
371+
args: []string{},
372+
expectedErr: "must not include a port",
373+
},
350374

351375
// CLI flag scenarios (comma-separated format)
352376
{
353377
name: "flag: single valid host",
354-
args: []string{"--allowed-hosts", "localhost:3284"},
355-
expected: []string{"localhost:3284"},
378+
args: []string{"--allowed-hosts", "localhost"},
379+
expected: []string{"localhost"},
356380
},
357381
{
358382
name: "flag: multiple valid hosts comma-separated",
359-
args: []string{"--allowed-hosts", "localhost:3284,example.com,192.168.1.1:8080"},
360-
expected: []string{"localhost:3284", "example.com", "192.168.1.1:8080"},
383+
args: []string{"--allowed-hosts", "localhost,example.com,192.168.1.1"},
384+
expected: []string{"localhost", "example.com", "192.168.1.1"},
361385
},
362386
{
363387
name: "flag: multiple valid hosts with multiple flags",
364-
args: []string{"--allowed-hosts", "localhost:3284", "--allowed-hosts", "example.com"},
365-
expected: []string{"localhost:3284", "example.com"},
388+
args: []string{"--allowed-hosts", "localhost", "--allowed-hosts", "example.com"},
389+
expected: []string{"localhost", "example.com"},
366390
},
367391
{
368392
name: "flag: host with newline",
369-
args: []string{"--allowed-hosts", "localhost:3284\n"},
370-
expected: []string{"localhost:3284"},
393+
args: []string{"--allowed-hosts", "localhost\n"},
394+
expected: []string{"localhost"},
371395
},
372396
{
373397
name: "flag: host with space in comma-separated list (invalid)",
374398
args: []string{"--allowed-hosts", "localhost:3284,example .com"},
375399
expectedErr: "contains whitespace characters",
376400
},
401+
{
402+
name: "flag: host with port (invalid)",
403+
args: []string{"--allowed-hosts", "localhost:3284"},
404+
expectedErr: "must not include a port",
405+
},
406+
{
407+
name: "flag: ipv6 literal",
408+
args: []string{"--allowed-hosts", "2001:db8::1"},
409+
expected: []string{"2001:db8::1"},
410+
},
411+
{
412+
name: "flag: ipv6 bracketed literal",
413+
args: []string{"--allowed-hosts", "[2001:db8::1]"},
414+
expected: []string{"[2001:db8::1]"},
415+
},
416+
{
417+
name: "flag: ipv6 with port (invalid)",
418+
args: []string{"--allowed-hosts", "[2001:db8::1]:443"},
419+
expectedErr: "must not include a port",
420+
},
377421

378422
// Mixed scenarios (env + flag precedence)
379423
{
380424
name: "mixed: flag overrides env",
381-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
425+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
382426
args: []string{"--allowed-hosts", "override.com"},
383427
expected: []string{"override.com"},
384428
},
385429
{
386430
name: "mixed: flag overrides env but flag is invalid",
387-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:8080"},
431+
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
388432
args: []string{"--allowed-hosts", "invalid .com"},
389433
expectedErr: "contains whitespace characters",
390434
},
@@ -400,7 +444,7 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
400444
{
401445
name: "default hosts when neither env nor flag provided",
402446
args: []string{},
403-
expected: []string{"localhost:3284"},
447+
expected: []string{"localhost"},
404448
},
405449
}
406450

lib/httpapi/server.go

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/json"
66
"fmt"
77
"log/slog"
8+
"net"
89
"net/http"
910
"net/url"
1011
"slices"
@@ -82,7 +83,32 @@ func parseAllowedHosts(hosts []string) ([]string, error) {
8283
return nil, xerrors.Errorf("host must not contain http:// or https://: %q", host)
8384
}
8485
}
85-
return hosts, nil
86+
// Normalize hosts to bare hostnames/IPs by stripping any port and brackets.
87+
// This ensures allowed entries match the Host header hostname only.
88+
normalized := make([]string, 0, len(hosts))
89+
for _, raw := range hosts {
90+
h := strings.TrimSpace(raw)
91+
// If it's an IPv6 literal (possibly bracketed) without an obvious port, keep the literal.
92+
unbracketed := strings.Trim(h, "[]")
93+
if ip := net.ParseIP(unbracketed); ip != nil {
94+
// It's an IP literal; use the bare form without brackets.
95+
normalized = append(normalized, unbracketed)
96+
continue
97+
}
98+
// If likely host:port (single colon) or bracketed host, use url.Parse to extract hostname.
99+
if strings.Count(h, ":") == 1 || (strings.HasPrefix(h, "[") && strings.Contains(h, "]")) {
100+
if u, err := url.Parse("http://" + h); err == nil {
101+
hn := u.Hostname()
102+
if hn != "" {
103+
normalized = append(normalized, hn)
104+
continue
105+
}
106+
}
107+
}
108+
// Fallback: use as-is (e.g., hostname without port)
109+
normalized = append(normalized, h)
110+
}
111+
return normalized, nil
86112
}
87113

88114
func parseAllowedOrigins(origins []string) ([]string, error) {
@@ -116,13 +142,14 @@ func NewServer(ctx context.Context, config ServerConfig) (*Server, error) {
116142
}
117143
logger.Info(fmt.Sprintf("Allowed origins: %s", strings.Join(allowedOrigins, ", ")))
118144

119-
secureMiddleware := secure.New(secure.Options{
120-
AllowedHosts: allowedHosts,
121-
})
145+
// Enforce allowed hosts in a custom middleware that ignores the port during matching.
122146
badHostHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123147
http.Error(w, "Invalid host header. Allowed hosts: "+strings.Join(allowedHosts, ", "), http.StatusBadRequest)
124148
})
125-
secureMiddleware.SetBadHostHandler(badHostHandler)
149+
router.Use(hostAuthorizationMiddleware(allowedHosts, badHostHandler))
150+
151+
// Keep other security headers/features; do not use AllowedHosts here as we handle it ourselves.
152+
secureMiddleware := secure.New(secure.Options{})
126153
router.Use(secureMiddleware.Handler)
127154

128155
corsMiddleware := cors.New(cors.Options{
@@ -174,6 +201,39 @@ func (s *Server) Handler() http.Handler {
174201
return s.router
175202
}
176203

204+
// hostAuthorizationMiddleware enforces that the request Host header matches one of the allowed
205+
// hosts, ignoring any port in the comparison. If allowedHosts is empty, all hosts are allowed.
206+
// Always uses url.Parse("http://" + r.Host) to robustly extract the hostname (handles IPv6).
207+
func hostAuthorizationMiddleware(allowedHosts []string, badHostHandler http.Handler) func(next http.Handler) http.Handler {
208+
// Copy for safety; also build a map for O(1) lookups with case-insensitive keys.
209+
allowed := make(map[string]struct{}, len(allowedHosts))
210+
for _, h := range allowedHosts {
211+
allowed[strings.ToLower(h)] = struct{}{}
212+
}
213+
return func(next http.Handler) http.Handler {
214+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
215+
if len(allowedHosts) == 0 { // wildcard semantics: allow all
216+
next.ServeHTTP(w, r)
217+
return
218+
}
219+
// Extract hostname from the Host header using url.Parse; ignore any port.
220+
hostHeader := r.Host
221+
if hostHeader == "" {
222+
badHostHandler.ServeHTTP(w, r)
223+
return
224+
}
225+
if u, err := url.Parse("http://" + hostHeader); err == nil {
226+
hostname := u.Hostname()
227+
if _, ok := allowed[strings.ToLower(hostname)]; ok {
228+
next.ServeHTTP(w, r)
229+
return
230+
}
231+
}
232+
badHostHandler.ServeHTTP(w, r)
233+
})
234+
}
235+
}
236+
177237
func (s *Server) StartSnapshotLoop(ctx context.Context) {
178238
s.conversation.StartSnapshotLoop(ctx)
179239
go func() {

0 commit comments

Comments
 (0)