diff --git a/cli/deployment/config.go b/cli/deployment/config.go index 1272ab8529ee7..77db09cc2d188 100644 --- a/cli/deployment/config.go +++ b/cli/deployment/config.go @@ -113,6 +113,16 @@ func newConfig() codersdk.DeploymentConfig { Flag: "pprof-address", Value: "127.0.0.1:6060", }, + ProxyTrustedHeaders: codersdk.DeploymentConfigField[[]string]{ + Key: "proxy.trusted_headers", + Flag: "proxy-trusted-headers", + Usage: "Headers to trust for forwarding IP addresses. e.g. Cf-Connecting-IP True-Client-Ip, X-Forwarded-for", + }, + ProxyTrustedOrigins: codersdk.DeploymentConfigField[[]string]{ + Key: "proxy.trusted_origins", + Flag: "proxy-trusted-origins", + Usage: "Origin addresses to respect \"proxy-trusted-headers\". e.g. example.com", + }, CacheDirectory: codersdk.DeploymentConfigField[string]{ Key: "cache_directory", Usage: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.", diff --git a/cli/server.go b/cli/server.go index bb9f7f3558bc3..e937c2ac80aeb 100644 --- a/cli/server.go +++ b/cli/server.go @@ -57,6 +57,7 @@ import ( "github.com/coder/coder/coderd/devtunnel" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/prometheusmetrics" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" @@ -325,6 +326,11 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co } } + realIPConfig, err := httpmw.ParseRealIPConfig(cfg.ProxyTrustedHeaders.Value, cfg.ProxyTrustedOrigins.Value) + if err != nil { + return xerrors.Errorf("parse real ip config: %w", err) + } + options := &coderd.Options{ AccessURL: accessURLParsed, AppHostname: appHostname, @@ -335,6 +341,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co Pubsub: database.NewPubsubInMemory(), CacheDir: cfg.CacheDirectory.Value, GoogleTokenValidator: googleTokenValidator, + RealIPConfig: realIPConfig, SecureAuthCookie: cfg.SecureAuthCookie.Value, SSHKeygenAlgorithm: sshKeygenAlgorithm, TracerProvider: tracerProvider, diff --git a/coderd/apikey.go b/coderd/apikey.go index 84e936cb22e16..01e9d7484a42b 100644 --- a/coderd/apikey.go +++ b/coderd/apikey.go @@ -230,8 +230,7 @@ func (api *API) createAPIKey(ctx context.Context, params createAPIKeyParams) (*h } } - host, _, _ := net.SplitHostPort(params.RemoteAddr) - ip := net.ParseIP(host) + ip := net.ParseIP(params.RemoteAddr) if ip == nil { ip = net.IPv4(0, 0, 0, 0) } diff --git a/coderd/audit.go b/coderd/audit.go index f76a6565bce77..5002bb6960c58 100644 --- a/coderd/audit.go +++ b/coderd/audit.go @@ -117,8 +117,7 @@ func (api *API) generateFakeAuditLog(rw http.ResponseWriter, r *http.Request) { return } - ipRaw, _, _ := net.SplitHostPort(r.RemoteAddr) - ip := net.ParseIP(ipRaw) + ip := net.ParseIP(r.RemoteAddr) ipNet := pqtype.Inet{} if ip != nil { ipNet = pqtype.Inet{ diff --git a/coderd/audit/request.go b/coderd/audit/request.go index c658a09038569..f330b321cd1ec 100644 --- a/coderd/audit/request.go +++ b/coderd/audit/request.go @@ -129,12 +129,8 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request } } - ip, err := parseIP(p.Request.RemoteAddr) - if err != nil { - p.Log.Warn(logCtx, "parse ip", slog.Error(err)) - } - - err = p.Audit.Export(ctx, database.AuditLog{ + ip := parseIP(p.Request.RemoteAddr) + err := p.Audit.Export(ctx, database.AuditLog{ ID: uuid.New(), Time: database.Now(), UserID: httpmw.APIKey(p.Request).UserID, @@ -166,16 +162,8 @@ func either[T Auditable, R any](old, new T, fn func(T) R) R { } } -func parseIP(ipStr string) (pqtype.Inet, error) { - var err error - - ipStr, _, err = net.SplitHostPort(ipStr) - if err != nil { - return pqtype.Inet{}, err - } - +func parseIP(ipStr string) pqtype.Inet { ip := net.ParseIP(ipStr) - ipNet := net.IPNet{} if ip != nil { ipNet = net.IPNet{ @@ -187,5 +175,5 @@ func parseIP(ipStr string) (pqtype.Inet, error) { return pqtype.Inet{ IPNet: ipNet, Valid: ip != nil, - }, nil + } } diff --git a/coderd/coderd.go b/coderd/coderd.go index 1aed8417b0ba8..5f98d83babd9f 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -82,6 +82,7 @@ type Options struct { Telemetry telemetry.Reporter TracerProvider trace.TracerProvider AutoImportTemplates []AutoImportTemplate + RealIPConfig *httpmw.RealIPConfig // TLSCertificates is used to mesh DERP servers securely. TLSCertificates []tls.Certificate @@ -198,6 +199,7 @@ func New(options *Options) *API { r.Use( httpmw.AttachRequestID, httpmw.Recover(api.Logger), + httpmw.ExtractRealIP(api.RealIPConfig), httpmw.Logger(api.Logger), httpmw.Prometheus(options.PrometheusRegistry), // handleSubdomainApplications checks if the first subdomain is a valid diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 4a4198e43c45b..f0b8e58f99c58 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -57,6 +57,7 @@ import ( "github.com/coder/coder/coderd/database/dbtestutil" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/util/ptr" @@ -77,6 +78,7 @@ type Options struct { Experimental bool AzureCertificates x509.VerifyOptions GithubOAuth2Config *coderd.GithubOAuth2Config + RealIPConfig *httpmw.RealIPConfig OIDCConfig *coderd.OIDCConfig GoogleTokenValidator *idtoken.Validator SSHKeygenAlgorithm gitsshkey.Algorithm @@ -238,6 +240,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can AWSCertificates: options.AWSCertificates, AzureCertificates: options.AzureCertificates, GithubOAuth2Config: options.GithubOAuth2Config, + RealIPConfig: options.RealIPConfig, OIDCConfig: options.OIDCConfig, GoogleTokenValidator: options.GoogleTokenValidator, SSHKeygenAlgorithm: options.SSHKeygenAlgorithm, diff --git a/coderd/httpmw/apikey.go b/coderd/httpmw/apikey.go index 54a28a2d1c617..c72d0d40e606a 100644 --- a/coderd/httpmw/apikey.go +++ b/coderd/httpmw/apikey.go @@ -250,8 +250,7 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler { // Only update LastUsed once an hour to prevent database spam. if now.Sub(key.LastUsed) > time.Hour { key.LastUsed = now - host, _, _ := net.SplitHostPort(r.RemoteAddr) - remoteIP := net.ParseIP(host) + remoteIP := net.ParseIP(r.RemoteAddr) if remoteIP == nil { remoteIP = net.IPv4(0, 0, 0, 0) } diff --git a/coderd/httpmw/apikey_test.go b/coderd/httpmw/apikey_test.go index 10166fadd0f63..a9467878aff78 100644 --- a/coderd/httpmw/apikey_test.go +++ b/coderd/httpmw/apikey_test.go @@ -512,7 +512,7 @@ func TestAPIKey(t *testing.T) { rw = httptest.NewRecorder() user = createUser(r.Context(), t, db) ) - r.RemoteAddr = "1.1.1.1:3555" + r.RemoteAddr = "1.1.1.1" r.Header.Set(codersdk.SessionCustomHeader, fmt.Sprintf("%s-%s", id, secret)) _, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{ diff --git a/coderd/httpmw/realip.go b/coderd/httpmw/realip.go new file mode 100644 index 0000000000000..ee708680e03cd --- /dev/null +++ b/coderd/httpmw/realip.go @@ -0,0 +1,225 @@ +package httpmw + +import ( + "context" + "net" + "net/http" + "strings" + + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/httpapi" +) + +const ( + headerXForwardedFor string = "X-Forwarded-For" + headerXForwardedProto string = "X-Forwarded-Proto" +) + +// RealIPConfig configures the search order for the function, which controls +// which headers to consider trusted. +type RealIPConfig struct { + // TrustedOrigins is a list of networks that will be trusted. If + // any non-trusted address supplies these headers, they will be + // ignored. + TrustedOrigins []*net.IPNet + + // TrustedHeaders lists headers that are trusted for forwarding + // IP addresses. e.g. "CF-Connecting-IP", "True-Client-IP", etc. + TrustedHeaders []string +} + +// ExtractRealIP is a middleware that uses headers from reverse proxies to +// propagate origin IP address information, when configured to do so. +func ExtractRealIP(config *RealIPConfig) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Preserve the original TLS connection state and RemoteAddr + req = req.WithContext(context.WithValue(req.Context(), ctxKey{}, &RealIPState{ + Config: config, + OriginalRemoteAddr: req.RemoteAddr, + })) + + info, err := ExtractRealIPAddress(config, req) + if err != nil { + httpapi.InternalServerError(w, err) + return + } + req.RemoteAddr = info.String() + + next.ServeHTTP(w, req) + }) + } +} + +// ExtractRealIPAddress returns the original client address according to the +// configuration and headers. It does not mutate the original request. +func ExtractRealIPAddress(config *RealIPConfig, req *http.Request) (net.IP, error) { + if config == nil { + config = &RealIPConfig{} + } + + cf := isContainedIn(config.TrustedOrigins, getRemoteAddress(req.RemoteAddr)) + if !cf { + // Address is not valid or the origin is not trusted; use the + // original address + return getRemoteAddress(req.RemoteAddr), nil + } + + for _, trustedHeader := range config.TrustedHeaders { + addr := getRemoteAddress(req.Header.Get(trustedHeader)) + if addr != nil { + return addr, nil + } + } + + return getRemoteAddress(req.RemoteAddr), nil +} + +// FilterUntrustedOriginHeaders removes all known proxy headers from the +// request for untrusted origins, and ensures that only one copy +// of each proxy header is set. +func FilterUntrustedOriginHeaders(config *RealIPConfig, req *http.Request) { + if config == nil { + config = &RealIPConfig{} + } + + cf := isContainedIn(config.TrustedOrigins, getRemoteAddress(req.RemoteAddr)) + if !cf { + // Address is not valid or the origin is not trusted; clear + // all known proxy headers and return + for _, header := range config.TrustedHeaders { + req.Header.Del(header) + } + return + } + + for _, header := range config.TrustedHeaders { + req.Header.Set(header, req.Header.Get(header)) + } +} + +// EnsureXForwardedForHeader ensures that the request has an X-Forwarded-For +// header. It uses the following logic: +// +// 1. If we have a direct connection (remoteAddr == proxyAddr), then +// set it to remoteAddr +// 2. If we have a proxied connection (remoteAddr != proxyAddr) and +// X-Forwarded-For doesn't begin with remoteAddr, then overwrite +// it with remoteAddr,proxyAddr +// 3. If we have a proxied connection (remoteAddr != proxyAddr) and +// X-Forwarded-For begins with remoteAddr, then append proxyAddr +// to the original X-Forwarded-For header +// 4. If X-Forwarded-Proto is not set, then it will be set to "https" +// if req.TLS != nil, otherwise it will be set to "http" +func EnsureXForwardedForHeader(req *http.Request) error { + state := RealIP(req.Context()) + if state == nil { + return xerrors.New("request does not contain realip.State; was it processed by httpmw.ExtractRealIP?") + } + + remoteAddr := getRemoteAddress(req.RemoteAddr) + if remoteAddr == nil { + return xerrors.Errorf("failed to parse remote address: %s", remoteAddr) + } + + proxyAddr := getRemoteAddress(state.OriginalRemoteAddr) + if proxyAddr == nil { + return xerrors.Errorf("failed to parse original address: %s", proxyAddr) + } + + if remoteAddr.Equal(proxyAddr) { + req.Header.Set(headerXForwardedFor, remoteAddr.String()) + } else { + forwarded := req.Header.Get(headerXForwardedFor) + if forwarded == "" || !remoteAddr.Equal(getRemoteAddress(forwarded)) { + req.Header.Set(headerXForwardedFor, remoteAddr.String()+","+proxyAddr.String()) + } else { + req.Header.Set(headerXForwardedFor, forwarded+","+proxyAddr.String()) + } + } + + if req.Header.Get(headerXForwardedProto) == "" { + if req.TLS != nil { + req.Header.Set(headerXForwardedProto, "https") + } else { + req.Header.Set(headerXForwardedProto, "http") + } + } + + return nil +} + +// getRemoteAddress extracts the IP address from the given string. If +// the string contains commas, it assumes that the first part is the +// original address. +func getRemoteAddress(address string) net.IP { + // X-Forwarded-For may contain multiple addresses, in case the + // proxies are chained; the first value is the client address + i := strings.IndexByte(address, ',') + if i == -1 { + i = len(address) + } + + // If the address contains a port, remove it + firstAddress := address[:i] + host, _, err := net.SplitHostPort(firstAddress) + if err != nil { + // This will error if there is no port, so try to parse the address + return net.ParseIP(firstAddress) + } + return net.ParseIP(host) +} + +// isContainedIn checks that the given address is contained in the given +// network. +func isContainedIn(networks []*net.IPNet, address net.IP) bool { + for _, network := range networks { + if network.Contains(address) { + return true + } + } + + return false +} + +// RealIPState is the original state prior to modification by this middleware, +// useful for getting information about the connecting client if needed. +type RealIPState struct { + // Config is the configuration applied in the middleware. Consider + // this read-only and do not modify. + Config *RealIPConfig + + // OriginalRemoteAddr is the original RemoteAddr for the request. + OriginalRemoteAddr string +} + +type ctxKey struct{} + +// FromContext retrieves the state from the given context.Context. +func RealIP(ctx context.Context) *RealIPState { + state, ok := ctx.Value(ctxKey{}).(*RealIPState) + if !ok { + return nil + } + return state +} + +// ParseRealIPConfig takes a raw string array of headers and origins +// to produce a config. +func ParseRealIPConfig(headers, origins []string) (*RealIPConfig, error) { + config := &RealIPConfig{} + for _, origin := range origins { + _, network, err := net.ParseCIDR(origin) + if err != nil { + return nil, xerrors.Errorf("parse proxy origin %q: %w", origin, err) + } + config.TrustedOrigins = append(config.TrustedOrigins, network) + } + for index, header := range headers { + headers[index] = http.CanonicalHeaderKey(header) + } + config.TrustedHeaders = headers + + return config, nil +} diff --git a/coderd/httpmw/realip_test.go b/coderd/httpmw/realip_test.go new file mode 100644 index 0000000000000..85036a3c63197 --- /dev/null +++ b/coderd/httpmw/realip_test.go @@ -0,0 +1,649 @@ +package httpmw_test + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/httpmw" +) + +// TestExtractAddress checks the ExtractAddress function. +func TestExtractAddress(t *testing.T) { + t.Parallel() + + tests := []struct { + Name string + Config *httpmw.RealIPConfig + Header http.Header + RemoteAddr string + TLS bool + ExpectedRemoteAddr string + ExpectedTLS bool + }{ + { + Name: "default-nil-config", + RemoteAddr: "123.45.67.89", + ExpectedRemoteAddr: "123.45.67.89", + }, + { + Name: "default-empty-config", + RemoteAddr: "123.45.67.89", + ExpectedRemoteAddr: "123.45.67.89", + Config: &httpmw.RealIPConfig{}, + }, + { + Name: "default-filter-headers", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("10.0.0.0"), + Mask: net.CIDRMask(8, 32), + }, + }, + }, + RemoteAddr: "123.45.67.89", + Header: http.Header{ + "X-Forwarded-For": []string{ + "127.0.0.1", + "10.0.0.5", + "10.0.0.5,4.4.4.4", + }, + }, + ExpectedRemoteAddr: "123.45.67.89", + }, + { + Name: "multiple-x-forwarded-for", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Forwarded-For", + }, + }, + RemoteAddr: "123.45.67.89", + Header: http.Header{ + "X-Forwarded-For": []string{ + "10.24.1.1,1.2.3.4,1.1.1.1,4.5.6.7", + "10.0.0.5", + "10.0.0.5,4.4.4.4", + }, + }, + ExpectedRemoteAddr: "10.24.1.1", + }, + { + Name: "single-real-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Real-Ip", + }, + }, + RemoteAddr: "123.45.67.89", + TLS: true, + Header: http.Header{ + "X-Real-Ip": []string{"8.8.8.8"}, + }, + ExpectedRemoteAddr: "8.8.8.8", + ExpectedTLS: true, + }, + { + Name: "multiple-real-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Real-Ip", + }, + }, + RemoteAddr: "123.45.67.89", + Header: http.Header{ + "X-Real-Ip": []string{"4.4.4.4", "8.8.8.8"}, + }, + ExpectedRemoteAddr: "4.4.4.4", + }, + { + // Has X-Forwarded-For and X-Real-Ip, prefers X-Real-Ip + Name: "prefer-real-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Real-Ip", + "X-Forwarded-For", + }, + }, + RemoteAddr: "123.45.67.89", + Header: http.Header{ + "X-Forwarded-For": []string{"8.8.8.8"}, + "X-Real-Ip": []string{"4.4.4.4"}, + }, + ExpectedRemoteAddr: "4.4.4.4", + }, + { + // Has X-Forwarded-For, X-Real-Ip, and True-Client-Ip, prefers + // True-Client-Ip + Name: "prefer-true-client-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("123.45.0.0"), + Mask: net.CIDRMask(16, 32), + }, + }, + TrustedHeaders: []string{ + "True-Client-Ip", + "X-Forwarded-For", + "X-Real-Ip", + }, + }, + RemoteAddr: "123.45.67.89", + TLS: true, + Header: http.Header{ + "X-Forwarded-For": []string{"1.2.3.4"}, + "X-Real-Ip": []string{"4.4.4.4", "8.8.8.8"}, + "True-Client-Ip": []string{"5.6.7.8", "9.8.7.6"}, + }, + ExpectedRemoteAddr: "5.6.7.8", + ExpectedTLS: true, + }, + { + // Has X-Forwarded-For, X-Real-Ip, True-Client-Ip, and + // Cf-Connecting-Ip, prefers Cf-Connecting-Ip + Name: "prefer-cf-connecting-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("123.45.67.89"), + Mask: net.CIDRMask(32, 32), + }, + }, + TrustedHeaders: []string{ + "Cf-Connecting-Ip", + "X-Forwarded-For", + "X-Real-Ip", + "True-Client-Ip", + }, + }, + RemoteAddr: "123.45.67.89", + Header: http.Header{ + "X-Forwarded-For": []string{"1.2.3.4,100.12.1.3,10.10.10.10"}, + "X-Real-Ip": []string{"4.4.4.4", "8.8.8.8"}, + "True-Client-Ip": []string{"5.6.7.8", "9.8.7.6"}, + "Cf-Connecting-Ip": []string{"100.10.2.2"}, + }, + ExpectedRemoteAddr: "100.10.2.2", + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + + // Default to a direct (unproxied) connection over HTTP + req.RemoteAddr = test.RemoteAddr + if test.TLS { + req.TLS = &tls.ConnectionState{} + } else { + req.TLS = nil + } + req.Header = test.Header + + info, err := httpmw.ExtractRealIPAddress(test.Config, req) + require.NoError(t, err, "unexpected error in ExtractAddress") + require.Equal(t, test.ExpectedRemoteAddr, info.String(), "expected info.String() to match") + }) + } +} + +// TestTrustedOrigins tests different settings for TrustedOrigins. +func TestTrustedOrigins(t *testing.T) { + t.Parallel() + + // Remote client protocol: HTTP or HTTPS + for _, proto := range []string{"http", "https"} { + // Trusted origin + // all: default behavior, trust all origins + // none: use an empty set (nothing will be accepted in this case) + // ipv4: trust an IPv6 network + // ipv6: trust an IPv4 network + for _, trusted := range []string{"none", "ipv4", "ipv6"} { + for _, header := range []string{"Cf-Connecting-Ip", "True-Client-Ip", "X-Real-Ip", "X-Forwarded-For"} { + trusted := trusted + header := header + proto := proto + name := fmt.Sprintf("%s-%s-%s", trusted, proto, strings.ToLower(header)) + + t.Run(name, func(t *testing.T) { + t.Parallel() + + remoteAddr := "10.10.10.10" + actualAddr := "12.34.56.78" + + config := &httpmw.RealIPConfig{ + TrustedHeaders: []string{ + "Cf-Connecting-Ip", + "X-Forwarded-For", + "X-Real-Ip", + "True-Client-Ip", + }, + } + switch trusted { + case "none": + config.TrustedOrigins = []*net.IPNet{} + case "ipv4": + config.TrustedOrigins = []*net.IPNet{ + { + IP: net.ParseIP("10.0.0.0"), + Mask: net.CIDRMask(24, 32), + }, + } + remoteAddr = "10.0.0.1" + case "ipv6": + config.TrustedOrigins = []*net.IPNet{ + { + IP: net.ParseIP("2606:4700::0"), + Mask: net.CIDRMask(32, 128), + }, + } + remoteAddr = "2606:4700:4700::1111" + } + + middleware := httpmw.ExtractRealIP(config) + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(header, actualAddr) + req.RemoteAddr = remoteAddr + if proto == "https" { + req.TLS = &tls.ConnectionState{} + } + + handlerCalled := false + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // If nothing is trusted, the remoteAddr should be unchanged + if trusted == "none" { + require.Equal(t, remoteAddr, req.RemoteAddr, "remote address should be unchanged") + } else { + require.Equal(t, actualAddr, req.RemoteAddr, "actual address should be trusted") + } + + handlerCalled = true + }) + + middleware(nextHandler).ServeHTTP(httptest.NewRecorder(), req) + + require.True(t, handlerCalled, "expected handler to be invoked") + }) + } + } + } +} + +// TestCorruptedHeaders tests the middleware when the reverse proxy +// supplies unparsable content. +func TestCorruptedHeaders(t *testing.T) { + t.Parallel() + + for _, header := range []string{"Cf-Connecting-Ip", "True-Client-Ip", "X-Real-Ip", "X-Forwarded-For"} { + header := header + name := strings.ToLower(header) + + t.Run(name, func(t *testing.T) { + t.Parallel() + + remoteAddr := "10.10.10.10" + + config := &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("10.0.0.0"), + Mask: net.CIDRMask(8, 32), + }, + }, + TrustedHeaders: []string{ + "Cf-Connecting-Ip", + "X-Forwarded-For", + "X-Real-Ip", + "True-Client-Ip", + }, + } + + middleware := httpmw.ExtractRealIP(config) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(header, "12.34.56!78") + req.RemoteAddr = remoteAddr + + handlerCalled := false + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Since the header is unparsable, the remoteAddr should be unchanged + require.Equal(t, remoteAddr, req.RemoteAddr, "remote address should be unchanged") + + handlerCalled = true + }) + + middleware(nextHandler).ServeHTTP(httptest.NewRecorder(), req) + + require.True(t, handlerCalled, "expected handler to be invoked") + }) + } +} + +// TestAddressFamilies tests the middleware using different combinations of +// address families for remote and proxy endpoints. +func TestAddressFamilies(t *testing.T) { + t.Parallel() + + for _, clientFamily := range []string{"ipv4", "ipv6"} { + for _, proxyFamily := range []string{"ipv4", "ipv6"} { + for _, header := range []string{"Cf-Connecting-Ip", "True-Client-Ip", "X-Real-Ip", "X-Forwarded-For"} { + clientFamily := clientFamily + proxyFamily := proxyFamily + header := header + name := fmt.Sprintf("%s-%s-%s", strings.ToLower(header), clientFamily, proxyFamily) + + t.Run(name, func(t *testing.T) { + t.Parallel() + + clientAddr := "123.123.123.123" + if clientFamily == "ipv6" { + clientAddr = "2a03:2880:f10c:83:face:b00c:0:25de" + } + + proxyAddr := "4.4.4.4" + if proxyFamily == "ipv6" { + proxyAddr = "2001:4860:4860::8888" + } + + config := &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + { + IP: net.ParseIP("0::"), + Mask: net.CIDRMask(0, 128), + }, + }, + TrustedHeaders: []string{ + "Cf-Connecting-Ip", + "X-Forwarded-For", + "X-Real-Ip", + "True-Client-Ip", + }, + } + + middleware := httpmw.ExtractRealIP(config) + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header.Set(header, clientAddr) + req.RemoteAddr = proxyAddr + + handlerCalled := false + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + require.Equal(t, clientAddr, req.RemoteAddr, "remote address should match remote client") + + handlerCalled = true + }) + + middleware(nextHandler).ServeHTTP(httptest.NewRecorder(), req) + + require.True(t, handlerCalled, "expected handler to be invoked") + }) + } + } + } +} + +// TestFilterUntrusted tests that untrusted headers are removed from the request. +func TestFilterUntrusted(t *testing.T) { + t.Parallel() + + tests := []struct { + Name string + Config *httpmw.RealIPConfig + Header http.Header + RemoteAddr string + ExpectedHeader http.Header + ExpectedRemoteAddr string + }{ + { + Name: "untrusted-origin", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: nil, + TrustedHeaders: []string{ + "Cf-Connecting-Ip", + "X-Forwarded-For", + "X-Real-Ip", + "True-Client-Ip", + }, + }, + Header: http.Header{ + "X-Forwarded-For": []string{"1.2.3.4,123.45.67.89"}, + "X-Forwarded-Proto": []string{"https"}, + "X-Real-Ip": []string{"4.4.4.4"}, + "True-Client-Ip": []string{"5.6.7.8"}, + "Authorization": []string{"Bearer 123"}, + "Accept-Encoding": []string{"gzip", "compress", "deflate", "identity"}, + }, + RemoteAddr: "1.2.3.4", + ExpectedHeader: http.Header{ + "Authorization": []string{"Bearer 123"}, + "Accept-Encoding": []string{"gzip", "compress", "deflate", "identity"}, + "X-Forwarded-Proto": []string{"https"}, + }, + ExpectedRemoteAddr: "1.2.3.4", + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header = test.Header + req.RemoteAddr = test.RemoteAddr + + httpmw.FilterUntrustedOriginHeaders(test.Config, req) + require.Equal(t, test.ExpectedRemoteAddr, req.RemoteAddr, "remote address should match") + require.Equal(t, test.ExpectedHeader, req.Header, "filtered headers should match") + }) + } +} + +// TestApplicationProxy checks headers passed to DevURL services are as expected. +func TestApplicationProxy(t *testing.T) { + t.Parallel() + + tests := []struct { + Name string + Config *httpmw.RealIPConfig + Header http.Header + RemoteAddr string + TLS bool + ExpectedHeader http.Header + ExpectedRemoteAddr string + }{ + { + Name: "untrusted-origin-http", + Config: nil, + Header: http.Header{ + "X-Forwarded-For": []string{"123.45.67.89,10.10.10.10"}, + }, + RemoteAddr: "17.18.19.20", + TLS: false, + ExpectedHeader: http.Header{ + "X-Forwarded-For": []string{"17.18.19.20"}, + "X-Forwarded-Proto": []string{"http"}, + }, + ExpectedRemoteAddr: "17.18.19.20", + }, + { + Name: "untrusted-origin-https", + Config: nil, + Header: http.Header{ + "X-Forwarded-For": []string{"123.45.67.89,10.10.10.10"}, + }, + RemoteAddr: "17.18.19.20", + TLS: true, + ExpectedHeader: http.Header{ + "X-Forwarded-For": []string{"17.18.19.20"}, + "X-Forwarded-Proto": []string{"https"}, + }, + ExpectedRemoteAddr: "17.18.19.20", + }, + { + Name: "trusted-real-ip", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Real-Ip", + }, + }, + Header: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + "X-Forwarded-For": []string{"123.45.67.89,10.10.10.10"}, + "X-Forwarded-Proto": []string{"https"}, + }, + RemoteAddr: "17.18.19.20", + TLS: true, + ExpectedHeader: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + "X-Forwarded-For": []string{"99.88.77.66,17.18.19.20"}, + "X-Forwarded-Proto": []string{"https"}, + }, + ExpectedRemoteAddr: "99.88.77.66", + }, + { + Name: "trusted-real-ip-and-forwarded-conflict", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Forwarded-For", + "X-Real-Ip", + }, + }, + Header: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + "X-Forwarded-For": []string{"123.45.67.89,10.10.10.10"}, + "X-Forwarded-Proto": []string{"https"}, + }, + RemoteAddr: "17.18.19.20", + TLS: false, + ExpectedHeader: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + // Even though X-Real-Ip and X-Forwarded-For are both trusted, + // ignore the value of X-Forwarded-For, since they conflict + "X-Forwarded-For": []string{"123.45.67.89,10.10.10.10,17.18.19.20"}, + "X-Forwarded-Proto": []string{"https"}, + }, + ExpectedRemoteAddr: "123.45.67.89", + }, + { + Name: "trusted-real-ip-and-forwarded-same", + Config: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{ + { + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + }, + }, + TrustedHeaders: []string{ + "X-Forwarded-For", + "X-Real-Ip", + }, + }, + Header: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + // X-Real-Ip and X-Forwarded-For are both trusted, and since + // they match, append the proxy address to X-Forwarded-For + "X-Forwarded-For": []string{"99.88.77.66,123.45.67.89,10.10.10.10"}, + "X-Forwarded-Proto": []string{"https"}, + }, + RemoteAddr: "17.18.19.20", + TLS: false, + ExpectedHeader: http.Header{ + "X-Real-Ip": []string{"99.88.77.66"}, + "X-Forwarded-For": []string{"99.88.77.66,123.45.67.89,10.10.10.10,17.18.19.20"}, + "X-Forwarded-Proto": []string{"https"}, + }, + ExpectedRemoteAddr: "99.88.77.66", + }, + } + + for _, test := range tests { + test := test + t.Run(test.Name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + req.Header = test.Header + req.RemoteAddr = test.RemoteAddr + if test.TLS { + req.TLS = &tls.ConnectionState{} + } else { + req.TLS = nil + } + + middleware := httpmw.ExtractRealIP(test.Config) + + handlerCalled := false + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + require.Equal(t, test.ExpectedRemoteAddr, req.RemoteAddr, "remote address should match") + + httpmw.FilterUntrustedOriginHeaders(test.Config, req) + err := httpmw.EnsureXForwardedForHeader(req) + require.NoError(t, err, "ensure X-Forwarded-For should be successful") + + require.Equal(t, test.ExpectedHeader, req.Header, "filtered headers should match") + + handlerCalled = true + }) + + middleware(nextHandler).ServeHTTP(httptest.NewRecorder(), req) + + require.True(t, handlerCalled, "expected handler to be invoked") + }) + } +} diff --git a/coderd/workspaceapps.go b/coderd/workspaceapps.go index 5a7c192602fb5..1da951c63d46c 100644 --- a/coderd/workspaceapps.go +++ b/coderd/workspaceapps.go @@ -580,6 +580,15 @@ func (api *API) proxyWorkspaceApplication(proxyApp proxyApplication, rw http.Res return } + // Filter IP headers from untrusted origins! + httpmw.FilterUntrustedOriginHeaders(api.RealIPConfig, r) + // Ensure proper IP headers get sent to the forwarded application. + err := httpmw.EnsureXForwardedForHeader(r) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + // If the app does not exist, but the app name is a port number, then // route to the port as an "anonymous app". We only support HTTP for // port-based URLs. diff --git a/coderd/workspaceapps_test.go b/coderd/workspaceapps_test.go index bbf80746d7987..de8473807150a 100644 --- a/coderd/workspaceapps_test.go +++ b/coderd/workspaceapps_test.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/agent" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" @@ -85,6 +86,7 @@ func setupProxyTest(t *testing.T, customAppHost ...string) (*codersdk.Client, co Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := r.Cookie(codersdk.SessionTokenKey) assert.ErrorIs(t, err, http.ErrNoCookie) + w.Header().Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For")) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(proxyTestAppBody)) }), @@ -107,6 +109,15 @@ func setupProxyTest(t *testing.T, customAppHost ...string) (*codersdk.Client, co IncludeProvisionerDaemon: true, AgentStatsRefreshInterval: time.Millisecond * 100, MetricsCacheRefreshInterval: time.Millisecond * 100, + RealIPConfig: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(8, 32), + }}, + TrustedHeaders: []string{ + "CF-Connecting-IP", + }, + }, }) user := coderdtest.CreateFirstUser(t, client) @@ -280,6 +291,24 @@ func TestWorkspaceAppsProxyPath(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) }) + t.Run("ForwardsIP", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + resp, err := client.Request(ctx, http.MethodGet, fmt.Sprintf("/@me/%s/apps/%s/?%s", workspace.Name, proxyTestAppNameOwner, proxyTestAppQuery), nil, func(r *http.Request) { + r.Header.Set("Cf-Connecting-IP", "1.1.1.1") + }) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, proxyTestAppBody, string(body)) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "1.1.1.1,127.0.0.1", resp.Header.Get("X-Forwarded-For")) + }) + t.Run("ProxyError", func(t *testing.T) { t.Parallel() diff --git a/codersdk/deploymentconfig.go b/codersdk/deploymentconfig.go index 1ab43ddbb4289..8cdebb4e63093 100644 --- a/codersdk/deploymentconfig.go +++ b/codersdk/deploymentconfig.go @@ -28,6 +28,8 @@ type DeploymentConfig struct { PrometheusAddress DeploymentConfigField[string] `json:"prometheus_address"` PprofEnable DeploymentConfigField[bool] `json:"pprof_enabled"` PprofAddress DeploymentConfigField[string] `json:"pprof_address"` + ProxyTrustedHeaders DeploymentConfigField[[]string] `json:"proxy_trusted_headers"` + ProxyTrustedOrigins DeploymentConfigField[[]string] `json:"proxy_trusted_origins"` CacheDirectory DeploymentConfigField[string] `json:"cache_directory"` InMemoryDatabase DeploymentConfigField[bool] `json:"in_memory_database"` ProvisionerDaemons DeploymentConfigField[int] `json:"provisioner_daemon_count"` diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index a34cda1d8cfd9..8cf296785becc 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -269,6 +269,8 @@ export interface DeploymentConfig { readonly prometheus_address: DeploymentConfigField readonly pprof_enabled: DeploymentConfigField readonly pprof_address: DeploymentConfigField + readonly proxy_trusted_headers: DeploymentConfigField + readonly proxy_trusted_origins: DeploymentConfigField readonly cache_directory: DeploymentConfigField readonly in_memory_database: DeploymentConfigField readonly provisioner_daemon_count: DeploymentConfigField