Skip to content

Commit 2d90bff

Browse files
committed
Add organization parameter query
1 parent d7f68c5 commit 2d90bff

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

httpmw/organizationparam.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package httpmw
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"errors"
7+
"fmt"
8+
"net/http"
9+
10+
"github.com/go-chi/chi"
11+
12+
"github.com/coder/coder/database"
13+
"github.com/coder/coder/httpapi"
14+
)
15+
16+
type organizationParamContextKey struct{}
17+
type organizationMemberParamContextKey struct{}
18+
19+
// OrganizationParam returns the organization from the ExtractOrganizationParam handler.
20+
func OrganizationParam(r *http.Request) database.Organization {
21+
organization, ok := r.Context().Value(organizationParamContextKey{}).(database.Organization)
22+
if !ok {
23+
panic("developer error: organization param middleware not provided")
24+
}
25+
return organization
26+
}
27+
28+
// OrganizationMemberParam returns the organization membership that allowed the query
29+
// from the ExtractOrganizationParam handler.
30+
func OrganizationMemberParam(r *http.Request) database.OrganizationMember {
31+
organizationMember, ok := r.Context().Value(organizationMemberParamContextKey{}).(database.OrganizationMember)
32+
if !ok {
33+
panic("developer error: organization param middleware not provided")
34+
}
35+
return organizationMember
36+
}
37+
38+
// ExtractOrganizationParam grabs an organization and user membership from the "organization" URL parameter.
39+
// This middleware requires the API key middleware for authentication.
40+
func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler {
41+
return func(next http.Handler) http.Handler {
42+
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
43+
apiKey := APIKey(r)
44+
organizationName := chi.URLParam(r, "organization")
45+
if organizationName == "" {
46+
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
47+
Message: "organization name must be provided",
48+
})
49+
return
50+
}
51+
organization, err := db.GetOrganizationByName(r.Context(), organizationName)
52+
if errors.Is(err, sql.ErrNoRows) {
53+
httpapi.Write(rw, http.StatusNotFound, httpapi.Response{
54+
Message: fmt.Sprintf("organization %q does not exist", organizationName),
55+
})
56+
return
57+
}
58+
if err != nil {
59+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
60+
Message: fmt.Sprintf("get organization: %s", err.Error()),
61+
})
62+
return
63+
}
64+
organizationMember, err := db.GetOrganizationMemberByUserID(r.Context(), database.GetOrganizationMemberByUserIDParams{
65+
OrganizationID: organization.ID,
66+
UserID: apiKey.UserID,
67+
})
68+
if errors.Is(err, sql.ErrNoRows) {
69+
httpapi.Write(rw, http.StatusUnauthorized, httpapi.Response{
70+
Message: "not a member of the organization",
71+
})
72+
return
73+
}
74+
if err != nil {
75+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
76+
Message: fmt.Sprintf("get organization member: %s", err.Error()),
77+
})
78+
return
79+
}
80+
81+
ctx := context.WithValue(r.Context(), organizationParamContextKey{}, organization)
82+
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, organizationMember)
83+
next.ServeHTTP(rw, r.WithContext(ctx))
84+
})
85+
}
86+
}

