Skip to content

Commit 5c6974e

Browse files
authored
feat: implement provisioner auth middleware and proper org params (coder#12330)
* feat: provisioner auth in mw to allow ExtractOrg Step to enable org scoped provisioner daemons * chore: handle default org handling for provisioner daemons
1 parent 926fd7f commit 5c6974e

File tree

11 files changed

+201
-30
lines changed

11 files changed

+201
-30
lines changed

coderd/database/dbauthz/dbauthz.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,9 @@ var (
170170
rbac.ResourceWorkspaceBuild.Type: {rbac.ActionRead, rbac.ActionUpdate, rbac.ActionDelete},
171171
rbac.ResourceUserData.Type: {rbac.ActionRead, rbac.ActionUpdate},
172172
rbac.ResourceAPIKey.Type: {rbac.WildcardSymbol},
173+
// When org scoped provisioner credentials are implemented,
174+
// this can be reduced to read a specific org.
175+
rbac.ResourceOrganization.Type: {rbac.ActionRead},
173176
}),
174177
Org: map[string][]rbac.Permission{},
175178
User: []rbac.Permission{},

coderd/httpmw/actor.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,32 @@ func RequireAPIKeyOrWorkspaceAgent() func(http.Handler) http.Handler {
6464
})
6565
}
6666
}
67+
68+
// RequireAPIKeyOrProvisionerDaemonAuth is middleware that should be inserted
69+
// after optional ExtractAPIKey and ExtractProvisionerDaemonAuthenticated
70+
// middlewares to ensure one of the two authentication methods is provided.
71+
//
72+
// If both are provided, an error is returned to avoid misuse.
73+
func RequireAPIKeyOrProvisionerDaemonAuth() func(http.Handler) http.Handler {
74+
return func(next http.Handler) http.Handler {
75+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76+
_, hasAPIKey := APIKeyOptional(r)
77+
hasProvisionerDaemon := ProvisionerDaemonAuthenticated(r)
78+
79+
if hasAPIKey && hasProvisionerDaemon {
80+
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
81+
Message: "API key and external provisioner authentication provided, but only one is allowed",
82+
})
83+
return
84+
}
85+
if !hasAPIKey && !hasProvisionerDaemon {
86+
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
87+
Message: "API key or external provisioner authentication required, but none provided",
88+
})
89+
return
90+
}
91+
92+
next.ServeHTTP(w, r)
93+
})
94+
}
95+
}

coderd/httpmw/organizationparam.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,26 +53,41 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler
5353
}
5454

5555
var organization database.Organization
56-
var err error
57-
// Try by name or uuid.
58-
id, err := uuid.Parse(arg)
59-
if err == nil {
60-
organization, err = db.GetOrganizationByID(ctx, id)
56+
var dbErr error
57+
58+
// If the name is exactly "default", then we fetch the default
59+
// organization. This is a special case to make it easier
60+
// for single org deployments.
61+
//
62+
// arg == uuid.Nil.String() should be a temporary workaround for
63+
// legacy provisioners that don't provide an organization ID.
64+
// This prevents a breaking change.
65+
// TODO: This change was added March 2024. Nil uuid returning the
66+
// default org should be removed some number of months after
67+
// that date.
68+
if arg == codersdk.DefaultOrganization || arg == uuid.Nil.String() {
69+
organization, dbErr = db.GetDefaultOrganization(ctx)
6170
} else {
62-
organization, err = db.GetOrganizationByName(ctx, arg)
71+
// Try by name or uuid.
72+
id, err := uuid.Parse(arg)
73+
if err == nil {
74+
organization, dbErr = db.GetOrganizationByID(ctx, id)
75+
} else {
76+
organization, dbErr = db.GetOrganizationByName(ctx, arg)
77+
}
6378
}
64-
if httpapi.Is404Error(err) {
79+
if httpapi.Is404Error(dbErr) {
6580
httpapi.ResourceNotFound(rw)
6681
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
6782
Message: fmt.Sprintf("Organization %q not found.", arg),
6883
Detail: "Provide either the organization id or name.",
6984
})
7085
return
7186
}
72-
if err != nil {
87+
if dbErr != nil {
7388
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
7489
Message: fmt.Sprintf("Internal error fetching organization %q.", arg),
75-
Detail: err.Error(),
90+
Detail: dbErr.Error(),
7691
})
7792
return
7893
}

