Skip to content

Commit d4d9bf9

Browse files
committed
tests for RequireAPIKeyOrWorkspaceProxyAuth
1 parent cfe484c commit d4d9bf9

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

coderd/httpmw/actor_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package httpmw_test
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
"net/http/httputil"
8+
"sync/atomic"
9+
"testing"
10+
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/coder/coder/coderd/database"
14+
"github.com/coder/coder/coderd/database/dbfake"
15+
"github.com/coder/coder/coderd/database/dbgen"
16+
"github.com/coder/coder/coderd/httpmw"
17+
"github.com/coder/coder/codersdk"
18+
)
19+
20+
func TestRequireAPIKeyOrWorkspaceProxyAuth(t *testing.T) {
21+
t.Parallel()
22+
23+
t.Run("None", func(t *testing.T) {
24+
t.Parallel()
25+
26+
r := httptest.NewRequest(http.MethodGet, "/", nil)
27+
rw := httptest.NewRecorder()
28+
29+
httpmw.RequireAPIKeyOrWorkspaceProxyAuth()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30+
t.Error("should not have been called")
31+
})).ServeHTTP(rw, r)
32+
33+
require.Equal(t, http.StatusUnauthorized, rw.Code)
34+
})
35+
36+
t.Run("APIKey", func(t *testing.T) {
37+
t.Parallel()
38+
39+
var (
40+
db = dbfake.New()
41+
user = dbgen.User(t, db, database.User{})
42+
_, token = dbgen.APIKey(t, db, database.APIKey{
43+
UserID: user.ID,
44+
ExpiresAt: database.Now().AddDate(0, 0, 1),
45+
})
46+
47+
r = httptest.NewRequest("GET", "/", nil)
48+
rw = httptest.NewRecorder()
49+
)
50+
r.Header.Set(codersdk.SessionTokenHeader, token)
51+
52+
var called int64
53+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
54+
DB: db,
55+
RedirectToLogin: false,
56+
})(
57+
httpmw.RequireAPIKeyOrWorkspaceProxyAuth()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
atomic.AddInt64(&called, 1)
59+
rw.WriteHeader(http.StatusOK)
60+
}))).
61+
ServeHTTP(rw, r)
62+
63+
res := rw.Result()
64+
defer res.Body.Close()
65+
dump, err := httputil.DumpResponse(res, true)
66+
require.NoError(t, err)
67+
t.Log(string(dump))
68+
69+
require.Equal(t, http.StatusOK, rw.Code)
70+
require.Equal(t, int64(1), atomic.LoadInt64(&called))
71+
})
72+
73+
t.Run("WorkspaceProxy", func(t *testing.T) {
74+
t.Parallel()
75+
76+
var (
77+
db = dbfake.New()
78+
user = dbgen.User(t, db, database.User{})
79+
_, userToken = dbgen.APIKey(t, db, database.APIKey{
80+
UserID: user.ID,
81+
ExpiresAt: database.Now().AddDate(0, 0, 1),
82+
})
83+
proxy, proxyToken = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
84+
85+
r = httptest.NewRequest("GET", "/", nil)
86+
rw = httptest.NewRecorder()
87+
)
88+
r.Header.Set(codersdk.SessionTokenHeader, userToken)
89+
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", proxy.ID, proxyToken))
90+
91+
httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{
92+
DB: db,
93+
RedirectToLogin: false,
94+
})(
95+
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
96+
DB: db,
97+
})(
98+
httpmw.RequireAPIKeyOrWorkspaceProxyAuth()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
99+
rw.WriteHeader(http.StatusOK)
100+
})))).
101+
ServeHTTP(rw, r)
102+
103+
res := rw.Result()
104+
defer res.Body.Close()
105+
dump, err := httputil.DumpResponse(res, true)
106+
require.NoError(t, err)
107+
t.Log(string(dump))
108+
109+
require.Equal(t, http.StatusBadRequest, rw.Code)
110+
})
111+
112+
t.Run("Both", func(t *testing.T) {
113+
t.Parallel()
114+
115+
var (
116+
db = dbfake.New()
117+
proxy, token = dbgen.WorkspaceProxy(t, db, database.WorkspaceProxy{})
118+
119+
r = httptest.NewRequest("GET", "/", nil)
120+
rw = httptest.NewRecorder()
121+
)
122+
r.Header.Set(httpmw.WorkspaceProxyAuthTokenHeader, fmt.Sprintf("%s:%s", proxy.ID, token))
123+
124+
var called int64
125+
httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{
126+
DB: db,
127+
})(
128+
httpmw.RequireAPIKeyOrWorkspaceProxyAuth()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
129+
atomic.AddInt64(&called, 1)
130+
rw.WriteHeader(http.StatusOK)
131+
}))).
132+
ServeHTTP(rw, r)
133+
134+
res := rw.Result()
135+
defer res.Body.Close()
136+
dump, err := httputil.DumpResponse(res, true)
137+
require.NoError(t, err)
138+
t.Log(string(dump))
139+
140+
require.Equal(t, http.StatusOK, rw.Code)
141+
require.Equal(t, int64(1), atomic.LoadInt64(&called))
142+
})
143+
}

0 commit comments

Comments
 (0)