@@ -18,7 +18,6 @@ import (
18
18
"github.com/coder/coder/coderd/httpmw"
19
19
"github.com/coder/coder/coderd/rbac"
20
20
"github.com/coder/coder/codersdk"
21
- "github.com/coder/coder/testutil"
22
21
)
23
22
24
23
func randRemoteAddr () string {
@@ -34,38 +33,39 @@ func TestRateLimit(t *testing.T) {
34
33
t .Run ("NoUserSucceeds" , func (t * testing.T ) {
35
34
t .Parallel ()
36
35
rtr := chi .NewRouter ()
37
- rtr .Use (httpmw .RateLimit (5 , time .Second ))
36
+ rtr .Use (httpmw .RateLimit (1 , time .Second ))
38
37
rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
39
38
rw .WriteHeader (http .StatusOK )
40
39
})
41
40
42
- require . Eventually ( t , func () bool {
41
+ for i := 0 ; i < 5 ; i ++ {
43
42
req := httptest .NewRequest ("GET" , "/" , nil )
44
43
rec := httptest .NewRecorder ()
45
44
rtr .ServeHTTP (rec , req )
46
45
resp := rec .Result ()
47
- defer resp .Body .Close ()
48
- return resp .StatusCode == http .StatusTooManyRequests
49
- }, testutil . WaitShort , testutil . IntervalFast )
46
+ _ = resp .Body .Close ()
47
+ require . Equal ( t , i != 0 , resp .StatusCode == http .StatusTooManyRequests )
48
+ }
50
49
})
51
50
52
51
t .Run ("RandomIPs" , func (t * testing.T ) {
53
52
t .Parallel ()
54
53
rtr := chi .NewRouter ()
55
- rtr .Use (httpmw .RateLimit (5 , time .Second ))
54
+ // Because these are random IPs, the limit should never be hit!
55
+ rtr .Use (httpmw .RateLimit (1 , time .Second ))
56
56
rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
57
57
rw .WriteHeader (http .StatusOK )
58
58
})
59
59
60
- require . Never ( t , func () bool {
60
+ for i := 0 ; i < 5 ; i ++ {
61
61
req := httptest .NewRequest ("GET" , "/" , nil )
62
62
rec := httptest .NewRecorder ()
63
63
req .RemoteAddr = randRemoteAddr ()
64
64
rtr .ServeHTTP (rec , req )
65
65
resp := rec .Result ()
66
- defer resp .Body .Close ()
67
- return resp .StatusCode == http .StatusTooManyRequests
68
- }, testutil . WaitShort , testutil . IntervalFast )
66
+ _ = resp .Body .Close ()
67
+ require . False ( t , resp .StatusCode == http .StatusTooManyRequests )
68
+ }
69
69
})
70
70
71
71
t .Run ("RegularUser" , func (t * testing.T ) {
@@ -81,7 +81,7 @@ func TestRateLimit(t *testing.T) {
81
81
Optional : false ,
82
82
}))
83
83
84
- rtr .Use (httpmw .RateLimit (5 , time .Second ))
84
+ rtr .Use (httpmw .RateLimit (1 , time .Second ))
85
85
rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
86
86
rw .WriteHeader (http .StatusOK )
87
87
})
@@ -98,17 +98,17 @@ func TestRateLimit(t *testing.T) {
98
98
defer resp .Body .Close ()
99
99
require .Equal (t , http .StatusPreconditionRequired , resp .StatusCode )
100
100
101
- require . Eventually ( t , func () bool {
101
+ for i := 0 ; i < 5 ; i ++ {
102
102
req := httptest .NewRequest ("GET" , "/" , nil )
103
103
req .Header .Set (codersdk .SessionTokenHeader , key )
104
104
rec := httptest .NewRecorder ()
105
105
// Assert we're not using IP address.
106
106
req .RemoteAddr = randRemoteAddr ()
107
107
rtr .ServeHTTP (rec , req )
108
108
resp := rec .Result ()
109
- defer resp .Body .Close ()
110
- return resp .StatusCode == http .StatusTooManyRequests
111
- }, testutil . WaitShort , testutil . IntervalFast )
109
+ _ = resp .Body .Close ()
110
+ require . Equal ( t , i != 0 , resp .StatusCode == http .StatusTooManyRequests )
111
+ }
112
112
})
113
113
114
114
t .Run ("OwnerBypass" , func (t * testing.T ) {
@@ -127,12 +127,12 @@ func TestRateLimit(t *testing.T) {
127
127
Optional : false ,
128
128
}))
129
129
130
- rtr .Use (httpmw .RateLimit (5 , time .Second ))
130
+ rtr .Use (httpmw .RateLimit (1 , time .Second ))
131
131
rtr .Get ("/" , func (rw http.ResponseWriter , r * http.Request ) {
132
132
rw .WriteHeader (http .StatusOK )
133
133
})
134
134
135
- require . Never ( t , func () bool {
135
+ for i := 0 ; i < 5 ; i ++ {
136
136
req := httptest .NewRequest ("GET" , "/" , nil )
137
137
req .Header .Set (codersdk .SessionTokenHeader , key )
138
138
req .Header .Set (codersdk .BypassRatelimitHeader , "true" )
@@ -141,8 +141,8 @@ func TestRateLimit(t *testing.T) {
141
141
req .RemoteAddr = randRemoteAddr ()
142
142
rtr .ServeHTTP (rec , req )
143
143
resp := rec .Result ()
144
- defer resp .Body .Close ()
145
- return resp .StatusCode == http .StatusTooManyRequests
146
- }, testutil . WaitShort , testutil . IntervalFast )
144
+ _ = resp .Body .Close ()
145
+ require . False ( t , resp .StatusCode == http .StatusTooManyRequests )
146
+ }
147
147
})
148
148
}
0 commit comments