coderd/httpmw/organizationparam_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,5 +208,24 @@ func TestOrganizationParam(t *testing.T) {
208208
res = rw.Result()
209209
defer res.Body.Close()
210210
require.Equal(t, http.StatusOK, res.StatusCode, "by name")
211+
212+
// Try by 'default'
213+
chi.RouteContext(r.Context()).URLParams.Add("organization", codersdk.DefaultOrganization)
214+
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
215+
rtr.ServeHTTP(rw, r)
216+
res = rw.Result()
217+
defer res.Body.Close()
218+
require.Equal(t, http.StatusOK, res.StatusCode, "by default keyword")
219+
220+
// Try by legacy
221+
// TODO: This can be removed when legacy nil uuids are no longer supported.
222+
// This is a temporary measure to ensure as legacy provisioners use
223+
// nil uuids as the org id and expect the default org.
224+
chi.RouteContext(r.Context()).URLParams.Add("organization", uuid.Nil.String())
225+
chi.RouteContext(r.Context()).URLParams.Add("user", user.ID.String())
226+
rtr.ServeHTTP(rw, r)
227+
res = rw.Result()
228+
defer res.Body.Close()
229+
require.Equal(t, http.StatusOK, res.StatusCode, "by nil uuid (legacy)")
211230
})
212231
}

coderd/httpmw/provisionerdaemon.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
package httpmw
2+
3+
import (
4+
"context"
5+
"crypto/subtle"
6+
"net/http"
7+
8+
"golang.org/x/xerrors"
9+
10+
"github.com/coder/coder/v2/coderd/database"
11+
"github.com/coder/coder/v2/coderd/database/dbauthz"
12+
"github.com/coder/coder/v2/coderd/httpapi"
13+
"github.com/coder/coder/v2/codersdk"
14+
)
15+
16+
type provisionerDaemonContextKey struct{}
17+
18+
func ProvisionerDaemonAuthenticated(r *http.Request) bool {
19+
proxy, ok := r.Context().Value(provisionerDaemonContextKey{}).(bool)
20+
return ok && proxy
21+
}
22+
23+
type ExtractProvisionerAuthConfig struct {
24+
DB database.Store
25+
Optional bool
26+
}
27+
28+
func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler {
29+
return func(next http.Handler) http.Handler {
30+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31+
ctx := r.Context()
32+
33+
handleOptional := func(code int, response codersdk.Response) {
34+
if opts.Optional {
35+
next.ServeHTTP(w, r)
36+
return
37+
}
38+
httpapi.Write(ctx, w, code, response)
39+
}
40+
41+
if psk == "" {
42+
// No psk means external provisioner daemons are not allowed.
43+
// So their auth is not valid.
44+
handleOptional(http.StatusBadRequest, codersdk.Response{
45+
Message: "External provisioner daemons not enabled",
46+
})
47+
return
48+
}
49+
50+
token := r.Header.Get(codersdk.ProvisionerDaemonPSK)
51+
if token == "" {
52+
handleOptional(http.StatusUnauthorized, codersdk.Response{
53+
Message: "provisioner daemon auth token required",
54+
})
55+
return
56+
}
57+
58+
if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 {
59+
handleOptional(http.StatusUnauthorized, codersdk.Response{
60+
Message: "provisioner daemon auth token invalid",
61+
})
62+
return
63+
}
64+
65+
// The PSK does not indicate a specific provisioner daemon. So just
66+
// store a boolean so the caller can check if the request is from an
67+
// authenticated provisioner daemon.
68+
ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true)
69+
// nolint:gocritic // Authenticating as a provisioner daemon.
70+
ctx = dbauthz.AsProvisionerd(ctx)
71+
subj, ok := dbauthz.ActorFromContext(ctx)
72+
if !ok {
73+
// This should never happen
74+
httpapi.InternalServerError(w, xerrors.New("developer error: ExtractProvisionerDaemonAuth missing rbac actor"))
75+
}
76+
77+
// Use the same subject for the userAuthKey
78+
ctx = context.WithValue(ctx, userAuthKey{}, Authorization{
79+
Actor: subj,
80+
ActorName: "provisioner_daemon",
81+
})
82+
83+
next.ServeHTTP(w, r.WithContext(ctx))
84+
})
85+
}
86+
}

coderd/organizations.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ func (api *API) postOrganizations(rw http.ResponseWriter, r *http.Request) {
5050
return
5151
}
5252

