diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 82f26c31da3e6..1dab5dbf96172 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -245,6 +245,7 @@ var ( rbac.ResourceOrganization.Type: {policy.ActionCreate, policy.ActionRead}, rbac.ResourceOrganizationMember.Type: {policy.ActionCreate}, rbac.ResourceProvisionerDaemon.Type: {policy.ActionCreate, policy.ActionUpdate}, + rbac.ResourceProvisionerKeys.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionDelete}, rbac.ResourceUser.Type: rbac.ResourceUser.AvailableActions(), rbac.ResourceWorkspaceDormant.Type: {policy.ActionUpdate, policy.ActionDelete, policy.ActionWorkspaceStop}, rbac.ResourceWorkspace.Type: {policy.ActionUpdate, policy.ActionDelete, policy.ActionWorkspaceStart, policy.ActionWorkspaceStop, policy.ActionSSH}, diff --git a/coderd/httpmw/csrf.go b/coderd/httpmw/csrf.go index 2bb0dd0a20037..e868019bac23b 100644 --- a/coderd/httpmw/csrf.go +++ b/coderd/httpmw/csrf.go @@ -93,6 +93,13 @@ func CSRF(secureCookie bool) func(next http.Handler) http.Handler { return true } + if r.Header.Get(codersdk.ProvisionerDaemonKey) != "" { + // If present, the provisioner daemon also is providing an api key + // that will make them exempt from CSRF. But this is still useful + // for enumerating the external auths. + return true + } + // If the X-CSRF-TOKEN header is set, we can exempt the func if it's valid. // This is the CSRF check. sent := r.Header.Get("X-CSRF-TOKEN") diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go index d0fbfe0e6bcf4..243af82598ff8 100644 --- a/coderd/httpmw/provisionerdaemon.go +++ b/coderd/httpmw/provisionerdaemon.go @@ -8,6 +8,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/httpapi" + "github.com/coder/coder/v2/coderd/provisionerkey" "github.com/coder/coder/v2/codersdk" ) @@ -19,11 +20,13 @@ func ProvisionerDaemonAuthenticated(r *http.Request) bool { } type ExtractProvisionerAuthConfig struct { - DB database.Store - Optional bool + DB database.Store + Optional bool + PSK string + MultiOrgEnabled bool } -func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, psk string) func(next http.Handler) http.Handler { +func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -36,37 +39,103 @@ func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig, ps httpapi.Write(ctx, w, code, response) } - if psk == "" { - // No psk means external provisioner daemons are not allowed. - // So their auth is not valid. + if !opts.MultiOrgEnabled { + if opts.PSK == "" { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "External provisioner daemons not enabled", + }) + return + } + + fallbackToPSK(ctx, opts.PSK, next, w, r, handleOptional) + return + } + + psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) + key := r.Header.Get(codersdk.ProvisionerDaemonKey) + if key == "" { + if opts.PSK == "" { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "provisioner daemon key required", + }) + return + } + + fallbackToPSK(ctx, opts.PSK, next, w, r, handleOptional) + return + } + if psk != "" { handleOptional(http.StatusBadRequest, codersdk.Response{ - Message: "External provisioner daemons not enabled", + Message: "provisioner daemon key and psk provided, but only one is allowed", }) return } - token := r.Header.Get(codersdk.ProvisionerDaemonPSK) - if token == "" { + id, keyValue, err := provisionerkey.Parse(key) + if err != nil { handleOptional(http.StatusUnauthorized, codersdk.Response{ - Message: "provisioner daemon auth token required", + Message: "provisioner daemon key invalid", + }) + return + } + + // nolint:gocritic // System must check if the provisioner key is valid. + pk, err := opts.DB.GetProvisionerKeyByID(dbauthz.AsSystemRestricted(ctx), id) + if err != nil { + if httpapi.Is404Error(err) { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "provisioner daemon key invalid", + }) + return + } + + handleOptional(http.StatusInternalServerError, codersdk.Response{ + Message: "get provisioner daemon key: " + err.Error(), }) return } - if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 { + if provisionerkey.Compare(pk.HashedSecret, provisionerkey.HashSecret(keyValue)) { handleOptional(http.StatusUnauthorized, codersdk.Response{ - Message: "provisioner daemon auth token invalid", + Message: "provisioner daemon key invalid", }) return } - // The PSK does not indicate a specific provisioner daemon. So just + // The provisioner key does not indicate a specific provisioner daemon. So just // store a boolean so the caller can check if the request is from an // authenticated provisioner daemon. ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true) + // store key used to authenticate the request + ctx = context.WithValue(ctx, provisionerKeyAuthContextKey{}, pk) // nolint:gocritic // Authenticating as a provisioner daemon. ctx = dbauthz.AsProvisionerd(ctx) next.ServeHTTP(w, r.WithContext(ctx)) }) } } + +type provisionerKeyAuthContextKey struct{} + +func ProvisionerKeyAuthOptional(r *http.Request) (database.ProvisionerKey, bool) { + user, ok := r.Context().Value(provisionerKeyAuthContextKey{}).(database.ProvisionerKey) + return user, ok +} + +func fallbackToPSK(ctx context.Context, psk string, next http.Handler, w http.ResponseWriter, r *http.Request, handleOptional func(code int, response codersdk.Response)) { + token := r.Header.Get(codersdk.ProvisionerDaemonPSK) + if subtle.ConstantTimeCompare([]byte(token), []byte(psk)) != 1 { + handleOptional(http.StatusUnauthorized, codersdk.Response{ + Message: "provisioner daemon psk invalid", + }) + return + } + + // The PSK does not indicate a specific provisioner daemon. So just + // store a boolean so the caller can check if the request is from an + // authenticated provisioner daemon. + ctx = context.WithValue(ctx, provisionerDaemonContextKey{}, true) + // nolint:gocritic // Authenticating as a provisioner daemon. + ctx = dbauthz.AsProvisionerd(ctx) + next.ServeHTTP(w, r.WithContext(ctx)) +} diff --git a/coderd/provisionerkey/provisionerkey.go b/coderd/provisionerkey/provisionerkey.go index 4df23125be2d3..70354c140b73e 100644 --- a/coderd/provisionerkey/provisionerkey.go +++ b/coderd/provisionerkey/provisionerkey.go @@ -2,7 +2,9 @@ package provisionerkey import ( "crypto/sha256" + "crypto/subtle" "fmt" + "strings" "github.com/google/uuid" "golang.org/x/xerrors" @@ -18,7 +20,7 @@ func New(organizationID uuid.UUID, name string) (database.InsertProvisionerKeyPa if err != nil { return database.InsertProvisionerKeyParams{}, "", xerrors.Errorf("generate token: %w", err) } - hashedSecret := sha256.Sum256([]byte(secret)) + hashedSecret := HashSecret(secret) token := fmt.Sprintf("%s:%s", id, secret) return database.InsertProvisionerKeyParams{ @@ -26,6 +28,29 @@ func New(organizationID uuid.UUID, name string) (database.InsertProvisionerKeyPa CreatedAt: dbtime.Now(), OrganizationID: organizationID, Name: name, - HashedSecret: hashedSecret[:], + HashedSecret: hashedSecret, }, token, nil } + +func Parse(token string) (uuid.UUID, string, error) { + parts := strings.Split(token, ":") + if len(parts) != 2 { + return uuid.UUID{}, "", xerrors.Errorf("invalid token format") + } + + id, err := uuid.Parse(parts[0]) + if err != nil { + return uuid.UUID{}, "", xerrors.Errorf("parse id: %w", err) + } + + return id, parts[1], nil +} + +func HashSecret(secret string) []byte { + h := sha256.Sum256([]byte(secret)) + return h[:] +} + +func Compare(a []byte, b []byte) bool { + return subtle.ConstantTimeCompare(a, b) != 1 +} diff --git a/codersdk/client.go b/codersdk/client.go index f1ac87981759b..cf013a25c3ce8 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -79,6 +79,9 @@ const ( // ProvisionerDaemonPSK contains the authentication pre-shared key for an external provisioner daemon ProvisionerDaemonPSK = "Coder-Provisioner-Daemon-PSK" + // ProvisionerDaemonKey contains the authentication key for an external provisioner daemon + ProvisionerDaemonKey = "Coder-Provisioner-Daemon-Key" + // BuildVersionHeader contains build information of Coder. BuildVersionHeader = "X-Coder-Build-Version" diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index d6a8ba1e6f2fe..e8be78525d6e6 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -189,6 +189,8 @@ type ServeProvisionerDaemonRequest struct { Tags map[string]string `json:"tags"` // PreSharedKey is an authentication key to use on the API instead of the normal session token from the client. PreSharedKey string `json:"pre_shared_key"` + // ProvisionerKey is an authentication key to use on the API instead of the normal session token from the client. + ProvisionerKey string `json:"provisioner_key"` } // ServeProvisionerDaemon returns the gRPC service for a provisioner daemon @@ -223,8 +225,15 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione headers := http.Header{} headers.Set(BuildVersionHeader, buildinfo.Version()) - if req.PreSharedKey == "" { - // use session token if we don't have a PSK. + + if req.ProvisionerKey != "" { + headers.Set(ProvisionerDaemonKey, req.ProvisionerKey) + } + if req.PreSharedKey != "" { + headers.Set(ProvisionerDaemonPSK, req.PreSharedKey) + } + if req.ProvisionerKey == "" && req.PreSharedKey == "" { + // use session token if we don't have a PSK or provisioner key. jar, err := cookiejar.New(nil) if err != nil { return nil, xerrors.Errorf("create cookie jar: %w", err) @@ -234,8 +243,6 @@ func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisione Value: c.SessionToken(), }}) httpClient.Jar = jar - } else { - headers.Set(ProvisionerDaemonPSK, req.PreSharedKey) } conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 784695a7ac2e3..05263f339b10a 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -109,6 +109,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { provisionerDaemonAuth: &provisionerDaemonAuth{ psk: options.ProvisionerDaemonPSK, authorizer: options.Authorizer, + db: options.Database, }, } // This must happen before coderd initialization! @@ -284,9 +285,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { api.provisionerDaemonsEnabledMW, apiKeyMiddlewareOptional, httpmw.ExtractProvisionerDaemonAuthenticated(httpmw.ExtractProvisionerAuthConfig{ - DB: api.Database, - Optional: true, - }, api.ProvisionerDaemonPSK), + DB: api.Database, + Optional: true, + PSK: api.ProvisionerDaemonPSK, + MultiOrgEnabled: api.AGPL.Experiments.Enabled(codersdk.ExperimentMultiOrganization), + }), // Either a user auth or provisioner auth is required // to move forward. httpmw.RequireAPIKeyOrProvisionerDaemonAuth(), diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index e74f2821092b9..4f9748f2d265b 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -79,36 +79,58 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { type provisionerDaemonAuth struct { psk string + db database.Store authorizer rbac.Authorizer } -// authorize returns mutated tags and true if the given HTTP request is authorized to access the provisioner daemon -// protobuf API, and returns nil, false otherwise. -func (p *provisionerDaemonAuth) authorize(r *http.Request, orgID uuid.UUID, tags map[string]string) (map[string]string, bool) { +// authorize returns mutated tags if the given HTTP request is authorized to access the provisioner daemon +// protobuf API, and returns nil, err otherwise. +func (p *provisionerDaemonAuth) authorize(r *http.Request, orgID uuid.UUID, tags map[string]string) (map[string]string, error) { ctx := r.Context() - apiKey, ok := httpmw.APIKeyOptional(r) - if ok { + apiKey, apiKeyOK := httpmw.APIKeyOptional(r) + pk, pkOK := httpmw.ProvisionerKeyAuthOptional(r) + provAuth := httpmw.ProvisionerDaemonAuthenticated(r) + if !provAuth && !apiKeyOK { + return nil, xerrors.New("no API key or provisioner key provided") + } + if apiKeyOK && pkOK { + return nil, xerrors.New("Both API key and provisioner key authentication provided. Only one is allowed.") + } + + if apiKeyOK { tags = provisionersdk.MutateTags(apiKey.UserID, tags) if tags[provisionersdk.TagScope] == provisionersdk.ScopeUser { // Any authenticated user can create provisioner daemons scoped // for jobs that they own, - return tags, true + return tags, nil } ua := httpmw.UserAuthorization(r) - if err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(orgID)); err == nil { - // User is allowed to create provisioner daemons - return tags, true + err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(orgID)) + if err != nil { + if !provAuth { + return nil, xerrors.New("user unauthorized") + } + + // Allow fallback to PSK auth if the user is not allowed to create provisioner daemons. + // This is to preserve backwards compatibility with existing user provisioner daemons. + // If using PSK auth, the daemon is, by definition, scoped to the organization. + tags = provisionersdk.MutateTags(uuid.Nil, tags) + return tags, nil } + + // User is allowed to create provisioner daemons + return tags, nil } - // Check for PSK - provAuth := httpmw.ProvisionerDaemonAuthenticated(r) - if provAuth { - // If using PSK auth, the daemon is, by definition, scoped to the organization. - tags = provisionersdk.MutateTags(uuid.Nil, tags) - return tags, true + if pkOK { + if pk.OrganizationID != orgID { + return nil, xerrors.New("provisioner key unauthorized") + } } - return nil, false + + // If using provisioner key / PSK auth, the daemon is, by definition, scoped to the organization. + tags = provisionersdk.MutateTags(uuid.Nil, tags) + return tags, nil } // Serves the provisioner daemon protobuf API over a WebSocket. @@ -171,12 +193,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) api.Logger.Warn(ctx, "unnamed provisioner daemon") } - tags, authorized := api.provisionerDaemonAuth.authorize(r, organization.ID, tags) - if !authorized { - api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags)) + tags, err := api.provisionerDaemonAuth.authorize(r, organization.ID, tags) + if err != nil { + api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags), slog.Error(err)) httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ Message: fmt.Sprintf("You aren't allowed to create provisioner daemons with scope %q", tags[provisionersdk.TagScope]), + Detail: err.Error(), }, ) return @@ -209,7 +232,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) ) authCtx := ctx - if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" { + if r.Header.Get(codersdk.ProvisionerDaemonPSK) != "" || r.Header.Get(codersdk.ProvisionerDaemonKey) != "" { //nolint:gocritic // PSK auth means no actor in request, // so use system restricted. authCtx = dbauthz.AsSystemRestricted(ctx) diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index c7c256f041c8b..139a97199ee92 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "testing" "github.com/google/uuid" @@ -18,6 +19,8 @@ import ( "github.com/coder/coder/v2/buildinfo" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/provisionerkey" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" @@ -552,6 +555,174 @@ func TestProvisionerDaemonServe(t *testing.T) { require.NoError(t, err) require.Len(t, daemons, 0) }) + + t.Run("ProvisionerKeyAuth", func(t *testing.T) { + t.Parallel() + + insertParams, token, err := provisionerkey.New(uuid.Nil, "dont-TEST-me") + require.NoError(t, err) + + tcs := []struct { + name string + psk string + multiOrgFeatureEnabled bool + multiOrgExperimentEnabled bool + insertParams database.InsertProvisionerKeyParams + requestProvisionerKey string + requestPSK string + errStatusCode int + }{ + { + name: "MultiOrgDisabledPSKAuthOK", + psk: "provisionersftw", + requestPSK: "provisionersftw", + }, + { + name: "MultiOrgExperimentDisabledPSKAuthOK", + multiOrgFeatureEnabled: true, + psk: "provisionersftw", + requestPSK: "provisionersftw", + }, + { + name: "MultiOrgFeatureDisabledPSKAuthOK", + multiOrgExperimentEnabled: true, + psk: "provisionersftw", + requestPSK: "provisionersftw", + }, + { + name: "MultiOrgEnabledPSKAuthOK", + psk: "provisionersftw", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + requestPSK: "provisionersftw", + }, + { + name: "MultiOrgEnabledKeyAuthOK", + psk: "provisionersftw", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + insertParams: insertParams, + requestProvisionerKey: token, + }, + { + name: "MultiOrgEnabledPSKAuthDisabled", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + requestPSK: "provisionersftw", + errStatusCode: http.StatusUnauthorized, + }, + { + name: "WrongKey", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + insertParams: insertParams, + requestProvisionerKey: "provisionersftw", + errStatusCode: http.StatusUnauthorized, + }, + { + name: "IdOKKeyValueWrong", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + insertParams: insertParams, + requestProvisionerKey: insertParams.ID.String() + ":" + "wrong", + errStatusCode: http.StatusUnauthorized, + }, + { + name: "IdWrongKeyValueOK", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + insertParams: insertParams, + requestProvisionerKey: uuid.NewString() + ":" + token, + errStatusCode: http.StatusUnauthorized, + }, + { + name: "KeyValueOnly", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + insertParams: insertParams, + requestProvisionerKey: strings.Split(token, ":")[1], + errStatusCode: http.StatusUnauthorized, + }, + { + name: "KeyAndPSK", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + psk: "provisionersftw", + insertParams: insertParams, + requestProvisionerKey: token, + requestPSK: "provisionersftw", + errStatusCode: http.StatusUnauthorized, + }, + { + name: "None", + multiOrgFeatureEnabled: true, + multiOrgExperimentEnabled: true, + psk: "provisionersftw", + insertParams: insertParams, + errStatusCode: http.StatusUnauthorized, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + features := license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + } + if tc.multiOrgFeatureEnabled { + features[codersdk.FeatureMultipleOrganizations] = 1 + } + dv := coderdtest.DeploymentValues(t) + if tc.multiOrgExperimentEnabled { + dv.Experiments.Append(string(codersdk.ExperimentMultiOrganization)) + } + client, db, user := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: features, + }, + ProvisionerDaemonPSK: tc.psk, + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + if tc.insertParams.Name != "" { + tc.insertParams.OrganizationID = user.OrganizationID + // nolint:gocritic // test + _, err := db.InsertProvisionerKey(dbauthz.AsSystemRestricted(ctx), tc.insertParams) + require.NoError(t, err) + } + + another := codersdk.New(client.URL) + srv, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + ID: uuid.New(), + Name: testutil.MustRandString(t, 63), + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionersdk.TagScope: provisionersdk.ScopeOrganization, + }, + PreSharedKey: tc.requestPSK, + ProvisionerKey: tc.requestProvisionerKey, + }) + if tc.errStatusCode != 0 { + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusUnauthorized, apiError.StatusCode()) + return + } + + require.NoError(t, err) + err = srv.DRPCConn().Close() + require.NoError(t, err) + }) + } + }) } func TestGetProvisionerDaemons(t *testing.T) {