Skip to content

Commit a26dfb7

Browse files
authored
fix: add --real-ip-header flag (#9)
1 parent 7027e23 commit a26dfb7

File tree

5 files changed

+101
-7
lines changed

5 files changed

+101
-7
lines changed

cmd/tunneld/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ func main() {
104104
Value: tunneld.DefaultWireguardNetworkPrefix.String(),
105105
EnvVars: []string{"TUNNELD_WIREGUARD_NETWORK_PREFIX"},
106106
},
107+
&cli.StringFlag{
108+
Name: "real-ip-header",
109+
Usage: "Use the given header as the real IP address rather than the remote socket address.",
110+
Value: "",
111+
EnvVars: []string{"TUNNELD_REAL_IP_HEADER"},
112+
},
107113
&cli.StringFlag{
108114
Name: "pprof-listen-address",
109115
Usage: "The address to listen on for pprof. If set to an empty string, pprof will not be enabled.",
@@ -137,6 +143,7 @@ func runApp(ctx *cli.Context) error {
137143
wireguardMTU = ctx.Int("wireguard-mtu")
138144
wireguardServerIP = ctx.String("wireguard-server-ip")
139145
wireguardNetworkPrefix = ctx.String("wireguard-network-prefix")
146+
realIPHeader = ctx.String("real-ip-header")
140147
pprofListenAddress = ctx.String("pprof-listen-address")
141148
tracingHoneycombTeam = ctx.String("tracing-honeycomb-team")
142149
)
@@ -240,6 +247,7 @@ func runApp(ctx *cli.Context) error {
240247
WireguardMTU: wireguardMTU,
241248
WireguardServerIP: wireguardServerIPParsed,
242249
WireguardNetworkPrefix: wireguardNetworkPrefixParsed,
250+
RealIPHeader: realIPHeader,
243251
}
244252
td, err := tunneld.New(options)
245253
if err != nil {

tunneld/api.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ func (api *API) Router() chi.Router {
3333
// Post tunnel middleware, this middleware will never execute on
3434
// tunneled connections.
3535
httpmw.LimitBody(1<<20), // change back to 1MB
36-
httpmw.RateLimit(10, 10*time.Second),
36+
httpmw.RateLimit(httpmw.RateLimitConfig{
37+
Log: api.Log.Named("ratelimier"),
38+
Count: 10,
39+
Window: 10 * time.Second,
40+
RealIPHeader: api.Options.RealIPHeader,
41+
}),
3742
)
3843

3944
r.Post("/tun", api.postTun)

tunneld/httpmw/ratelimit.go

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,97 @@ package httpmw
22

33
import (
44
"fmt"
5+
"net"
56
"net/http"
7+
"strings"
8+
"sync"
69
"time"
710

811
"github.com/go-chi/httprate"
912

13+
"cdr.dev/slog"
1014
"github.com/coder/wgtunnel/tunneld/httpapi"
1115
"github.com/coder/wgtunnel/tunnelsdk"
1216
)
1317

18+
type RateLimitConfig struct {
19+
Log slog.Logger
20+
21+
// Count of the amount of requests allowed in the Window. If the Count is
22+
// zero, the rate limiter is disabled.
23+
Count int
24+
Window time.Duration
25+
26+
// RealIPHeader is the header to use to get the real IP address of the
27+
// request. If this is empty, the request's RemoteAddr is used.
28+
RealIPHeader string
29+
}
30+
1431
// RateLimit returns a handler that limits requests based on IP.
15-
func RateLimit(count int, window time.Duration) func(http.Handler) http.Handler {
16-
if count <= 0 {
32+
func RateLimit(cfg RateLimitConfig) func(http.Handler) http.Handler {
33+
if cfg.Count <= 0 {
1734
return func(handler http.Handler) http.Handler {
1835
return handler
1936
}
2037
}
2138

39+
var logMissingHeaderOnce sync.Once
40+
2241
return httprate.Limit(
23-
count,
24-
window,
25-
httprate.WithKeyByIP(),
42+
cfg.Count,
43+
cfg.Window,
44+
httprate.WithKeyFuncs(func(r *http.Request) (string, error) {
45+
if cfg.RealIPHeader != "" {
46+
val := r.Header.Get(cfg.RealIPHeader)
47+
if val != "" {
48+
val = strings.TrimSpace(strings.Split(val, ",")[0])
49+
return canonicalizeIP(val), nil
50+
}
51+
52+
logMissingHeaderOnce.Do(func() {
53+
cfg.Log.Warn(r.Context(), "real IP header not found or invalid on request", slog.F("header", cfg.RealIPHeader), slog.F("value", val))
54+
})
55+
}
56+
57+
return httprate.KeyByIP(r)
58+
}),
2659
httprate.WithLimitHandler(func(rw http.ResponseWriter, r *http.Request) {
2760
httpapi.Write(r.Context(), rw, http.StatusTooManyRequests, tunnelsdk.Response{
28-
Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", count, window),
61+
Message: fmt.Sprintf("You've been rate limited for sending more than %v requests in %v.", cfg.Count, cfg.Window),
2962
})
3063
}),
3164
)
3265
}
66+
67+
// canonicalizeIP returns a form of ip suitable for comparison to other IPs.
68+
// For IPv4 addresses, this is simply the whole string.
69+
// For IPv6 addresses, this is the /64 prefix.
70+
//
71+
// This function is taken directly from go-chi/httprate:
72+
// https://github.com/go-chi/httprate/blob/0ea2148d09a46ae62efcad05b70d87418d8e4f43/httprate.go#L111
73+
func canonicalizeIP(ip string) string {
74+
isIPv6 := false
75+
// This is how net.ParseIP decides if an address is IPv6
76+
// https://cs.opensource.google/go/go/+/refs/tags/go1.17.7:src/net/ip.go;l=704
77+
for i := 0; !isIPv6 && i < len(ip); i++ {
78+
switch ip[i] {
79+
case '.':
80+
// IPv4
81+
return ip
82+
case ':':
83+
// IPv6
84+
isIPv6 = true
85+
}
86+
}
87+
if !isIPv6 {
88+
// Not an IP address at all
89+
return ip
90+
}
91+
92+
ipv6 := net.ParseIP(ip)
93+
if ipv6 == nil {
94+
return ip
95+
}
96+
97+
return ipv6.Mask(net.CIDRMask(64, 128)).String()
98+
}

tunneld/options.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"encoding/base32"
66
"encoding/hex"
77
"net"
8+
"net/http"
89
"net/netip"
910
"net/url"
1011
"strings"
@@ -60,6 +61,12 @@ type Options struct {
6061
// at least 64 bits of space available. Defaults to fcca::/16.
6162
WireguardNetworkPrefix netip.Prefix
6263

64+
// RealIPHeader is the header to use for getting a request's IP address. If
65+
// not set, the request's RemoteAddr will be used.
66+
//
67+
// Used for rate limiting.
68+
RealIPHeader string
69+
6370
// PeerDialTimeout is the timeout for dialing a peer on a request. Defaults
6471
// to 10 seconds.
6572
PeerDialTimeout time.Duration
@@ -113,6 +120,10 @@ func (options *Options) Validate() error {
113120
return xerrors.New("WireguardServerIP must be contained within WireguardNetworkPrefix")
114121
}
115122

123+
if options.RealIPHeader != "" {
124+
options.RealIPHeader = http.CanonicalHeaderKey(options.RealIPHeader)
125+
}
126+
116127
if options.PeerDialTimeout <= 0 {
117128
options.PeerDialTimeout = DefaultPeerDialTimeout
118129
}

tunneld/options_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ func Test_Option(t *testing.T) {
3939
WireguardMTU: tunneld.DefaultWireguardMTU + 1,
4040
WireguardServerIP: netip.MustParseAddr("feed::1"),
4141
WireguardNetworkPrefix: netip.MustParsePrefix("feed::1/64"),
42+
RealIPHeader: "X-Real-Ip",
4243
PeerDialTimeout: 1 * time.Second,
4344
}
4445

@@ -66,6 +67,7 @@ func Test_Option(t *testing.T) {
6667
WireguardEndpoint: "localhost:1234",
6768
WireguardPort: 1234,
6869
WireguardKey: key,
70+
RealIPHeader: "x-real-ip",
6971
}
7072

7173
err := o.Validate()
@@ -78,6 +80,8 @@ func Test_Option(t *testing.T) {
7880
require.EqualValues(t, tunneld.DefaultWireguardMTU, o.WireguardMTU)
7981
require.Equal(t, tunneld.DefaultWireguardServerIP, o.WireguardServerIP)
8082
require.Equal(t, tunneld.DefaultWireguardNetworkPrefix, o.WireguardNetworkPrefix)
83+
// should be canonicalized.
84+
require.Equal(t, "X-Real-Ip", o.RealIPHeader)
8185
})
8286

8387
t.Run("Invalid", func(t *testing.T) {

0 commit comments

Comments
 (0)