Skip to content

Commit bc9fdd1

Browse files
authored
fix(enterprise/cli): correctly set default tags for PSK auth (#9436)
* provisionerd: unconditionally set tag scope to org for psk auth * provisionerd: add unit tests for MutateTags * cli: add some informational logging around provisionerd tags * cli: respect CODER_VERBOSE when initializing logger
1 parent 8ee6178 commit bc9fdd1

File tree

3 files changed

+119
-1
lines changed

3 files changed

+119
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package provisionerdserver_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/google/uuid"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/v2/coderd/provisionerdserver"
11+
)
12+
13+
func TestMutateTags(t *testing.T) {
14+
t.Parallel()
15+
16+
testUserID := uuid.New()
17+
18+
for _, tt := range []struct {
19+
name string
20+
userID uuid.UUID
21+
tags map[string]string
22+
want map[string]string
23+
}{
24+
{
25+
name: "nil tags",
26+
userID: uuid.Nil,
27+
tags: nil,
28+
want: map[string]string{
29+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
30+
},
31+
},
32+
{
33+
name: "empty tags",
34+
userID: uuid.Nil,
35+
tags: map[string]string{},
36+
want: map[string]string{
37+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
38+
},
39+
},
40+
{
41+
name: "user scope",
42+
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeUser},
43+
userID: testUserID,
44+
want: map[string]string{
45+
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
46+
provisionerdserver.TagOwner: testUserID.String(),
47+
},
48+
},
49+
{
50+
name: "organization scope",
51+
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeOrganization},
52+
userID: testUserID,
53+
want: map[string]string{
54+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
55+
},
56+
},
57+
{
58+
name: "invalid scope",
59+
tags: map[string]string{provisionerdserver.TagScope: "360noscope"},
60+
userID: testUserID,
61+
want: map[string]string{
62+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
63+
},
64+
},
65+
} {
66+
tt := tt
67+
t.Run(tt.name, func(t *testing.T) {
68+
t.Parallel()
69+
// make a copy of the map because the function under test
70+
// mutates the map
71+
bytes, err := json.Marshal(tt.tags)
72+
require.NoError(t, err)
73+
var tags map[string]string
74+
err = json.Unmarshal(bytes, &tags)
75+
require.NoError(t, err)
76+
got := provisionerdserver.MutateTags(tt.userID, tags)
77+
require.Equal(t, tt.want, got)
78+
})
79+
}
80+
}

enterprise/cli/provisionerdaemons.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/coder/coder/v2/cli/clibase"
1616
"github.com/coder/coder/v2/cli/cliui"
1717
"github.com/coder/coder/v2/coderd/database"
18+
"github.com/coder/coder/v2/coderd/provisionerdserver"
1819
"github.com/coder/coder/v2/codersdk"
1920
"github.com/coder/coder/v2/provisioner/terraform"
2021
"github.com/coder/coder/v2/provisionerd"
@@ -65,6 +66,23 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
6566
return err
6667
}
6768

69+
logger := slog.Make(sloghuman.Sink(inv.Stderr))
70+
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
71+
logger = logger.Leveled(slog.LevelDebug)
72+
}
73+
74+
if len(tags) != 0 {
75+
logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates")
76+
logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details")
77+
}
78+
79+
// When authorizing with a PSK, we automatically scope the provisionerd
80+
// to organization. Scoping to user with PSK auth is not a valid configuration.
81+
if preSharedKey != "" {
82+
logger.Info(ctx, "psk auth automatically sets tag "+provisionerdserver.TagScope+"="+provisionerdserver.ScopeOrganization)
83+
tags[provisionerdserver.TagScope] = provisionerdserver.ScopeOrganization
84+
}
85+
6886
err = os.MkdirAll(cacheDir, 0o700)
6987
if err != nil {
7088
return xerrors.Errorf("mkdir %q: %w", cacheDir, err)
@@ -82,7 +100,6 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
82100
_ = terraformServer.Close()
83101
}()
84102

85-
logger := slog.Make(sloghuman.Sink(inv.Stderr))
86103
errCh := make(chan error, 1)
87104
go func() {
88105
defer cancel()

enterprise/coderd/provisionerdaemons.go

+21
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
117117
if p.psk != "" {
118118
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
119119
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
120+
// If using PSK auth, the daemon is, by definition, scoped to the organization.
121+
tags[provisionerdserver.TagScope] = provisionerdserver.ScopeOrganization
120122
return tags, true
121123
}
122124
}
@@ -172,10 +174,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
172174

173175
tags, authorized := api.provisionerDaemonAuth.authorize(r, tags)
174176
if !authorized {
177+
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
175178
httpapi.Write(ctx, rw, http.StatusForbidden,
176179
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"})
177180
return
178181
}
182+
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))
179183

180184
provisioners := make([]database.ProvisionerType, 0)
181185
for p := range provisionersMap {
@@ -188,6 +192,11 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
188192
}
189193

190194
name := namesgenerator.GetRandomName(1)
195+
log := api.Logger.With(
196+
slog.F("name", name),
197+
slog.F("provisioners", provisioners),
198+
slog.F("tags", tags),
199+
)
191200
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
192201
ID: uuid.New(),
193202
CreatedAt: database.Now(),
@@ -196,6 +205,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
196205
Tags: tags,
197206
})
198207
if err != nil {
208+
if !xerrors.Is(err, context.Canceled) {
209+
log.Error(ctx, "write provisioner daemon", slog.Error(err))
210+
}
199211
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
200212
Message: "Internal error writing provisioner daemon.",
201213
Detail: err.Error(),
@@ -205,6 +217,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
205217

206218
rawTags, err := json.Marshal(daemon.Tags)
207219
if err != nil {
220+
if !xerrors.Is(err, context.Canceled) {
221+
log.Error(ctx, "marshal provisioner tags", slog.Error(err))
222+
}
208223
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
209224
Message: "Internal error marshaling daemon tags.",
210225
Detail: err.Error(),
@@ -222,6 +237,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
222237
CompressionMode: websocket.CompressionDisabled,
223238
})
224239
if err != nil {
240+
if !xerrors.Is(err, context.Canceled) {
241+
log.Error(ctx, "accept provisioner websocket conn", slog.Error(err))
242+
}
225243
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
226244
Message: "Internal error accepting websocket connection.",
227245
Detail: err.Error(),
@@ -267,6 +285,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
267285
},
268286
)
269287
if err != nil {
288+
if !xerrors.Is(err, context.Canceled) {
289+
log.Error(ctx, "create provisioner daemon server", slog.Error(err))
290+
}
270291
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err))
271292
return
272293
}

0 commit comments

Comments
 (0)