Skip to content

Commit 2441cdc

Browse files
committed
chore: move allowed origins and hosts validation into the httpapi package
1 parent 6167d5e commit 2441cdc

File tree

4 files changed

+222
-256
lines changed

4 files changed

+222
-256
lines changed

cmd/server/server.go

Lines changed: 1 addition & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,9 @@ import (
66
"fmt"
77
"log/slog"
88
"net/http"
9-
"net/url"
109
"os"
1110
"sort"
1211
"strings"
13-
"unicode"
1412

1513
"github.com/spf13/cobra"
1614
"github.com/spf13/viper"
@@ -60,79 +58,6 @@ func parseAgentType(firstArg string, agentTypeVar string) (AgentType, error) {
6058
return AgentTypeCustom, nil
6159
}
6260

63-
// Validate allowed hosts don't contain whitespace, commas, schemes, or ports.
64-
// Viper/Cobra use different separators (space for env vars, comma for flags),
65-
// so these characters likely indicate user error.
66-
func validateAllowedHosts(input []string) error {
67-
if len(input) == 0 {
68-
return fmt.Errorf("the list must not be empty")
69-
}
70-
// First pass: whitespace & comma checks (surface these errors first)
71-
for _, item := range input {
72-
for _, r := range item {
73-
if unicode.IsSpace(r) {
74-
return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item)
75-
}
76-
}
77-
if strings.Contains(item, ",") {
78-
return fmt.Errorf("'%s' contains comma characters, which are not allowed", item)
79-
}
80-
}
81-
// Second pass: scheme check
82-
for _, item := range input {
83-
if strings.Contains(item, "http://") || strings.Contains(item, "https://") {
84-
return fmt.Errorf("'%s' must not include http:// or https://", item)
85-
}
86-
}
87-
// Third pass: port check (but allow IPv6 literals without ports)
88-
for _, item := range input {
89-
trimmed := strings.TrimSpace(item)
90-
colonCount := strings.Count(trimmed, ":")
91-
// If bracketed, rely on url.Parse to detect a port in "]:<port>" form.
92-
if strings.HasPrefix(trimmed, "[") {
93-
if u, err := url.Parse("http://" + trimmed); err == nil {
94-
if u.Port() != "" {
95-
return fmt.Errorf("'%s' must not include a port", item)
96-
}
97-
}
98-
continue
99-
}
100-
// Unbracketed IPv6: multiple colons and no brackets; treat as valid (no ports allowed here)
101-
if colonCount >= 2 {
102-
continue
103-
}
104-
// IPv4 or hostname: if URL parsing finds a port or there's a single colon, it's invalid
105-
if u, err := url.Parse("http://" + trimmed); err == nil {
106-
if u.Port() != "" {
107-
return fmt.Errorf("'%s' must not include a port", item)
108-
}
109-
}
110-
if colonCount == 1 {
111-
return fmt.Errorf("'%s' must not include a port", item)
112-
}
113-
}
114-
return nil
115-
}
116-
117-
// Validate allowed origins don't contain whitespace or commas.
118-
// Origins must include a scheme, validated later by the HTTP layer.
119-
func validateAllowedOrigins(input []string) error {
120-
if len(input) == 0 {
121-
return fmt.Errorf("the list must not be empty")
122-
}
123-
for _, item := range input {
124-
for _, r := range item {
125-
if unicode.IsSpace(r) {
126-
return fmt.Errorf("'%s' contains whitespace characters, which are not allowed", item)
127-
}
128-
}
129-
if strings.Contains(item, ",") {
130-
return fmt.Errorf("'%s' contains comma characters, which are not allowed", item)
131-
}
132-
}
133-
return nil
134-
}
135-
13661
func runServer(ctx context.Context, logger *slog.Logger, argsToPass []string) error {
13762
agent := argsToPass[0]
13863
agentTypeValue := viper.GetString(FlagType)
@@ -247,17 +172,6 @@ func CreateServerCmd() *cobra.Command {
247172
Short: "Run the server",
248173
Long: fmt.Sprintf("Run the server with the specified agent (one of: %s)", strings.Join(agentNames, ", ")),
249174
Args: cobra.MinimumNArgs(1),
250-
PreRunE: func(cmd *cobra.Command, args []string) error {
251-
allowedHosts := viper.GetStringSlice(FlagAllowedHosts)
252-
if err := validateAllowedHosts(allowedHosts); err != nil {
253-
return xerrors.Errorf("failed to validate allowed hosts: %w", err)
254-
}
255-
allowedOrigins := viper.GetStringSlice(FlagAllowedOrigins)
256-
if err := validateAllowedOrigins(allowedOrigins); err != nil {
257-
return xerrors.Errorf("failed to validate allowed origins: %w", err)
258-
}
259-
return nil
260-
},
261175
Run: func(cmd *cobra.Command, args []string) {
262176
// The --exit flag is used for testing validation of flags in the test suite
263177
if viper.GetBool(FlagExit) {
@@ -280,7 +194,7 @@ func CreateServerCmd() *cobra.Command {
280194
{FlagTermWidth, "W", uint16(80), "Width of the emulated terminal", "uint16"},
281195
{FlagTermHeight, "H", uint16(1000), "Height of the emulated terminal", "uint16"},
282196
// localhost is the default host for the server. Port is ignored during matching.
283-
{FlagAllowedHosts, "a", []string{"localhost"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
197+
{FlagAllowedHosts, "a", []string{"localhost", "127.0.0.1", "[::1]"}, "HTTP allowed hosts (hostnames only, no ports). Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_HOSTS env var", "stringSlice"},
284198
// localhost:3284 is the default origin when you open the chat interface in your browser. localhost:3000 and 3001 are used during development.
285199
{FlagAllowedOrigins, "o", []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, "HTTP allowed origins. Use '*' for all, comma-separated list via flag, space-separated list via AGENTAPI_ALLOWED_ORIGINS env var", "stringSlice"},
286200
}

cmd/server/server_test.go

Lines changed: 4 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ func TestServerCmd_AllArgs_Defaults(t *testing.T) {
155155
{"chat-base-path default", FlagChatBasePath, "/chat", func() any { return viper.GetString(FlagChatBasePath) }},
156156
{"term-width default", FlagTermWidth, uint16(80), func() any { return viper.GetUint16(FlagTermWidth) }},
157157
{"term-height default", FlagTermHeight, uint16(1000), func() any { return viper.GetUint16(FlagTermHeight) }},
158-
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
158+
{"allowed-hosts default", FlagAllowedHosts, []string{"localhost", "127.0.0.1", "[::1]"}, func() any { return viper.GetStringSlice(FlagAllowedHosts) }},
159159
{"allowed-origins default", FlagAllowedOrigins, []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"}, func() any { return viper.GetStringSlice(FlagAllowedOrigins) }},
160160
}
161161

@@ -341,37 +341,6 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
341341
args: []string{},
342342
expected: []string{"localhost", "example.com"},
343343
},
344-
{
345-
name: "env: host with comma (invalid)",
346-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284,example.com"},
347-
args: []string{},
348-
expectedErr: "contains comma characters",
349-
},
350-
{
351-
name: "env: host with port (invalid)",
352-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost:3284"},
353-
args: []string{},
354-
expectedErr: "must not include a port",
355-
},
356-
{
357-
name: "env: ipv6 literal",
358-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "2001:db8::1"},
359-
args: []string{},
360-
expected: []string{"2001:db8::1"},
361-
},
362-
{
363-
name: "env: ipv6 bracketed literal",
364-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]"},
365-
args: []string{},
366-
expected: []string{"[2001:db8::1]"},
367-
},
368-
{
369-
name: "env: ipv6 with port (invalid)",
370-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "[2001:db8::1]:443"},
371-
args: []string{},
372-
expectedErr: "must not include a port",
373-
},
374-
375344
// CLI flag scenarios (comma-separated format)
376345
{
377346
name: "flag: single valid host",
@@ -394,30 +363,10 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
394363
expected: []string{"localhost"},
395364
},
396365
{
397-
name: "flag: host with space in comma-separated list (invalid)",
398-
args: []string{"--allowed-hosts", "localhost:3284,example .com"},
399-
expectedErr: "contains whitespace characters",
400-
},
401-
{
402-
name: "flag: host with port (invalid)",
403-
args: []string{"--allowed-hosts", "localhost:3284"},
404-
expectedErr: "must not include a port",
366+
name: "flag: ipv6 bracketed literal",
367+
args: []string{"--allowed-hosts", "[2001:db8::1]"},
368+
expected: []string{"[2001:db8::1]"},
405369
},
406-
{
407-
name: "flag: ipv6 literal",
408-
args: []string{"--allowed-hosts", "2001:db8::1"},
409-
expected: []string{"2001:db8::1"},
410-
},
411-
{
412-
name: "flag: ipv6 bracketed literal",
413-
args: []string{"--allowed-hosts", "[2001:db8::1]"},
414-
expected: []string{"[2001:db8::1]"},
415-
},
416-
{
417-
name: "flag: ipv6 with port (invalid)",
418-
args: []string{"--allowed-hosts", "[2001:db8::1]:443"},
419-
expectedErr: "must not include a port",
420-
},
421370

422371
// Mixed scenarios (env + flag precedence)
423372
{
@@ -426,26 +375,6 @@ func TestServerCmd_AllowedHosts(t *testing.T) {
426375
args: []string{"--allowed-hosts", "override.com"},
427376
expected: []string{"override.com"},
428377
},
429-
{
430-
name: "mixed: flag overrides env but flag is invalid",
431-
env: map[string]string{"AGENTAPI_ALLOWED_HOSTS": "localhost"},
432-
args: []string{"--allowed-hosts", "invalid .com"},
433-
expectedErr: "contains whitespace characters",
434-
},
435-
436-
// Empty hosts are not allowed
437-
{
438-
name: "empty host",
439-
args: []string{"--allowed-hosts", ""},
440-
expectedErr: "the list must not be empty",
441-
},
442-
443-
// Default behavior
444-
{
445-
name: "default hosts when neither env nor flag provided",
446-
args: []string{},
447-
expected: []string{"localhost"},
448-
},
449378
}
450379

451380
for _, tt := range tests {
@@ -506,12 +435,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) {
506435
args: []string{},
507436
expected: []string{"https://example.com", "http://localhost:3000"},
508437
},
509-
{
510-
name: "env: origin with comma (invalid)",
511-
env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://example.com,http://localhost:3000"},
512-
args: []string{},
513-
expectedErr: "contains comma characters",
514-
},
515438

516439
// CLI flag scenarios (comma-separated format)
517440
{
@@ -539,11 +462,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) {
539462
args: []string{"--allowed-origins", "https://example.com\n"},
540463
expected: []string{"https://example.com"},
541464
},
542-
{
543-
name: "flag: origin with space in comma-separated list (invalid)",
544-
args: []string{"--allowed-origins", "https://example.com,http://localhost :3000"},
545-
expectedErr: "contains whitespace characters",
546-
},
547465

548466
// Mixed scenarios (env + flag precedence)
549467
{
@@ -552,26 +470,6 @@ func TestServerCmd_AllowedOrigins(t *testing.T) {
552470
args: []string{"--allowed-origins", "https://override.com"},
553471
expected: []string{"https://override.com"},
554472
},
555-
{
556-
name: "mixed: flag overrides env but flag is invalid",
557-
env: map[string]string{"AGENTAPI_ALLOWED_ORIGINS": "https://env-example.com"},
558-
args: []string{"--allowed-origins", "invalid origin"},
559-
expectedErr: "contains whitespace characters",
560-
},
561-
562-
// Empty origins are not allowed
563-
{
564-
name: "empty origin",
565-
args: []string{"--allowed-origins", ""},
566-
expectedErr: "the list must not be empty",
567-
},
568-
569-
// Default behavior
570-
{
571-
name: "default origins when neither env nor flag provided",
572-
args: []string{},
573-
expected: []string{"http://localhost:3284", "http://localhost:3000", "http://localhost:3001"},
574-
},
575473
}
576474

577475
for _, tt := range tests {

0 commit comments

Comments
 (0)