1
1
package httpmw_test
2
2
3
3
import (
4
+ "context"
5
+ "crypto/sha256"
6
+ "fmt"
7
+ "math/rand"
8
+ "net"
4
9
"net/http"
5
10
"net/http/httptest"
6
11
"testing"
7
12
"time"
8
13
9
14
"github.com/go-chi/chi/v5"
15
+ "github.com/google/uuid"
10
16
"github.com/stretchr/testify/require"
11
17
18
+ "github.com/coder/coder/coderd/database"
19
+ "github.com/coder/coder/coderd/database/databasefake"
12
20
"github.com/coder/coder/coderd/httpmw"
21
+ "github.com/coder/coder/coderd/rbac"
22
+ "github.com/coder/coder/codersdk"
13
23
"github.com/coder/coder/testutil"
14
24
)
15
25
26
+ func insertAPIKey (ctx context.Context , t * testing.T , db database.Store , userID uuid.UUID ) string {
27
+ id , secret := randomAPIKeyParts ()
28
+ hashed := sha256 .Sum256 ([]byte (secret ))
29
+
30
+ _ , err := db .InsertAPIKey (ctx , database.InsertAPIKeyParams {
31
+ ID : id ,
32
+ HashedSecret : hashed [:],
33
+ LastUsed : database .Now ().AddDate (0 , 0 , - 1 ),
34
+ ExpiresAt : database .Now ().AddDate (0 , 0 , 1 ),
35
+ UserID : userID ,
36
+ LoginType : database .LoginTypePassword ,
37
+ Scope : database .APIKeyScopeAll ,
38
+ })
39
+ require .NoError (t , err )
40
+
41
+ return fmt .Sprintf ("%s-%s" , id , secret )
42
+ }
43
+
44
+ func randRemoteAddr () string {
45
+ var b [4 ]byte
46
+ rand .Read (b [:])
47
+ return fmt .Sprintf ("%s:%v" , net .IP (b [:]).String (), rand .Int31 ()% (1 << 16 ))
48
+ }
49
+
16
50
func TestRateLimit (t * testing.T ) {
17
51
t .Parallel ()
18
- t .Run ("NoUser " , func (t * testing.T ) {
52
+ t .Run ("NoUserSucceeds " , func (t * testing.T ) {
19
53
t .Parallel ()
20
54
rtr := chi .NewRouter ()
21
55
rtr .Use (httpmw .RateLimit (5 , time .Second ))
@@ -32,4 +66,107 @@ func TestRateLimit(t *testing.T) {
32
66
return resp .StatusCode == http .StatusTooManyRequests
33
67
}, testutil .WaitShort , testutil .IntervalFast )
34
68
})
69
+
70
+ t .Run ("RandomIPs" , func (t * testing.T ) {
71
+ t .Parallel ()
72
+ rtr := chi .NewRouter ()
73
+ rtr .Use (httpmw .RateLimit (5 , time .Second ))
74
+ rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
75
+ rw .WriteHeader (http .StatusOK )
76
+ })
77
+
78
+ require .Never (t , func () bool {
79
+ req := httptest .NewRequest ("GET" , "/" , nil )
80
+ rec := httptest .NewRecorder ()
81
+ req .RemoteAddr = randRemoteAddr ()
82
+ rtr .ServeHTTP (rec , req )
83
+ resp := rec .Result ()
84
+ defer resp .Body .Close ()
85
+ return resp .StatusCode == http .StatusTooManyRequests
86
+ }, testutil .WaitShort , testutil .IntervalFast )
87
+ })
88
+
89
+ t .Run ("RegularUser" , func (t * testing.T ) {
90
+ t .Parallel ()
91
+
92
+ ctx := context .Background ()
93
+
94
+ db := databasefake .New ()
95
+
96
+ u := createUser (ctx , t , db )
97
+ key := insertAPIKey (ctx , t , db , u .ID )
98
+
99
+ rtr := chi .NewRouter ()
100
+ rtr .Use (httpmw .ExtractAPIKey (httpmw.ExtractAPIKeyConfig {
101
+ DB : db ,
102
+ Optional : false ,
103
+ }))
104
+
105
+ rtr .Use (httpmw .RateLimit (5 , time .Second ))
106
+ rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
107
+ rw .WriteHeader (http .StatusOK )
108
+ })
109
+
110
+ // Bypass must fail
111
+ req := httptest .NewRequest ("GET" , "/" , nil )
112
+ req .Header .Set (codersdk .SessionCustomHeader , key )
113
+ req .Header .Set (codersdk .BypassRatelimitHeader , "true" )
114
+ rec := httptest .NewRecorder ()
115
+ // Assert we're not using IP address.
116
+ req .RemoteAddr = randRemoteAddr ()
117
+ rtr .ServeHTTP (rec , req )
118
+ resp := rec .Result ()
119
+ defer resp .Body .Close ()
120
+ require .Equal (t , http .StatusPreconditionRequired , resp .StatusCode )
121
+
122
+ require .Eventually (t , func () bool {
123
+ req := httptest .NewRequest ("GET" , "/" , nil )
124
+ req .Header .Set (codersdk .SessionCustomHeader , key )
125
+ rec := httptest .NewRecorder ()
126
+ // Assert we're not using IP address.
127
+ req .RemoteAddr = randRemoteAddr ()
128
+ rtr .ServeHTTP (rec , req )
129
+ resp := rec .Result ()
130
+ defer resp .Body .Close ()
131
+ return resp .StatusCode == http .StatusTooManyRequests
132
+ }, testutil .WaitShort , testutil .IntervalFast )
133
+ })
134
+
135
+ t .Run ("OwnerBypass" , func (t * testing.T ) {
136
+ t .Parallel ()
137
+
138
+ ctx := context .Background ()
139
+
140
+ db := databasefake .New ()
141
+
142
+ u := createUser (ctx , t , db , func (u * database.InsertUserParams ) {
143
+ u .RBACRoles = []string {rbac .RoleOwner ()}
144
+ })
145
+
146
+ key := insertAPIKey (ctx , t , db , u .ID )
147
+
148
+ rtr := chi .NewRouter ()
149
+ rtr .Use (httpmw .ExtractAPIKey (httpmw.ExtractAPIKeyConfig {
150
+ DB : db ,
151
+ Optional : false ,
152
+ }))
153
+
154
+ rtr .Use (httpmw .RateLimit (5 , time .Second ))
155
+ rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
156
+ rw .WriteHeader (http .StatusOK )
157
+ })
158
+
159
+ require .Never (t , func () bool {
160
+ req := httptest .NewRequest ("GET" , "/" , nil )
161
+ req .Header .Set (codersdk .SessionCustomHeader , key )
162
+ req .Header .Set (codersdk .BypassRatelimitHeader , "true" )
163
+ rec := httptest .NewRecorder ()
164
+ // Assert we're not using IP address.
165
+ req .RemoteAddr = randRemoteAddr ()
166
+ rtr .ServeHTTP (rec , req )
167
+ resp := rec .Result ()
168
+ defer resp .Body .Close ()
169
+ return resp .StatusCode == http .StatusTooManyRequests
170
+ }, testutil .WaitShort , testutil .IntervalFast )
171
+ })
35
172
}
0 commit comments