53+
if req.Name == codersdk.DefaultOrganization {
54+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
55+
Message: fmt.Sprintf("Organization name %q is reserved.", codersdk.DefaultOrganization),
56+
})
57+
return
58+
}
59+
5360
_, err := api.Database.GetOrganizationByName(ctx, req.Name)
5461
if err == nil {
5562
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{

codersdk/organizations.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import (
1111
"golang.org/x/xerrors"
1212
)
1313

14+
// DefaultOrganization is used as a replacement for the default organization.
15+
var DefaultOrganization = "default"
16+
1417
type ProvisionerStorageMethod string
1518

1619
const (

codersdk/provisionerdaemons.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ type ServeProvisionerDaemonRequest struct {
179179
ID uuid.UUID `json:"id" format:"uuid"`
180180
// Name is the human-readable unique identifier for the daemon.
181181
Name string `json:"name" example:"my-cool-provisioner-daemon"`
182-
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations
183-
// and so the organization ID is optional.
182+
// Organization is the organization for the URL. If no orgID is provided,
183+
// then it is assumed to use the default organization.
184184
Organization uuid.UUID `json:"organization" format:"uuid"`
185185
// Provisioners is a list of provisioner types hosted by the provisioner daemon
186186
Provisioners []ProvisionerType `json:"provisioners"`
@@ -194,7 +194,12 @@ type ServeProvisionerDaemonRequest struct {
194194
// implementation. The context is during dial, not during the lifetime of the
195195
// client. Client should be closed after use.
196196
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
197-
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization))
197+
orgParam := req.Organization.String()
198+
if req.Organization == uuid.Nil {
199+
orgParam = DefaultOrganization
200+
}
201+
202+
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", orgParam))
198203
if err != nil {
199204
return nil, xerrors.Errorf("parse url: %w", err)
200205
}

enterprise/coderd/coderd.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
292292
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
293293
r.Use(
294294
api.provisionerDaemonsEnabledMW,
295+
apiKeyMiddlewareOptional,
296+
httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{
297+
DB: api.Database,
298+
Optional: true,
299+
}, api.ProvisionerDaemonPSK),
300+
// Either a user auth or provisioner auth is required
301+
// to move forward.
302+
httpmw.RequireAPIKeyOrProvisionerDaemonAuth(),
303+
httpmw.ExtractOrganizationParam(api.Database),
295304
)
296305
r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons)
297306
r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe)

enterprise/coderd/provisionerdaemons.go

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package coderd
22

33
import (
44
"context"
5-
"crypto/subtle"
65
"database/sql"
76
"errors"
87
"fmt"
@@ -86,11 +85,8 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
8685
})
8786
return
8887
}
89-
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
90-
for _, daemon := range daemons {
91-
apiDaemons = append(apiDaemons, db2sdk.ProvisionerDaemon(daemon))
92-
}
93-
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
88+
89+
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
9490
}
9591

9692
type provisionerDaemonAuth struct {
@@ -118,13 +114,11 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
118114
}
119115

120116
// Check for PSK
121-
if p.psk != "" {
122-
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
123-
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
124-
// If using PSK auth, the daemon is, by definition, scoped to the organization.
125-
tags = provisionersdk.MutateTags(uuid.Nil, tags)
126-
return tags, true
127-
}
117+
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
118+
if provAuth {
119+
// If using PSK auth, the daemon is, by definition, scoped to the organization.
120+
tags = provisionersdk.MutateTags(uuid.Nil, tags)
121+
return tags, true
128122
}
129123
return nil, false
130124
}

enterprise/coderd/provisionerdaemons_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
350350

351351
t.Run("PSK_daily_cost", func(t *testing.T) {
352352
t.Parallel()
353+
const provPSK = `provisionersftw`
353354
client, user := coderdenttest.New(t, &coderdenttest.Options{
354355
UserWorkspaceQuota: 10,
355356
LicenseOptions: &coderdenttest.LicenseOptions{
@@ -358,7 +359,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
358359
codersdk.FeatureTemplateRBAC: 1,
359360
},
360361
},
361-
ProvisionerDaemonPSK: "provisionersftw",
362+
ProvisionerDaemonPSK: provPSK,
362363
})
363364
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
364365
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -397,7 +398,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
397398
Tags: map[string]string{
398399
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
399400
},
400-
PreSharedKey: "provisionersftw",
401+
PreSharedKey: provPSK,
401402
})
402403
}, &provisionerd.Options{
403404
Logger: logger.Named("provisionerd"),
@@ -480,7 +481,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
480481
require.Error(t, err)
481482
var apiError *codersdk.Error
482483
require.ErrorAs(t, err, &apiError)
483-
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
484+
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
484485

485486
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
486487
require.NoError(t, err)
@@ -514,7 +515,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
514515
require.Error(t, err)
515516
var apiError *codersdk.Error
516517
require.ErrorAs(t, err, &apiError)
517-
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
518+
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
518519

519520
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
520521
require.NoError(t, err)
@@ -548,7 +549,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
548549
require.Error(t, err)
549550
var apiError *codersdk.Error
550551
require.ErrorAs(t, err, &apiError)
551-
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
552+
require.Equal(t, http.StatusUnauthorized, apiError.StatusCode())
552553

553554
daemons, err := client.ProvisionerDaemons(ctx) //nolint:gocritic // Test assertion.
554555
require.NoError(t, err)

0 commit comments

Comments
 (0)