diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 6a768fa9b4dfd..d12b9aba23863 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1680,6 +1680,10 @@ func (q *querier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt return q.db.GetProvisionerJobsCreatedAfter(ctx, createdAt) } +func (q *querier) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + return fetch(q.log, q.auth, q.db.GetProvisionerKeyByHashedSecret)(ctx, hashedSecret) +} + func (q *querier) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (database.ProvisionerKey, error) { return fetch(q.log, q.auth, q.db.GetProvisionerKeyByID)(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 6514d2f0dfeb0..0ec7d2b17fb9c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1825,6 +1825,11 @@ func (s *MethodTestSuite) TestProvisionerKeys() { pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) check.Args(pk.ID).Asserts(pk, policy.ActionRead).Returns(pk) })) + s.Run("GetProvisionerKeyByHashedSecret", s.Subtest(func(db database.Store, check *expects) { + org := dbgen.Organization(s.T(), db, database.Organization{}) + pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID, HashedSecret: []byte("foo")}) + check.Args([]byte("foo")).Asserts(pk, policy.ActionRead).Returns(pk) + })) s.Run("GetProvisionerKeyByName", s.Subtest(func(db database.Store, check *expects) { org := dbgen.Organization(s.T(), db, database.Organization{}) pk := dbgen.ProvisionerKey(s.T(), db, database.ProvisionerKey{OrganizationID: org.ID}) diff --git a/coderd/database/dbmem/dbmem.go b/coderd/database/dbmem/dbmem.go index 827d99a2c14df..8d1088616f6bc 100644 --- a/coderd/database/dbmem/dbmem.go +++ b/coderd/database/dbmem/dbmem.go @@ -3240,6 +3240,19 @@ func (q *FakeQuerier) GetProvisionerJobsCreatedAfter(_ context.Context, after ti return jobs, nil } +func (q *FakeQuerier) GetProvisionerKeyByHashedSecret(_ context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + for _, key := range q.provisionerKeys { + if bytes.Equal(key.HashedSecret, hashedSecret) { + return key, nil + } + } + + return database.ProvisionerKey{}, sql.ErrNoRows +} + func (q *FakeQuerier) GetProvisionerKeyByID(_ context.Context, id uuid.UUID) (database.ProvisionerKey, error) { q.mutex.RLock() defer q.mutex.RUnlock() diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index e6642da53974f..f987d0505653b 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -914,6 +914,13 @@ func (m metricsStore) GetProvisionerJobsCreatedAfter(ctx context.Context, create return jobs, err } +func (m metricsStore) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (database.ProvisionerKey, error) { + start := time.Now() + r0, r1 := m.s.GetProvisionerKeyByHashedSecret(ctx, hashedSecret) + m.queryLatencies.WithLabelValues("GetProvisionerKeyByHashedSecret").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m metricsStore) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (database.ProvisionerKey, error) { start := time.Now() r0, r1 := m.s.GetProvisionerKeyByID(ctx, id) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 8517a7a8e5f21..78cd95a69cde5 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1840,6 +1840,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerJobsCreatedAfter(arg0, arg1 any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerJobsCreatedAfter", reflect.TypeOf((*MockStore)(nil).GetProvisionerJobsCreatedAfter), arg0, arg1) } +// GetProvisionerKeyByHashedSecret mocks base method. +func (m *MockStore) GetProvisionerKeyByHashedSecret(arg0 context.Context, arg1 []byte) (database.ProvisionerKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProvisionerKeyByHashedSecret", arg0, arg1) + ret0, _ := ret[0].(database.ProvisionerKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetProvisionerKeyByHashedSecret indicates an expected call of GetProvisionerKeyByHashedSecret. +func (mr *MockStoreMockRecorder) GetProvisionerKeyByHashedSecret(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerKeyByHashedSecret", reflect.TypeOf((*MockStore)(nil).GetProvisionerKeyByHashedSecret), arg0, arg1) +} + // GetProvisionerKeyByID mocks base method. func (m *MockStore) GetProvisionerKeyByID(arg0 context.Context, arg1 uuid.UUID) (database.ProvisionerKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 78ebf958739d6..9d0494813e306 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -186,6 +186,7 @@ type sqlcQuerier interface { GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error) GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) + GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (ProvisionerKey, error) GetProvisionerKeyByID(ctx context.Context, id uuid.UUID) (ProvisionerKey, error) GetProvisionerKeyByName(ctx context.Context, arg GetProvisionerKeyByNameParams) (ProvisionerKey, error) GetProvisionerLogsAfterID(ctx context.Context, arg GetProvisionerLogsAfterIDParams) ([]ProvisionerJobLog, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f383f2e7c0d5d..2e3a5c9892d40 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5531,6 +5531,29 @@ func (q *sqlQuerier) DeleteProvisionerKey(ctx context.Context, id uuid.UUID) err return err } +const getProvisionerKeyByHashedSecret = `-- name: GetProvisionerKeyByHashedSecret :one +SELECT + id, created_at, organization_id, name, hashed_secret, tags +FROM + provisioner_keys +WHERE + hashed_secret = $1 +` + +func (q *sqlQuerier) GetProvisionerKeyByHashedSecret(ctx context.Context, hashedSecret []byte) (ProvisionerKey, error) { + row := q.db.QueryRowContext(ctx, getProvisionerKeyByHashedSecret, hashedSecret) + var i ProvisionerKey + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.OrganizationID, + &i.Name, + &i.HashedSecret, + &i.Tags, + ) + return i, err +} + const getProvisionerKeyByID = `-- name: GetProvisionerKeyByID :one SELECT id, created_at, organization_id, name, hashed_secret, tags diff --git a/coderd/database/queries/provisionerkeys.sql b/coderd/database/queries/provisionerkeys.sql index ac41eb2d444d2..cb4c763f1061e 100644 --- a/coderd/database/queries/provisionerkeys.sql +++ b/coderd/database/queries/provisionerkeys.sql @@ -19,6 +19,14 @@ FROM WHERE id = $1; +-- name: GetProvisionerKeyByHashedSecret :one +SELECT + * +FROM + provisioner_keys +WHERE + hashed_secret = $1; + -- name: GetProvisionerKeyByName :one SELECT * diff --git a/coderd/httpmw/provisionerdaemon.go b/coderd/httpmw/provisionerdaemon.go index 243af82598ff8..cac4aa0cba0a9 100644 --- a/coderd/httpmw/provisionerdaemon.go +++ b/coderd/httpmw/provisionerdaemon.go @@ -71,16 +71,17 @@ func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) fu return } - id, keyValue, err := provisionerkey.Parse(key) + err := provisionerkey.Validate(key) if err != nil { - handleOptional(http.StatusUnauthorized, codersdk.Response{ + handleOptional(http.StatusBadRequest, codersdk.Response{ Message: "provisioner daemon key invalid", + Detail: err.Error(), }) return } - + hashedKey := provisionerkey.HashSecret(key) // nolint:gocritic // System must check if the provisioner key is valid. - pk, err := opts.DB.GetProvisionerKeyByID(dbauthz.AsSystemRestricted(ctx), id) + pk, err := opts.DB.GetProvisionerKeyByHashedSecret(dbauthz.AsSystemRestricted(ctx), hashedKey) if err != nil { if httpapi.Is404Error(err) { handleOptional(http.StatusUnauthorized, codersdk.Response{ @@ -90,12 +91,13 @@ func ExtractProvisionerDaemonAuthenticated(opts ExtractProvisionerAuthConfig) fu } handleOptional(http.StatusInternalServerError, codersdk.Response{ - Message: "get provisioner daemon key: " + err.Error(), + Message: "get provisioner daemon key", + Detail: err.Error(), }) return } - if provisionerkey.Compare(pk.HashedSecret, provisionerkey.HashSecret(keyValue)) { + if provisionerkey.Compare(pk.HashedSecret, hashedKey) { handleOptional(http.StatusUnauthorized, codersdk.Response{ Message: "provisioner daemon key invalid", }) diff --git a/coderd/provisionerkey/provisionerkey.go b/coderd/provisionerkey/provisionerkey.go index 5be3658f6a5be..bfd70fb0295e0 100644 --- a/coderd/provisionerkey/provisionerkey.go +++ b/coderd/provisionerkey/provisionerkey.go @@ -3,8 +3,6 @@ package provisionerkey import ( "crypto/sha256" "crypto/subtle" - "fmt" - "strings" "github.com/google/uuid" "golang.org/x/xerrors" @@ -14,41 +12,36 @@ import ( "github.com/coder/coder/v2/cryptorand" ) +const ( + secretLength = 43 +) + func New(organizationID uuid.UUID, name string, tags map[string]string) (database.InsertProvisionerKeyParams, string, error) { - id := uuid.New() - secret, err := cryptorand.HexString(64) + secret, err := cryptorand.String(secretLength) if err != nil { - return database.InsertProvisionerKeyParams{}, "", xerrors.Errorf("generate token: %w", err) + return database.InsertProvisionerKeyParams{}, "", xerrors.Errorf("generate secret: %w", err) } - hashedSecret := HashSecret(secret) - token := fmt.Sprintf("%s:%s", id, secret) if tags == nil { tags = map[string]string{} } return database.InsertProvisionerKeyParams{ - ID: id, + ID: uuid.New(), CreatedAt: dbtime.Now(), OrganizationID: organizationID, Name: name, - HashedSecret: hashedSecret, + HashedSecret: HashSecret(secret), Tags: tags, - }, token, nil + }, secret, nil } -func Parse(token string) (uuid.UUID, string, error) { - parts := strings.Split(token, ":") - if len(parts) != 2 { - return uuid.UUID{}, "", xerrors.Errorf("invalid token format") - } - - id, err := uuid.Parse(parts[0]) - if err != nil { - return uuid.UUID{}, "", xerrors.Errorf("parse id: %w", err) +func Validate(token string) error { + if len(token) != secretLength { + return xerrors.Errorf("must be %d characters", secretLength) } - return id, parts[1], nil + return nil } func HashSecret(secret string) []byte { diff --git a/enterprise/cli/provisionerdaemonstart.go b/enterprise/cli/provisionerdaemonstart.go index b0dfff227dbe3..f92b0126c46a7 100644 --- a/enterprise/cli/provisionerdaemonstart.go +++ b/enterprise/cli/provisionerdaemonstart.go @@ -122,9 +122,9 @@ func (r *RootCmd) provisionerDaemonStart() *serpent.Command { if len(rawTags) > 0 { return xerrors.New("cannot provide tags when using provisioner key") } - _, _, err := provisionerkey.Parse(provisionerKey) + err = provisionerkey.Validate(provisionerKey) if err != nil { - return xerrors.Errorf("parse provisioner key: %w", err) + return xerrors.Errorf("validate provisioner key: %w", err) } } diff --git a/enterprise/cli/provisionerkeys_test.go b/enterprise/cli/provisionerkeys_test.go index 5b62b1e9d46fd..47df45ed98596 100644 --- a/enterprise/cli/provisionerkeys_test.go +++ b/enterprise/cli/provisionerkeys_test.go @@ -4,11 +4,11 @@ import ( "strings" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/provisionerkey" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" @@ -58,10 +58,7 @@ func TestProvisionerKeys(t *testing.T) { _ = pty.ReadLine(ctx) key := pty.ReadLine(ctx) require.NotEmpty(t, key) - parts := strings.Split(key, ":") - require.Len(t, parts, 2, "expected 2 parts") - _, err = uuid.Parse(parts[0]) - require.NoError(t, err, "expected token to be a uuid") + require.NoError(t, provisionerkey.Validate(key)) inv, conf = newCLI( t, diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 451ff2249a15d..a3cf9a23cc75e 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "strings" "testing" "github.com/google/uuid" @@ -612,36 +611,12 @@ func TestProvisionerDaemonServe(t *testing.T) { errStatusCode: http.StatusUnauthorized, }, { - name: "WrongKey", + name: "InvalidKey", multiOrgFeatureEnabled: true, multiOrgExperimentEnabled: true, insertParams: insertParams, requestProvisionerKey: "provisionersftw", - errStatusCode: http.StatusUnauthorized, - }, - { - name: "IdOKKeyValueWrong", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: insertParams.ID.String() + ":" + "wrong", - errStatusCode: http.StatusUnauthorized, - }, - { - name: "IdWrongKeyValueOK", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: uuid.NewString() + ":" + token, - errStatusCode: http.StatusUnauthorized, - }, - { - name: "KeyValueOnly", - multiOrgFeatureEnabled: true, - multiOrgExperimentEnabled: true, - insertParams: insertParams, - requestProvisionerKey: strings.Split(token, ":")[1], - errStatusCode: http.StatusUnauthorized, + errStatusCode: http.StatusBadRequest, }, { name: "KeyAndPSK",