Skip to content

Commit a4cecbe

Browse files
committed
use response struct
1 parent bb7ec5e commit a4cecbe

File tree

1 file changed

+41
-16
lines changed

1 file changed

+41
-16
lines changed

enterprise/coderd/provisionerdaemons.go

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
7474
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.List(daemons, db2sdk.ProvisionerDaemon))
7575
}
7676

77+
type provisiionerDaemonAuthResponse struct {
78+
keyID uuid.UUID
79+
orgID uuid.UUID
80+
tags map[string]string
81+
}
82+
7783
type provisionerDaemonAuth struct {
7884
psk string
7985
db database.Store
@@ -82,68 +88,85 @@ type provisionerDaemonAuth struct {
8288

8389
// authorize returns mutated tags if the given HTTP request is authorized to access the provisioner daemon
8490
// protobuf API, and returns nil, err otherwise.
85-
func (p *provisionerDaemonAuth) authorize(r *http.Request, org database.Organization, tags map[string]string) (uuid.UUID, uuid.UUID, map[string]string, error) {
91+
func (p *provisionerDaemonAuth) authorize(r *http.Request, org database.Organization, tags map[string]string) (provisiionerDaemonAuthResponse, error) {
8692
ctx := r.Context()
8793
apiKey, apiKeyOK := httpmw.APIKeyOptional(r)
8894
pk, pkOK := httpmw.ProvisionerKeyAuthOptional(r)
8995
provAuth := httpmw.ProvisionerDaemonAuthenticated(r)
9096
if !provAuth && !apiKeyOK {
91-
return uuid.Nil, uuid.Nil, nil, xerrors.New("no API key or provisioner key provided")
97+
return provisiionerDaemonAuthResponse{}, xerrors.New("no API key or provisioner key provided")
9298
}
9399
if apiKeyOK && pkOK {
94-
return uuid.Nil, uuid.Nil, nil, xerrors.New("Both API key and provisioner key authentication provided. Only one is allowed.")
100+
return provisiionerDaemonAuthResponse{}, xerrors.New("Both API key and provisioner key authentication provided. Only one is allowed.")
95101
}
96102

97103
// Provisioner Key Auth
98104
if pkOK {
99105
if tags != nil && !maps.Equal(tags, map[string]string{}) {
100-
return uuid.Nil, uuid.Nil, nil, xerrors.New("tags are not allowed when using a provisioner key")
106+
return provisiionerDaemonAuthResponse{}, xerrors.New("tags are not allowed when using a provisioner key")
101107
}
102108

103109
// If using provisioner key / PSK auth, the daemon is, by definition, scoped to the organization.
104110
// Use the provisioner key tags here.
105111
tags = provisionersdk.MutateTags(uuid.Nil, pk.Tags)
106-
return pk.ID, pk.OrganizationID, tags, nil
112+
return provisiionerDaemonAuthResponse{
113+
keyID: pk.ID,
114+
orgID: pk.OrganizationID,
115+
tags: tags,
116+
}, nil
107117
}
108118

109119
// PSK Auth
110120
if provAuth {
111121
if !org.IsDefault {
112-
return uuid.Nil, uuid.Nil, nil, xerrors.Errorf("PSK auth is only allowed for the default organization '%s'", org.Name)
122+
return provisiionerDaemonAuthResponse{}, xerrors.Errorf("PSK auth is only allowed for the default organization '%s'", org.Name)
113123
}
114124

115125
pskKey, err := uuid.Parse(codersdk.ProvisionerKeyIDPSK)
116126
if err != nil {
117-
return uuid.Nil, uuid.Nil, nil, xerrors.Errorf("parse psk provisioner key id: %w", err)
127+
return provisiionerDaemonAuthResponse{}, xerrors.Errorf("parse psk provisioner key id: %w", err)
118128
}
119129

120130
tags = provisionersdk.MutateTags(uuid.Nil, tags)
121-
return pskKey, org.ID, tags, nil
131+
132+
return provisiionerDaemonAuthResponse{
133+
keyID: pskKey,
134+
orgID: org.ID,
135+
tags: tags,
136+
}, nil
122137
}
123138

124139
// User Auth
125140
if !apiKeyOK {
126-
return uuid.Nil, uuid.Nil, nil, xerrors.New("no API key provided")
141+
return provisiionerDaemonAuthResponse{}, xerrors.New("no API key provided")
127142
}
128143

129144
userKey, err := uuid.Parse(codersdk.ProvisionerKeyIDUserAuth)
130145
if err != nil {
131-
return uuid.Nil, uuid.Nil, nil, xerrors.Errorf("parse user provisioner key id: %w", err)
146+
return provisiionerDaemonAuthResponse{}, xerrors.Errorf("parse user provisioner key id: %w", err)
132147
}
133148

134149
tags = provisionersdk.MutateTags(apiKey.UserID, tags)
135150
if tags[provisionersdk.TagScope] == provisionersdk.ScopeUser {
136151
// Any authenticated user can create provisioner daemons scoped
137152
// for jobs that they own,
138-
return userKey, org.ID, tags, nil
153+
return provisiionerDaemonAuthResponse{
154+
keyID: userKey,
155+
orgID: org.ID,
156+
tags: tags,
157+
}, nil
139158
}
140159
ua := httpmw.UserAuthorization(r)
141160
err = p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(org.ID))
142161
if err != nil {
143-
return uuid.Nil, uuid.Nil, nil, xerrors.New("user unauthorized")
162+
return provisiionerDaemonAuthResponse{}, xerrors.New("user unauthorized")
144163
}
145164

146-
return userKey, org.ID, tags, nil
165+
return provisiionerDaemonAuthResponse{
166+
keyID: userKey,
167+
orgID: org.ID,
168+
tags: tags,
169+
}, nil
147170
}
148171

149172
// Serves the provisioner daemon protobuf API over a WebSocket.
@@ -205,7 +228,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
205228
api.Logger.Warn(ctx, "unnamed provisioner daemon")
206229
}
207230

208-
keyID, orgID, tags, err := api.provisionerDaemonAuth.authorize(r, httpmw.OrganizationParam(r), tags)
231+
authRes, err := api.provisionerDaemonAuth.authorize(r, httpmw.OrganizationParam(r), tags)
209232
if err != nil {
210233
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags), slog.Error(err))
211234
httpapi.Write(ctx, rw, http.StatusForbidden,
@@ -216,6 +239,8 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
216239
)
217240
return
218241
}
242+
tags = authRes.tags
243+
219244
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
220245
if err := provisionerdserver.Tags(tags).Valid(); err != nil {
221246
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -277,8 +302,8 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
277302
LastSeenAt: sql.NullTime{Time: now, Valid: true},
278303
Version: versionHdrVal,
279304
APIVersion: apiVersion,
280-
OrganizationID: orgID,
281-
KeyID: keyID,
305+
OrganizationID: authRes.orgID,
306+
KeyID: authRes.keyID,
282307
})
283308
if err != nil {
284309
if !xerrors.Is(err, context.Canceled) {

0 commit comments

Comments
 (0)