httpmw/organizationparam_test.go

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
package httpmw_test
2+
3+
import (
4+
"context"
5+
"crypto/sha256"
6+
"fmt"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
"time"
11+
12+
"github.com/go-chi/chi"
13+
"github.com/google/uuid"
14+
"github.com/stretchr/testify/require"
15+
16+
"github.com/coder/coder/cryptorand"
17+
"github.com/coder/coder/database"
18+
"github.com/coder/coder/database/databasefake"
19+
"github.com/coder/coder/httpmw"
20+
)
21+
22+
func TestOrganizationParam(t *testing.T) {
23+
t.Parallel()
24+
25+
setupAuthentication := func(db database.Store, r *http.Request) database.User {
26+
var (
27+
id, secret = randomAPIKeyParts()
28+
hashed = sha256.Sum256([]byte(secret))
29+
)
30+
r.AddCookie(&http.Cookie{
31+
Name: httpmw.AuthCookie,
32+
Value: fmt.Sprintf("%s-%s", id, secret),
33+
})
34+
userID, err := cryptorand.String(16)
35+
require.NoError(t, err)
36+
username, err := cryptorand.String(8)
37+
require.NoError(t, err)
38+
user, err := db.InsertUser(r.Context(), database.InsertUserParams{
39+
ID: userID,
40+
Email: "testaccount@coder.com",
41+
Name: "example",
42+
LoginType: database.LoginTypeBuiltIn,
43+
HashedPassword: hashed[:],
44+
Username: username,
45+
CreatedAt: database.Now(),
46+
UpdatedAt: database.Now(),
47+
})
48+
require.NoError(t, err)
49+
_, err = db.InsertAPIKey(r.Context(), database.InsertAPIKeyParams{
50+
ID: id,
51+
UserID: user.ID,
52+
HashedSecret: hashed[:],
53+
LastUsed: database.Now(),
54+
ExpiresAt: database.Now().Add(time.Minute),
55+
})
56+
require.NoError(t, err)
57+
return user
58+
}
59+
60+
t.Run("None", func(t *testing.T) {
61+
var (
62+
db = databasefake.New()
63+
r = httptest.NewRequest("GET", "/", nil)
64+
rw = httptest.NewRecorder()
65+
_ = setupAuthentication(db, r)
66+
)
67+
httpmw.ExtractAPIKey(db, nil)(httpmw.ExtractOrganizationParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
68+
}))).ServeHTTP(rw, r)
69+
res := rw.Result()
70+
defer res.Body.Close()
71+
require.Equal(t, http.StatusBadRequest, res.StatusCode)
72+
})
73+
74+
t.Run("NotFound", func(t *testing.T) {
75+
var (
76+
db = databasefake.New()
77+
r = httptest.NewRequest("GET", "/", nil)
78+
rw = httptest.NewRecorder()
79+
_ = setupAuthentication(db, r)
80+
)
81+
routeContext := chi.NewRouteContext()
82+
routeContext.URLParams.Add("organization", "example")
83+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
84+
httpmw.ExtractAPIKey(db, nil)(httpmw.ExtractOrganizationParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
85+
}))).ServeHTTP(rw, r)
86+
res := rw.Result()
87+
defer res.Body.Close()
88+
require.Equal(t, http.StatusNotFound, res.StatusCode)
89+
})
90+
91+
t.Run("NotInOrganization", func(t *testing.T) {
92+
var (
93+
db = databasefake.New()
94+
r = httptest.NewRequest("GET", "/", nil)
95+
rw = httptest.NewRecorder()
96+
_ = setupAuthentication(db, r)
97+
)
98+
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
99+
ID: uuid.NewString(),
100+
Name: "test",
101+
CreatedAt: database.Now(),
102+
UpdatedAt: database.Now(),
103+
})
104+
require.NoError(t, err)
105+
routeContext := chi.NewRouteContext()
106+
routeContext.URLParams.Add("organization", organization.Name)
107+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
108+
httpmw.ExtractAPIKey(db, nil)(httpmw.ExtractOrganizationParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
109+
}))).ServeHTTP(rw, r)
110+
res := rw.Result()
111+
defer res.Body.Close()
112+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
113+
})
114+
115+
t.Run("Success", func(t *testing.T) {
116+
var (
117+
db = databasefake.New()
118+
r = httptest.NewRequest("GET", "/", nil)
119+
rw = httptest.NewRecorder()
120+
user = setupAuthentication(db, r)
121+
)
122+
organization, err := db.InsertOrganization(r.Context(), database.InsertOrganizationParams{
123+
ID: uuid.NewString(),
124+
Name: "test",
125+
CreatedAt: database.Now(),
126+
UpdatedAt: database.Now(),
127+
})
128+
require.NoError(t, err)
129+
_, err = db.InsertOrganizationMember(r.Context(), database.InsertOrganizationMemberParams{
130+
OrganizationID: organization.ID,
131+
UserID: user.ID,
132+
CreatedAt: database.Now(),
133+
UpdatedAt: database.Now(),
134+
})
135+
require.NoError(t, err)
136+
routeContext := chi.NewRouteContext()
137+
routeContext.URLParams.Add("organization", organization.Name)
138+
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, routeContext))
139+
httpmw.ExtractAPIKey(db, nil)(httpmw.ExtractOrganizationParam(db)(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
140+
_ = httpmw.OrganizationParam(r)
141+
_ = httpmw.OrganizationMemberParam(r)
142+
}))).ServeHTTP(rw, r)
143+
res := rw.Result()
144+
defer res.Body.Close()
145+
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
146+
})
147+
}

0 commit comments

Comments
 (0)