Skip to content

Commit 95b51ed

Browse files
committed
fix(provisionerd): correctly mutate default tags for PSK auth
1 parent a910e93 commit 95b51ed

File tree

4 files changed

+113
-2
lines changed

4 files changed

+113
-2
lines changed

coderd/provisionerdserver/provisionertags.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@ func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
2424
}
2525
switch tags[TagScope] {
2626
case ScopeUser:
27-
tags[TagOwner] = userID.String()
27+
if userID != uuid.Nil {
28+
tags[TagOwner] = userID.String()
29+
}
2830
case ScopeOrganization:
2931
default:
3032
tags[TagScope] = ScopeOrganization
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package provisionerdserver_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/google/uuid"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/v2/coderd/provisionerdserver"
11+
)
12+
13+
func TestMutateTags(t *testing.T) {
14+
t.Parallel()
15+
16+
testUserID := uuid.New()
17+
18+
for _, tt := range []struct {
19+
name string
20+
userID uuid.UUID
21+
tags map[string]string
22+
want map[string]string
23+
}{
24+
{
25+
name: "nil tags",
26+
userID: uuid.Nil,
27+
tags: nil,
28+
want: map[string]string{
29+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
30+
},
31+
},
32+
{
33+
name: "empty tags",
34+
userID: uuid.Nil,
35+
tags: map[string]string{},
36+
want: map[string]string{
37+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
38+
},
39+
},
40+
{
41+
name: "user scope",
42+
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeUser},
43+
userID: testUserID,
44+
want: map[string]string{
45+
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
46+
provisionerdserver.TagOwner: testUserID.String(),
47+
},
48+
},
49+
{
50+
name: "organization scope",
51+
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeOrganization},
52+
userID: testUserID,
53+
want: map[string]string{
54+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
55+
},
56+
},
57+
{
58+
name: "invalid scope",
59+
tags: map[string]string{provisionerdserver.TagScope: "360noscope"},
60+
userID: testUserID,
61+
want: map[string]string{
62+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
63+
},
64+
},
65+
} {
66+
tt := tt
67+
t.Run(tt.name, func(t *testing.T) {
68+
t.Parallel()
69+
// make a copy of the map because the function under test
70+
// mutates the map
71+
bytes, err := json.Marshal(tt.tags)
72+
require.NoError(t, err)
73+
var tags map[string]string
74+
err = json.Unmarshal(bytes, &tags)
75+
require.NoError(t, err)
76+
got := provisionerdserver.MutateTags(tt.userID, tags)
77+
require.Equal(t, tt.want, got)
78+
})
79+
}
80+
}

enterprise/cli/provisionerdaemons.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
8383
}()
8484

8585
logger := slog.Make(sloghuman.Sink(inv.Stderr))
86+
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
87+
logger = logger.Leveled(slog.LevelDebug)
88+
}
8689
errCh := make(chan error, 1)
8790
go func() {
8891
defer cancel()

enterprise/coderd/provisionerdaemons.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
117117
if p.psk != "" {
118118
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
119119
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
120-
return tags, true
120+
return provisionerdserver.MutateTags(uuid.Nil, tags), true
121121
}
122122
}
123123
return nil, false
@@ -172,10 +172,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
172172

173173
tags, authorized := api.provisionerDaemonAuth.authorize(r, tags)
174174
if !authorized {
175+
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
175176
httpapi.Write(ctx, rw, http.StatusForbidden,
176177
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"})
177178
return
178179
}
180+
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
179181

180182
provisioners := make([]database.ProvisionerType, 0)
181183
for p := range provisionersMap {
@@ -196,6 +198,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
196198
Tags: tags,
197199
})
198200
if err != nil {
201+
api.Logger.Error(ctx, "write provisioner daemon",
202+
slog.F("name", name),
203+
slog.F("provisioners", provisioners),
204+
slog.F("tags", tags),
205+
slog.Error(err),
206+
)
199207
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
200208
Message: "Internal error writing provisioner daemon.",
201209
Detail: err.Error(),
@@ -205,6 +213,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
205213

206214
rawTags, err := json.Marshal(daemon.Tags)
207215
if err != nil {
216+
api.Logger.Error(ctx, "marshal provisioner tags",
217+
slog.F("name", name),
218+
slog.F("provisioners", provisioners),
219+
slog.F("tags", tags),
220+
slog.Error(err),
221+
)
208222
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
209223
Message: "Internal error marshaling daemon tags.",
210224
Detail: err.Error(),
@@ -222,6 +236,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
222236
CompressionMode: websocket.CompressionDisabled,
223237
})
224238
if err != nil {
239+
api.Logger.Error(ctx, "accept provisioner websocket conn",
240+
slog.F("name", name),
241+
slog.F("provisioners", provisioners),
242+
slog.F("tags", tags),
243+
slog.Error(err),
244+
)
225245
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
226246
Message: "Internal error accepting websocket connection.",
227247
Detail: err.Error(),
@@ -267,6 +287,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
267287
},
268288
)
269289
if err != nil {
290+
api.Logger.Error(ctx, "create provisioner daemon server",
291+
slog.F("name", name),
292+
slog.F("provisioners", provisioners),
293+
slog.F("tags", tags),
294+
slog.Error(err),
295+
)
270296
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err))
271297
return
272298
}

0 commit comments

Comments
 (0)