Skip to content

Commit 1d35967

Browse files
committed
pr comments
1 parent bc6dec0 commit 1d35967

File tree

4 files changed

+58
-44
lines changed

4 files changed

+58
-44
lines changed

coderd/apikey.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ func (api *API) validateAPIKeyLifetime(lifetime time.Duration) error {
374374
}
375375

376376
func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*http.Cookie, *database.APIKey, error) {
377-
secret, key, err := apikey.Generate(params)
377+
key, sessionToken, err := apikey.Generate(params)
378378
if err != nil {
379379
return nil, nil, xerrors.Errorf("generate API key: %w", err)
380380
}
@@ -390,7 +390,7 @@ func (api *API) createAPIKey(ctx context.Context, params apikey.CreateParams) (*
390390

391391
return &http.Cookie{
392392
Name: codersdk.SessionTokenCookie,
393-
Value: secret,
393+
Value: sessionToken,
394394
Path: "/",
395395
HttpOnly: true,
396396
SameSite: http.SameSiteLaxMode,

coderd/apikey/apikey.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ type CreateParams struct {
3131
// Generate generates an API key, returning the key as a string as well as the
3232
// database representation. It is the responsibility of the caller to insert it
3333
// into the database.
34-
func Generate(params CreateParams) (string, database.InsertAPIKeyParams, error) {
34+
func Generate(params CreateParams) (database.InsertAPIKeyParams, string, error) {
3535
keyID, keySecret, err := generateKey()
3636
if err != nil {
37-
return "", database.InsertAPIKeyParams{}, xerrors.Errorf("generate API key: %w", err)
37+
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("generate API key: %w", err)
3838
}
3939

4040
hashed := sha256.Sum256([]byte(keySecret))
@@ -67,12 +67,12 @@ func Generate(params CreateParams) (string, database.InsertAPIKeyParams, error)
6767
switch scope {
6868
case database.APIKeyScopeAll, database.APIKeyScopeApplicationConnect:
6969
default:
70-
return "", database.InsertAPIKeyParams{}, xerrors.Errorf("invalid API key scope: %q", scope)
70+
return database.InsertAPIKeyParams{}, "", xerrors.Errorf("invalid API key scope: %q", scope)
7171
}
7272

73-
keyStr := fmt.Sprintf("%s-%s", keyID, keySecret)
73+
token := fmt.Sprintf("%s-%s", keyID, keySecret)
7474

75-
return keyStr, database.InsertAPIKeyParams{
75+
return database.InsertAPIKeyParams{
7676
ID: keyID,
7777
UserID: params.UserID,
7878
LifetimeSeconds: params.LifetimeSeconds,
@@ -91,7 +91,7 @@ func Generate(params CreateParams) (string, database.InsertAPIKeyParams, error)
9191
LoginType: params.LoginType,
9292
Scope: scope,
9393
TokenName: params.TokenName,
94-
}, nil
94+
}, token, nil
9595
}
9696

9797
// generateKey a new ID and secret for an API key.

coderd/apikey/apikey_test.go

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"time"
88

99
"github.com/google/uuid"
10+
"github.com/stretchr/testify/assert"
1011
"github.com/stretchr/testify/require"
1112

1213
"github.com/coder/coder/cli/clibase"
@@ -100,7 +101,7 @@ func TestGenerate(t *testing.T) {
100101
t.Run(tc.name, func(t *testing.T) {
101102
t.Parallel()
102103

103-
keystr, key, err := apikey.Generate(tc.params)
104+
key, keystr, err := apikey.Generate(tc.params)
104105
if tc.fail {
105106
require.Error(t, err)
106107
return
@@ -117,46 +118,46 @@ func TestGenerate(t *testing.T) {
117118

118119
// Assert that the hashed secret is correct.
119120
hashed := sha256.Sum256([]byte(keytokens[1]))
120-
require.ElementsMatch(t, hashed, key.HashedSecret[:])
121+
assert.ElementsMatch(t, hashed, key.HashedSecret[:])
121122

122-
require.Equal(t, tc.params.UserID, key.UserID)
123-
require.WithinDuration(t, database.Now(), key.CreatedAt, time.Second*5)
124-
require.WithinDuration(t, database.Now(), key.UpdatedAt, time.Second*5)
123+
assert.Equal(t, tc.params.UserID, key.UserID)
124+
assert.WithinDuration(t, database.Now(), key.CreatedAt, time.Second*5)
125+
assert.WithinDuration(t, database.Now(), key.UpdatedAt, time.Second*5)
125126

126127
if tc.params.LifetimeSeconds > 0 {
127-
require.Equal(t, tc.params.LifetimeSeconds, key.LifetimeSeconds)
128+
assert.Equal(t, tc.params.LifetimeSeconds, key.LifetimeSeconds)
128129
} else if !tc.params.ExpiresAt.IsZero() {
129130
// Should not be a delta greater than 5 seconds.
130-
require.InDelta(t, time.Until(tc.params.ExpiresAt).Seconds(), key.LifetimeSeconds, 5)
131+
assert.InDelta(t, time.Until(tc.params.ExpiresAt).Seconds(), key.LifetimeSeconds, 5)
131132
} else {
132-
require.Equal(t, int64(tc.params.DeploymentValues.SessionDuration.Value().Seconds()), key.LifetimeSeconds)
133+
assert.Equal(t, int64(tc.params.DeploymentValues.SessionDuration.Value().Seconds()), key.LifetimeSeconds)
133134
}
134135

135136
if !tc.params.ExpiresAt.IsZero() {
136-
require.Equal(t, tc.params.ExpiresAt.UTC(), key.ExpiresAt)
137+
assert.Equal(t, tc.params.ExpiresAt.UTC(), key.ExpiresAt)
137138
} else if tc.params.LifetimeSeconds > 0 {
138-
require.WithinDuration(t, database.Now().Add(time.Duration(tc.params.LifetimeSeconds)), key.ExpiresAt, time.Second*5)
139+
assert.WithinDuration(t, database.Now().Add(time.Duration(tc.params.LifetimeSeconds)), key.ExpiresAt, time.Second*5)
139140
} else {
140-
require.WithinDuration(t, database.Now().Add(tc.params.DeploymentValues.SessionDuration.Value()), key.ExpiresAt, time.Second*5)
141+
assert.WithinDuration(t, database.Now().Add(tc.params.DeploymentValues.SessionDuration.Value()), key.ExpiresAt, time.Second*5)
141142
}
142143

143144
if tc.params.RemoteAddr != "" {
144-
require.Equal(t, tc.params.RemoteAddr, key.IPAddress.IPNet.IP.String())
145+
assert.Equal(t, tc.params.RemoteAddr, key.IPAddress.IPNet.IP.String())
145146
} else {
146-
require.Equal(t, "0.0.0.0", key.IPAddress.IPNet.IP.String())
147+
assert.Equal(t, "0.0.0.0", key.IPAddress.IPNet.IP.String())
147148
}
148149

149150
if tc.params.Scope != "" {
150-
require.Equal(t, tc.params.Scope, key.Scope)
151+
assert.Equal(t, tc.params.Scope, key.Scope)
151152
} else {
152-
require.Equal(t, database.APIKeyScopeAll, key.Scope)
153+
assert.Equal(t, database.APIKeyScopeAll, key.Scope)
153154
}
154155

155156
if tc.params.TokenName != "" {
156-
require.Equal(t, tc.params.TokenName, key.TokenName)
157+
assert.Equal(t, tc.params.TokenName, key.TokenName)
157158
}
158159
if tc.params.LoginType != "" {
159-
require.Equal(t, tc.params.LoginType, key.LoginType)
160+
assert.Equal(t, tc.params.LoginType, key.LoginType)
160161
}
161162
})
162163
}

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
203203
return nil, failJob(fmt.Sprintf("regenerate session token: %s", err))
204204
}
205205
case database.WorkspaceTransitionStop, database.WorkspaceTransitionDelete:
206-
err = server.deleteSessionToken(ctx, workspace)
206+
err = deleteSessionToken(ctx, server.Database, workspace)
207207
if err != nil {
208208
return nil, failJob(fmt.Sprintf("delete session token: %s", err))
209209
}
@@ -1432,7 +1432,7 @@ func workspaceSessionTokenName(workspace database.Workspace) string {
14321432
}
14331433

14341434
func (server *Server) regenerateSessionToken(ctx context.Context, user database.User, workspace database.Workspace) (string, error) {
1435-
secret, newkey, err := apikey.Generate(apikey.CreateParams{
1435+
newkey, sessionToken, err := apikey.Generate(apikey.CreateParams{
14361436
UserID: user.ID,
14371437
LoginType: user.LoginType,
14381438
DeploymentValues: server.DeploymentValues,
@@ -1443,30 +1443,43 @@ func (server *Server) regenerateSessionToken(ctx context.Context, user database.
14431443
return "", xerrors.Errorf("generate API key: %w", err)
14441444
}
14451445

1446-
err = server.deleteSessionToken(ctx, workspace)
1447-
if err != nil {
1448-
return "", xerrors.Errorf("delete session token: %w", err)
1449-
}
1446+
err = server.Database.InTx(func(tx database.Store) error {
1447+
err := deleteSessionToken(ctx, tx, workspace)
1448+
if err != nil {
1449+
return xerrors.Errorf("delete session token: %w", err)
1450+
}
14501451

1451-
_, err = server.Database.InsertAPIKey(ctx, newkey)
1452+
_, err = tx.InsertAPIKey(ctx, newkey)
1453+
if err != nil {
1454+
return xerrors.Errorf("insert API key: %w", err)
1455+
}
1456+
return nil
1457+
}, nil)
14521458
if err != nil {
1453-
return "", xerrors.Errorf("insert API key: %w", err)
1459+
return "", xerrors.Errorf("create API key: %w", err)
14541460
}
14551461

1456-
return secret, nil
1462+
return sessionToken, nil
14571463
}
14581464

1459-
func (server *Server) deleteSessionToken(ctx context.Context, workspace database.Workspace) error {
1460-
key, err := server.Database.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
1461-
UserID: workspace.OwnerID,
1462-
TokenName: workspaceSessionTokenName(workspace),
1463-
})
1464-
if err == nil {
1465-
err = server.Database.DeleteAPIKeyByID(ctx, key.ID)
1466-
}
1465+
func deleteSessionToken(ctx context.Context, db database.Store, workspace database.Workspace) error {
1466+
err := db.InTx(func(tx database.Store) error {
1467+
key, err := tx.GetAPIKeyByName(ctx, database.GetAPIKeyByNameParams{
1468+
UserID: workspace.OwnerID,
1469+
TokenName: workspaceSessionTokenName(workspace),
1470+
})
1471+
if err == nil {
1472+
err = tx.DeleteAPIKeyByID(ctx, key.ID)
1473+
}
14671474

1468-
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
1469-
return xerrors.Errorf("get api key by name: %w", err)
1475+
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
1476+
return xerrors.Errorf("get api key by name: %w", err)
1477+
}
1478+
1479+
return nil
1480+
}, nil)
1481+
if err != nil {
1482+
return xerrors.Errorf("in tx: %w", err)
14701483
}
14711484

14721485
return nil

0 commit comments

Comments
 (0)