Skip to content

Commit f75a54c

Browse files
authored
feat: Support x-forwarded-for headers for IPs (#4684)
* feat: Support x-forwarded-for headers for IPs Fixes #4430. * Fix realip accepting headers * Fix unused headers
1 parent 795ed3d commit f75a54c

15 files changed

+946
-23
lines changed

cli/deployment/config.go

+10
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ func newConfig() codersdk.DeploymentConfig {
113113
Flag: "pprof-address",
114114
Value: "127.0.0.1:6060",
115115
},
116+
ProxyTrustedHeaders: codersdk.DeploymentConfigField[[]string]{
117+
Key: "proxy.trusted_headers",
118+
Flag: "proxy-trusted-headers",
119+
Usage: "Headers to trust for forwarding IP addresses. e.g. Cf-Connecting-IP True-Client-Ip, X-Forwarded-for",
120+
},
121+
ProxyTrustedOrigins: codersdk.DeploymentConfigField[[]string]{
122+
Key: "proxy.trusted_origins",
123+
Flag: "proxy-trusted-origins",
124+
Usage: "Origin addresses to respect \"proxy-trusted-headers\". e.g. example.com",
125+
},
116126
CacheDirectory: codersdk.DeploymentConfigField[string]{
117127
Key: "cache_directory",
118128
Usage: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.",

cli/server.go

+7
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import (
5757
"github.com/coder/coder/coderd/devtunnel"
5858
"github.com/coder/coder/coderd/gitsshkey"
5959
"github.com/coder/coder/coderd/httpapi"
60+
"github.com/coder/coder/coderd/httpmw"
6061
"github.com/coder/coder/coderd/prometheusmetrics"
6162
"github.com/coder/coder/coderd/telemetry"
6263
"github.com/coder/coder/coderd/tracing"
@@ -325,6 +326,11 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
325326
}
326327
}
327328

329+
realIPConfig, err := httpmw.ParseRealIPConfig(cfg.ProxyTrustedHeaders.Value, cfg.ProxyTrustedOrigins.Value)
330+
if err != nil {
331+
return xerrors.Errorf("parse real ip config: %w", err)
332+
}
333+
328334
options := &coderd.Options{
329335
AccessURL: accessURLParsed,
330336
AppHostname: appHostname,
@@ -335,6 +341,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
335341
Pubsub: database.NewPubsubInMemory(),
336342
CacheDir: cfg.CacheDirectory.Value,
337343
GoogleTokenValidator: googleTokenValidator,
344+
RealIPConfig: realIPConfig,
338345
SecureAuthCookie: cfg.SecureAuthCookie.Value,
339346
SSHKeygenAlgorithm: sshKeygenAlgorithm,
340347
TracerProvider: tracerProvider,

coderd/apikey.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,7 @@ func (api *API) createAPIKey(ctx context.Context, params createAPIKeyParams) (*h
230230
}
231231
}
232232

233-
host, _, _ := net.SplitHostPort(params.RemoteAddr)
234-
ip := net.ParseIP(host)
233+
ip := net.ParseIP(params.RemoteAddr)
235234
if ip == nil {
236235
ip = net.IPv4(0, 0, 0, 0)
237236
}

coderd/audit.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,7 @@ func (api *API) generateFakeAuditLog(rw http.ResponseWriter, r *http.Request) {
117117
return
118118
}
119119

120-
ipRaw, _, _ := net.SplitHostPort(r.RemoteAddr)
121-
ip := net.ParseIP(ipRaw)
120+
ip := net.ParseIP(r.RemoteAddr)
122121
ipNet := pqtype.Inet{}
123122
if ip != nil {
124123
ipNet = pqtype.Inet{

coderd/audit/request.go

+4-16
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,8 @@ func InitRequest[T Auditable](w http.ResponseWriter, p *RequestParams) (*Request
129129
}
130130
}
131131

132-
ip, err := parseIP(p.Request.RemoteAddr)
133-
if err != nil {
134-
p.Log.Warn(logCtx, "parse ip", slog.Error(err))
135-
}
136-
137-
err = p.Audit.Export(ctx, database.AuditLog{
132+
ip := parseIP(p.Request.RemoteAddr)
133+
err := p.Audit.Export(ctx, database.AuditLog{
138134
ID: uuid.New(),
139135
Time: database.Now(),
140136
UserID: httpmw.APIKey(p.Request).UserID,
@@ -166,16 +162,8 @@ func either[T Auditable, R any](old, new T, fn func(T) R) R {
166162
}
167163
}
168164

169-
func parseIP(ipStr string) (pqtype.Inet, error) {
170-
var err error
171-
172-
ipStr, _, err = net.SplitHostPort(ipStr)
173-
if err != nil {
174-
return pqtype.Inet{}, err
175-
}
176-
165+
func parseIP(ipStr string) pqtype.Inet {
177166
ip := net.ParseIP(ipStr)
178-
179167
ipNet := net.IPNet{}
180168
if ip != nil {
181169
ipNet = net.IPNet{
@@ -187,5 +175,5 @@ func parseIP(ipStr string) (pqtype.Inet, error) {
187175
return pqtype.Inet{
188176
IPNet: ipNet,
189177
Valid: ip != nil,
190-
}, nil
178+
}
191179
}

coderd/coderd.go

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ type Options struct {
8282
Telemetry telemetry.Reporter
8383
TracerProvider trace.TracerProvider
8484
AutoImportTemplates []AutoImportTemplate
85+
RealIPConfig *httpmw.RealIPConfig
8586

8687
// TLSCertificates is used to mesh DERP servers securely.
8788
TLSCertificates []tls.Certificate
@@ -198,6 +199,7 @@ func New(options *Options) *API {
198199
r.Use(
199200
httpmw.AttachRequestID,
200201
httpmw.Recover(api.Logger),
202+
httpmw.ExtractRealIP(api.RealIPConfig),
201203
httpmw.Logger(api.Logger),
202204
httpmw.Prometheus(options.PrometheusRegistry),
203205
// handleSubdomainApplications checks if the first subdomain is a valid

coderd/coderdtest/coderdtest.go

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import (
5757
"github.com/coder/coder/coderd/database/dbtestutil"
5858
"github.com/coder/coder/coderd/gitsshkey"
5959
"github.com/coder/coder/coderd/httpapi"
60+
"github.com/coder/coder/coderd/httpmw"
6061
"github.com/coder/coder/coderd/rbac"
6162
"github.com/coder/coder/coderd/telemetry"
6263
"github.com/coder/coder/coderd/util/ptr"
@@ -77,6 +78,7 @@ type Options struct {
7778
Experimental bool
7879
AzureCertificates x509.VerifyOptions
7980
GithubOAuth2Config *coderd.GithubOAuth2Config
81+
RealIPConfig *httpmw.RealIPConfig
8082
OIDCConfig *coderd.OIDCConfig
8183
GoogleTokenValidator *idtoken.Validator
8284
SSHKeygenAlgorithm gitsshkey.Algorithm
@@ -238,6 +240,7 @@ func NewOptions(t *testing.T, options *Options) (func(http.Handler), context.Can
238240
AWSCertificates: options.AWSCertificates,
239241
AzureCertificates: options.AzureCertificates,
240242
GithubOAuth2Config: options.GithubOAuth2Config,
243+
RealIPConfig: options.RealIPConfig,
241244
OIDCConfig: options.OIDCConfig,
242245
GoogleTokenValidator: options.GoogleTokenValidator,
243246
SSHKeygenAlgorithm: options.SSHKeygenAlgorithm,

coderd/httpmw/apikey.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,7 @@ func ExtractAPIKey(cfg ExtractAPIKeyConfig) func(http.Handler) http.Handler {
250250
// Only update LastUsed once an hour to prevent database spam.
251251
if now.Sub(key.LastUsed) > time.Hour {
252252
key.LastUsed = now
253-
host, _, _ := net.SplitHostPort(r.RemoteAddr)
254-
remoteIP := net.ParseIP(host)
253+
remoteIP := net.ParseIP(r.RemoteAddr)
255254
if remoteIP == nil {
256255
remoteIP = net.IPv4(0, 0, 0, 0)
257256
}

coderd/httpmw/apikey_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ func TestAPIKey(t *testing.T) {
512512
rw = httptest.NewRecorder()
513513
user = createUser(r.Context(), t, db)
514514
)
515-
r.RemoteAddr = "1.1.1.1:3555"
515+
r.RemoteAddr = "1.1.1.1"
516516
r.Header.Set(codersdk.SessionCustomHeader, fmt.Sprintf("%s-%s", id, secret))
517517

518518
_, err := db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{

coderd/httpmw/realip.go

+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package httpmw
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/http"
7+
"strings"
8+
9+
"golang.org/x/xerrors"
10+
11+
"github.com/coder/coder/coderd/httpapi"
12+
)
13+
14+
const (
15+
headerXForwardedFor string = "X-Forwarded-For"
16+
headerXForwardedProto string = "X-Forwarded-Proto"
17+
)
18+
19+
// RealIPConfig configures the search order for the function, which controls
20+
// which headers to consider trusted.
21+
type RealIPConfig struct {
22+
// TrustedOrigins is a list of networks that will be trusted. If
23+
// any non-trusted address supplies these headers, they will be
24+
// ignored.
25+
TrustedOrigins []*net.IPNet
26+
27+
// TrustedHeaders lists headers that are trusted for forwarding
28+
// IP addresses. e.g. "CF-Connecting-IP", "True-Client-IP", etc.
29+
TrustedHeaders []string
30+
}
31+
32+
// ExtractRealIP is a middleware that uses headers from reverse proxies to
33+
// propagate origin IP address information, when configured to do so.
34+
func ExtractRealIP(config *RealIPConfig) func(next http.Handler) http.Handler {
35+
return func(next http.Handler) http.Handler {
36+
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
37+
// Preserve the original TLS connection state and RemoteAddr
38+
req = req.WithContext(context.WithValue(req.Context(), ctxKey{}, &RealIPState{
39+
Config: config,
40+
OriginalRemoteAddr: req.RemoteAddr,
41+
}))
42+
43+
info, err := ExtractRealIPAddress(config, req)
44+
if err != nil {
45+
httpapi.InternalServerError(w, err)
46+
return
47+
}
48+
req.RemoteAddr = info.String()
49+
50+
next.ServeHTTP(w, req)
51+
})
52+
}
53+
}
54+
55+
// ExtractRealIPAddress returns the original client address according to the
56+
// configuration and headers. It does not mutate the original request.
57+
func ExtractRealIPAddress(config *RealIPConfig, req *http.Request) (net.IP, error) {
58+
if config == nil {
59+
config = &RealIPConfig{}
60+
}
61+
62+
cf := isContainedIn(config.TrustedOrigins, getRemoteAddress(req.RemoteAddr))
63+
if !cf {
64+
// Address is not valid or the origin is not trusted; use the
65+
// original address
66+
return getRemoteAddress(req.RemoteAddr), nil
67+
}
68+
69+
for _, trustedHeader := range config.TrustedHeaders {
70+
addr := getRemoteAddress(req.Header.Get(trustedHeader))
71+
if addr != nil {
72+
return addr, nil
73+
}
74+
}
75+
76+
return getRemoteAddress(req.RemoteAddr), nil
77+
}
78+
79+
// FilterUntrustedOriginHeaders removes all known proxy headers from the
80+
// request for untrusted origins, and ensures that only one copy
81+
// of each proxy header is set.
82+
func FilterUntrustedOriginHeaders(config *RealIPConfig, req *http.Request) {
83+
if config == nil {
84+
config = &RealIPConfig{}
85+
}
86+
87+
cf := isContainedIn(config.TrustedOrigins, getRemoteAddress(req.RemoteAddr))
88+
if !cf {
89+
// Address is not valid or the origin is not trusted; clear
90+
// all known proxy headers and return
91+
for _, header := range config.TrustedHeaders {
92+
req.Header.Del(header)
93+
}
94+
return
95+
}
96+
97+
for _, header := range config.TrustedHeaders {
98+
req.Header.Set(header, req.Header.Get(header))
99+
}
100+
}
101+
102+
// EnsureXForwardedForHeader ensures that the request has an X-Forwarded-For
103+
// header. It uses the following logic:
104+
//
105+
// 1. If we have a direct connection (remoteAddr == proxyAddr), then
106+
// set it to remoteAddr
107+
// 2. If we have a proxied connection (remoteAddr != proxyAddr) and
108+
// X-Forwarded-For doesn't begin with remoteAddr, then overwrite
109+
// it with remoteAddr,proxyAddr
110+
// 3. If we have a proxied connection (remoteAddr != proxyAddr) and
111+
// X-Forwarded-For begins with remoteAddr, then append proxyAddr
112+
// to the original X-Forwarded-For header
113+
// 4. If X-Forwarded-Proto is not set, then it will be set to "https"
114+
// if req.TLS != nil, otherwise it will be set to "http"
115+
func EnsureXForwardedForHeader(req *http.Request) error {
116+
state := RealIP(req.Context())
117+
if state == nil {
118+
return xerrors.New("request does not contain realip.State; was it processed by httpmw.ExtractRealIP?")
119+
}
120+
121+
remoteAddr := getRemoteAddress(req.RemoteAddr)
122+
if remoteAddr == nil {
123+
return xerrors.Errorf("failed to parse remote address: %s", remoteAddr)
124+
}
125+
126+
proxyAddr := getRemoteAddress(state.OriginalRemoteAddr)
127+
if proxyAddr == nil {
128+
return xerrors.Errorf("failed to parse original address: %s", proxyAddr)
129+
}
130+
131+
if remoteAddr.Equal(proxyAddr) {
132+
req.Header.Set(headerXForwardedFor, remoteAddr.String())
133+
} else {
134+
forwarded := req.Header.Get(headerXForwardedFor)
135+
if forwarded == "" || !remoteAddr.Equal(getRemoteAddress(forwarded)) {
136+
req.Header.Set(headerXForwardedFor, remoteAddr.String()+","+proxyAddr.String())
137+
} else {
138+
req.Header.Set(headerXForwardedFor, forwarded+","+proxyAddr.String())
139+
}
140+
}
141+
142+
if req.Header.Get(headerXForwardedProto) == "" {
143+
if req.TLS != nil {
144+
req.Header.Set(headerXForwardedProto, "https")
145+
} else {
146+
req.Header.Set(headerXForwardedProto, "http")
147+
}
148+
}
149+
150+
return nil
151+
}
152+
153+
// getRemoteAddress extracts the IP address from the given string. If
154+
// the string contains commas, it assumes that the first part is the
155+
// original address.
156+
func getRemoteAddress(address string) net.IP {
157+
// X-Forwarded-For may contain multiple addresses, in case the
158+
// proxies are chained; the first value is the client address
159+
i := strings.IndexByte(address, ',')
160+
if i == -1 {
161+
i = len(address)
162+
}
163+
164+
// If the address contains a port, remove it
165+
firstAddress := address[:i]
166+
host, _, err := net.SplitHostPort(firstAddress)
167+
if err != nil {
168+
// This will error if there is no port, so try to parse the address
169+
return net.ParseIP(firstAddress)
170+
}
171+
return net.ParseIP(host)
172+
}
173+
174+
// isContainedIn checks that the given address is contained in the given
175+
// network.
176+
func isContainedIn(networks []*net.IPNet, address net.IP) bool {
177+
for _, network := range networks {
178+
if network.Contains(address) {
179+
return true
180+
}
181+
}
182+
183+
return false
184+
}
185+
186+
// RealIPState is the original state prior to modification by this middleware,
187+
// useful for getting information about the connecting client if needed.
188+
type RealIPState struct {
189+
// Config is the configuration applied in the middleware. Consider
190+
// this read-only and do not modify.
191+
Config *RealIPConfig
192+
193+
// OriginalRemoteAddr is the original RemoteAddr for the request.
194+
OriginalRemoteAddr string
195+
}
196+
197+
type ctxKey struct{}
198+
199+
// FromContext retrieves the state from the given context.Context.
200+
func RealIP(ctx context.Context) *RealIPState {
201+
state, ok := ctx.Value(ctxKey{}).(*RealIPState)
202+
if !ok {
203+
return nil
204+
}
205+
return state
206+
}
207+
208+
// ParseRealIPConfig takes a raw string array of headers and origins
209+
// to produce a config.
210+
func ParseRealIPConfig(headers, origins []string) (*RealIPConfig, error) {
211+
config := &RealIPConfig{}
212+
for _, origin := range origins {
213+
_, network, err := net.ParseCIDR(origin)
214+
if err != nil {
215+
return nil, xerrors.Errorf("parse proxy origin %q: %w", origin, err)
216+
}
217+
config.TrustedOrigins = append(config.TrustedOrigins, network)
218+
}
219+
for index, header := range headers {
220+
headers[index] = http.CanonicalHeaderKey(header)
221+
}
222+
config.TrustedHeaders = headers
223+
224+
return config, nil
225+
}

0 commit comments

Comments
 (0)