Skip to content

Commit 31d18cd

Browse files
committed
convert httpmw
1 parent 854ef5b commit 31d18cd

14 files changed

+235
-175
lines changed

coderd/httpmw/actor_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212

1313
"github.com/coder/coder/v2/coderd/database"
1414
"github.com/coder/coder/v2/coderd/database/dbgen"
15-
"github.com/coder/coder/v2/coderd/database/dbmem"
15+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
1616
"github.com/coder/coder/v2/coderd/database/dbtime"
1717
"github.com/coder/coder/v2/coderd/httpmw"
1818
"github.com/coder/coder/v2/codersdk"
@@ -38,7 +38,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) {
3838
t.Parallel()
3939

4040
var (
41-
db = dbmem.New()
41+
db, _ = dbtestutil.NewDB(t)
4242
user = dbgen.User(t, db, database.User{})
4343
_, token = dbgen.APIKey(t, db, database.APIKey{
4444
UserID: user.ID,
@@ -75,7 +75,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) {
7575
t.Parallel()
7676

7777
var (
78-
db = dbmem.New()
78+
db, _ = dbtestutil.NewDB(t)
7979
user = dbgen.User(t, db, database.User{})
8080
_, userToken = dbgen.APIKey(t, db, database.APIKey{
8181
UserID: user.ID,
@@ -114,7 +114,7 @@ func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) {
114114
t.Parallel()
115115

116116
var (
117-
db = dbmem.New()
117+
db, _ = dbtestutil.NewDB(t)
118118
proxy, token = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
119119

120120
r = httptest.NewRequest("GET", "/", nil)

coderd/httpmw/apikey_test.go

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9-
"net"
109
"net/http"
1110
"net/http/httptest"
12-
"slices"
1311
"strings"
1412
"sync/atomic"
1513
"testing"
@@ -18,12 +16,13 @@ import (
1816
"github.com/google/uuid"
1917
"github.com/stretchr/testify/assert"
2018
"github.com/stretchr/testify/require"
19+
"golang.org/x/exp/slices"
2120
"golang.org/x/oauth2"
2221

2322
"github.com/coder/coder/v2/coderd/database"
2423
"github.com/coder/coder/v2/coderd/database/dbauthz"
2524
"github.com/coder/coder/v2/coderd/database/dbgen"
26-
"github.com/coder/coder/v2/coderd/database/dbmem"
25+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
2726
"github.com/coder/coder/v2/coderd/database/dbtime"
2827
"github.com/coder/coder/v2/coderd/httpapi"
2928
"github.com/coder/coder/v2/coderd/httpmw"
@@ -83,9 +82,9 @@ func TestAPIKey(t *testing.T) {
8382
t.Run("NoCookie", func(t *testing.T) {
8483
t.Parallel()
8584
var (
86-
db = dbmem.New()
87-
r = httptest.NewRequest("GET", "/", nil)
88-
rw = httptest.NewRecorder()
85+
db, _ = dbtestutil.NewDB(t)
86+
r = httptest.NewRequest("GET", "/", nil)
87+
rw = httptest.NewRecorder()
8988
)
9089
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
9190
DB: db,
@@ -99,9 +98,9 @@ func TestAPIKey(t *testing.T) {
9998
t.Run("NoCookieRedirects", func(t *testing.T) {
10099
t.Parallel()
101100
var (
102-
db = dbmem.New()
103-
r = httptest.NewRequest("GET", "/", nil)
104-
rw = httptest.NewRecorder()
101+
db, _ = dbtestutil.NewDB(t)
102+
r = httptest.NewRequest("GET", "/", nil)
103+
rw = httptest.NewRecorder()
105104
)
106105
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
107106
DB: db,
@@ -118,9 +117,9 @@ func TestAPIKey(t *testing.T) {
118117
t.Run("InvalidFormat", func(t *testing.T) {
119118
t.Parallel()
120119
var (
121-
db = dbmem.New()
122-
r = httptest.NewRequest("GET", "/", nil)
123-
rw = httptest.NewRecorder()
120+
db, _ = dbtestutil.NewDB(t)
121+
r = httptest.NewRequest("GET", "/", nil)
122+
rw = httptest.NewRecorder()
124123
)
125124
r.Header.Set(codersdk.SessionTokenHeader, "test-wow-hello")
126125

@@ -136,9 +135,9 @@ func TestAPIKey(t *testing.T) {
136135
t.Run("InvalidIDLength", func(t *testing.T) {
137136
t.Parallel()
138137
var (
139-
db = dbmem.New()
140-
r = httptest.NewRequest("GET", "/", nil)
141-
rw = httptest.NewRecorder()
138+
db, _ = dbtestutil.NewDB(t)
139+
r = httptest.NewRequest("GET", "/", nil)
140+
rw = httptest.NewRecorder()
142141
)
143142
r.Header.Set(codersdk.SessionTokenHeader, "test-wow")
144143

@@ -154,9 +153,9 @@ func TestAPIKey(t *testing.T) {
154153
t.Run("InvalidSecretLength", func(t *testing.T) {
155154
t.Parallel()
156155
var (
157-
db = dbmem.New()
158-
r = httptest.NewRequest("GET", "/", nil)
159-
rw = httptest.NewRecorder()
156+
db, _ = dbtestutil.NewDB(t)
157+
r = httptest.NewRequest("GET", "/", nil)
158+
rw = httptest.NewRecorder()
160159
)
161160
r.Header.Set(codersdk.SessionTokenHeader, "testtestid-wow")
162161

@@ -172,7 +171,7 @@ func TestAPIKey(t *testing.T) {
172171
t.Run("NotFound", func(t *testing.T) {
173172
t.Parallel()
174173
var (
175-
db = dbmem.New()
174+
db, _ = dbtestutil.NewDB(t)
176175
id, secret = randomAPIKeyParts()
177176
r = httptest.NewRequest("GET", "/", nil)
178177
rw = httptest.NewRecorder()
@@ -191,10 +190,10 @@ func TestAPIKey(t *testing.T) {
191190
t.Run("UserLinkNotFound", func(t *testing.T) {
192191
t.Parallel()
193192
var (
194-
db = dbmem.New()
195-
r = httptest.NewRequest("GET", "/", nil)
196-
rw = httptest.NewRecorder()
197-
user = dbgen.User(t, db, database.User{
193+
db, _ = dbtestutil.NewDB(t)
194+
r = httptest.NewRequest("GET", "/", nil)
195+
rw = httptest.NewRecorder()
196+
user = dbgen.User(t, db, database.User{
198197
LoginType: database.LoginTypeGithub,
199198
})
200199
// Intentionally not inserting any user link
@@ -219,10 +218,10 @@ func TestAPIKey(t *testing.T) {
219218
t.Run("InvalidSecret", func(t *testing.T) {
220219
t.Parallel()
221220
var (
222-
db = dbmem.New()
223-
r = httptest.NewRequest("GET", "/", nil)
224-
rw = httptest.NewRecorder()
225-
user = dbgen.User(t, db, database.User{})
221+
db, _ = dbtestutil.NewDB(t)
222+
r = httptest.NewRequest("GET", "/", nil)
223+
rw = httptest.NewRecorder()
224+
user = dbgen.User(t, db, database.User{})
226225

227226
// Use a different secret so they don't match!
228227
hashed = sha256.Sum256([]byte("differentsecret"))
@@ -244,7 +243,7 @@ func TestAPIKey(t *testing.T) {
244243
t.Run("Expired", func(t *testing.T) {
245244
t.Parallel()
246245
var (
247-
db = dbmem.New()
246+
db, _ = dbtestutil.NewDB(t)
248247
user = dbgen.User(t, db, database.User{})
249248
_, token = dbgen.APIKey(t, db, database.APIKey{
250249
UserID: user.ID,
@@ -273,7 +272,7 @@ func TestAPIKey(t *testing.T) {
273272
t.Run("Valid", func(t *testing.T) {
274273
t.Parallel()
275274
var (
276-
db = dbmem.New()
275+
db, _ = dbtestutil.NewDB(t)
277276
user = dbgen.User(t, db, database.User{})
278277
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
279278
UserID: user.ID,
@@ -309,7 +308,7 @@ func TestAPIKey(t *testing.T) {
309308
t.Run("ValidWithScope", func(t *testing.T) {
310309
t.Parallel()
311310
var (
312-
db = dbmem.New()
311+
db, _ = dbtestutil.NewDB(t)
313312
user = dbgen.User(t, db, database.User{})
314313
_, token = dbgen.APIKey(t, db, database.APIKey{
315314
UserID: user.ID,
@@ -347,7 +346,7 @@ func TestAPIKey(t *testing.T) {
347346
t.Run("QueryParameter", func(t *testing.T) {
348347
t.Parallel()
349348
var (
350-
db = dbmem.New()
349+
db, _ = dbtestutil.NewDB(t)
351350
user = dbgen.User(t, db, database.User{})
352351
_, token = dbgen.APIKey(t, db, database.APIKey{
353352
UserID: user.ID,
@@ -381,7 +380,7 @@ func TestAPIKey(t *testing.T) {
381380
t.Run("ValidUpdateLastUsed", func(t *testing.T) {
382381
t.Parallel()
383382
var (
384-
db = dbmem.New()
383+
db, _ = dbtestutil.NewDB(t)
385384
user = dbgen.User(t, db, database.User{})
386385
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
387386
UserID: user.ID,
@@ -412,7 +411,7 @@ func TestAPIKey(t *testing.T) {
412411
t.Run("ValidUpdateExpiry", func(t *testing.T) {
413412
t.Parallel()
414413
var (
415-
db = dbmem.New()
414+
db, _ = dbtestutil.NewDB(t)
416415
user = dbgen.User(t, db, database.User{})
417416
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
418417
UserID: user.ID,
@@ -443,7 +442,7 @@ func TestAPIKey(t *testing.T) {
443442
t.Run("NoRefresh", func(t *testing.T) {
444443
t.Parallel()
445444
var (
446-
db = dbmem.New()
445+
db, _ = dbtestutil.NewDB(t)
447446
user = dbgen.User(t, db, database.User{})
448447
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
449448
UserID: user.ID,
@@ -475,7 +474,7 @@ func TestAPIKey(t *testing.T) {
475474
t.Run("OAuthNotExpired", func(t *testing.T) {
476475
t.Parallel()
477476
var (
478-
db = dbmem.New()
477+
db, _ = dbtestutil.NewDB(t)
479478
user = dbgen.User(t, db, database.User{})
480479
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
481480
UserID: user.ID,
@@ -511,7 +510,7 @@ func TestAPIKey(t *testing.T) {
511510
t.Run("APIKeyExpiredOAuthExpired", func(t *testing.T) {
512511
t.Parallel()
513512
var (
514-
db = dbmem.New()
513+
db, _ = dbtestutil.NewDB(t)
515514
user = dbgen.User(t, db, database.User{})
516515
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
517516
UserID: user.ID,
@@ -561,7 +560,7 @@ func TestAPIKey(t *testing.T) {
561560
t.Run("APIKeyExpiredOAuthNotExpired", func(t *testing.T) {
562561
t.Parallel()
563562
var (
564-
db = dbmem.New()
563+
db, _ = dbtestutil.NewDB(t)
565564
user = dbgen.User(t, db, database.User{})
566565
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
567566
UserID: user.ID,
@@ -607,7 +606,7 @@ func TestAPIKey(t *testing.T) {
607606
t.Run("OAuthRefresh", func(t *testing.T) {
608607
t.Parallel()
609608
var (
610-
db = dbmem.New()
609+
db, _ = dbtestutil.NewDB(t)
611610
user = dbgen.User(t, db, database.User{})
612611
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
613612
UserID: user.ID,
@@ -630,7 +629,7 @@ func TestAPIKey(t *testing.T) {
630629
oauthToken := &oauth2.Token{
631630
AccessToken: "wow",
632631
RefreshToken: "moo",
633-
Expiry: dbtime.Now().AddDate(0, 0, 1),
632+
Expiry: dbtestutil.NowInDefaultTimezone().AddDate(0, 0, 1),
634633
}
635634
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
636635
DB: db,
@@ -665,7 +664,7 @@ func TestAPIKey(t *testing.T) {
665664
t.Parallel()
666665
var (
667666
ctx = testutil.Context(t, testutil.WaitShort)
668-
db = dbmem.New()
667+
db, _ = dbtestutil.NewDB(t)
669668
user = dbgen.User(t, db, database.User{})
670669
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
671670
UserID: user.ID,
@@ -715,7 +714,7 @@ func TestAPIKey(t *testing.T) {
715714
t.Run("RemoteIPUpdates", func(t *testing.T) {
716715
t.Parallel()
717716
var (
718-
db = dbmem.New()
717+
db, _ = dbtestutil.NewDB(t)
719718
user = dbgen.User(t, db, database.User{})
720719
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
721720
UserID: user.ID,
@@ -740,15 +739,15 @@ func TestAPIKey(t *testing.T) {
740739
gotAPIKey, err := db.GetAPIKeyByID(r.Context(), sentAPIKey.ID)
741740
require.NoError(t, err)
742741

743-
require.Equal(t, net.ParseIP("1.1.1.1"), gotAPIKey.IPAddress.IPNet.IP)
742+
require.Equal(t, "1.1.1.1", gotAPIKey.IPAddress.IPNet.IP.String())
744743
})
745744

746745
t.Run("RedirectToLogin", func(t *testing.T) {
747746
t.Parallel()
748747
var (
749-
db = dbmem.New()
750-
r = httptest.NewRequest("GET", "/", nil)
751-
rw = httptest.NewRecorder()
748+
db, _ = dbtestutil.NewDB(t)
749+
r = httptest.NewRequest("GET", "/", nil)
750+
rw = httptest.NewRecorder()
752751
)
753752

754753
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
@@ -767,9 +766,9 @@ func TestAPIKey(t *testing.T) {
767766
t.Run("Optional", func(t *testing.T) {
768767
t.Parallel()
769768
var (
770-
db = dbmem.New()
771-
r = httptest.NewRequest("GET", "/", nil)
772-
rw = httptest.NewRecorder()
769+
db, _ = dbtestutil.NewDB(t)
770+
r = httptest.NewRequest("GET", "/", nil)
771+
rw = httptest.NewRecorder()
773772

774773
count int64
775774
handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@@ -798,7 +797,7 @@ func TestAPIKey(t *testing.T) {
798797
t.Run("Tokens", func(t *testing.T) {
799798
t.Parallel()
800799
var (
801-
db = dbmem.New()
800+
db, _ = dbtestutil.NewDB(t)
802801
user = dbgen.User(t, db, database.User{})
803802
sentAPIKey, token = dbgen.APIKey(t, db, database.APIKey{
804803
UserID: user.ID,
@@ -831,7 +830,7 @@ func TestAPIKey(t *testing.T) {
831830
t.Run("MissingConfig", func(t *testing.T) {
832831
t.Parallel()
833832
var (
834-
db = dbmem.New()
833+
db, _ = dbtestutil.NewDB(t)
835834
user = dbgen.User(t, db, database.User{})
836835
_, token = dbgen.APIKey(t, db, database.APIKey{
837836
UserID: user.ID,
@@ -866,7 +865,7 @@ func TestAPIKey(t *testing.T) {
866865
t.Run("CustomRoles", func(t *testing.T) {
867866
t.Parallel()
868867
var (
869-
db = dbmem.New()
868+
db, _ = dbtestutil.NewDB(t)
870869
org = dbgen.Organization(t, db, database.Organization{})
871870
customRole = dbgen.CustomRole(t, db, database.CustomRole{
872871
Name: "custom-role",
@@ -933,7 +932,7 @@ func TestAPIKey(t *testing.T) {
933932
t.Parallel()
934933
var (
935934
roleNotExistsName = "role-not-exists"
936-
db = dbmem.New()
935+
db, _ = dbtestutil.NewDB(t)
937936
org = dbgen.Organization(t, db, database.Organization{})
938937
user = dbgen.User(t, db, database.User{
939938
RBACRoles: []string{

0 commit comments

Comments
 (0)