Skip to content

fix(enterprise/cli): correctly set default tags for PSK auth #9436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions coderd/provisionerdserver/provisionertags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package provisionerdserver_test

import (
"encoding/json"
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/provisionerdserver"
)

func TestMutateTags(t *testing.T) {
t.Parallel()

testUserID := uuid.New()

for _, tt := range []struct {
name string
userID uuid.UUID
tags map[string]string
want map[string]string
}{
{
name: "nil tags",
userID: uuid.Nil,
tags: nil,
want: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
},
{
name: "empty tags",
userID: uuid.Nil,
tags: map[string]string{},
want: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
},
{
name: "user scope",
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeUser},
userID: testUserID,
want: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
provisionerdserver.TagOwner: testUserID.String(),
},
},
{
name: "organization scope",
tags: map[string]string{provisionerdserver.TagScope: provisionerdserver.ScopeOrganization},
userID: testUserID,
want: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
},
{
name: "invalid scope",
tags: map[string]string{provisionerdserver.TagScope: "360noscope"},
userID: testUserID,
want: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// make a copy of the map because the function under test
// mutates the map
bytes, err := json.Marshal(tt.tags)
require.NoError(t, err)
var tags map[string]string
err = json.Unmarshal(bytes, &tags)
require.NoError(t, err)
got := provisionerdserver.MutateTags(tt.userID, tags)
require.Equal(t, tt.want, got)
})
}
}
19 changes: 18 additions & 1 deletion enterprise/cli/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/coder/coder/v2/cli/clibase"
"github.com/coder/coder/v2/cli/cliui"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/provisionerdserver"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/provisioner/terraform"
"github.com/coder/coder/v2/provisionerd"
Expand Down Expand Up @@ -65,6 +66,23 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
return err
}

logger := slog.Make(sloghuman.Sink(inv.Stderr))
if ok, _ := inv.ParsedFlags().GetBool("verbose"); ok {
logger = logger.Leveled(slog.LevelDebug)
}

if len(tags) != 0 {
logger.Info(ctx, "note: tagged provisioners can currently pick up jobs from untagged templates")
logger.Info(ctx, "see https://github.com/coder/coder/issues/6442 for details")
}

// When authorizing with a PSK, we automatically scope the provisionerd
// to organization. Scoping to user with PSK auth is not a valid configuration.
if preSharedKey != "" {
logger.Info(ctx, "psk auth automatically sets tag "+provisionerdserver.TagScope+"="+provisionerdserver.ScopeOrganization)
tags[provisionerdserver.TagScope] = provisionerdserver.ScopeOrganization
}

err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return xerrors.Errorf("mkdir %q: %w", cacheDir, err)
Expand All @@ -82,7 +100,6 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
_ = terraformServer.Close()
}()

logger := slog.Make(sloghuman.Sink(inv.Stderr))
errCh := make(chan error, 1)
go func() {
defer cancel()
Expand Down
21 changes: 21 additions & 0 deletions enterprise/coderd/provisionerdaemons.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
if p.psk != "" {
psk := r.Header.Get(codersdk.ProvisionerDaemonPSK)
if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 {
// If using PSK auth, the daemon is, by definition, scoped to the organization.
tags[provisionerdserver.TagScope] = provisionerdserver.ScopeOrganization
return tags, true
}
}
Expand Down Expand Up @@ -172,10 +174,12 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)

tags, authorized := api.provisionerDaemonAuth.authorize(r, tags)
if !authorized {
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
httpapi.Write(ctx, rw, http.StatusForbidden,
codersdk.Response{Message: "You aren't allowed to create provisioner daemons"})
return
}
api.Logger.Debug(ctx, "provisioner authorized", slog.F("tags", tags))

provisioners := make([]database.ProvisionerType, 0)
for p := range provisionersMap {
Expand All @@ -188,6 +192,11 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
}

name := namesgenerator.GetRandomName(1)
log := api.Logger.With(
slog.F("name", name),
slog.F("provisioners", provisioners),
slog.F("tags", tags),
)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Expand All @@ -196,6 +205,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
Tags: tags,
})
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "write provisioner daemon", slog.Error(err))
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.",
Detail: err.Error(),
Expand All @@ -205,6 +217,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)

rawTags, err := json.Marshal(daemon.Tags)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "marshal provisioner tags", slog.Error(err))
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.",
Detail: err.Error(),
Expand All @@ -222,6 +237,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "accept provisioner websocket conn", slog.Error(err))
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.",
Detail: err.Error(),
Expand Down Expand Up @@ -267,6 +285,9 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
},
)
if err != nil {
if !xerrors.Is(err, context.Canceled) {
log.Error(ctx, "create provisioner daemon server", slog.Error(err))
}
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("create provisioner daemon server: %s", err))
return
}
Expand Down