Skip to content

Commit c3fb1b3

Browse files
authored
feat: add owner_oidc_access_token to coder_workspace data source (#6042)
See the discussion in Discord here: https://discord.com/channels/747933592273027093/1071182088490987542/1071182088490987542 Related provider PR: coder/terraform-provider-coder#91
1 parent ca067cf commit c3fb1b3

File tree

8 files changed

+297
-124
lines changed

8 files changed

+297
-124
lines changed

coderd/coderd.go

+1
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti
835835
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
836836
AccessURL: api.AccessURL,
837837
ID: daemon.ID,
838+
OIDCConfig: api.OIDCConfig,
838839
Database: api.Database,
839840
Pubsub: api.Pubsub,
840841
Provisioners: daemon.Provisioners,

coderd/provisionerdserver/provisionerdserver.go

+66-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/tabbed/pqtype"
2020
"golang.org/x/exp/maps"
2121
"golang.org/x/exp/slices"
22+
"golang.org/x/oauth2"
2223
"golang.org/x/xerrors"
2324
protobuf "google.golang.org/protobuf/proto"
2425

@@ -27,6 +28,7 @@ import (
2728
"github.com/coder/coder/coderd/audit"
2829
"github.com/coder/coder/coderd/database"
2930
"github.com/coder/coder/coderd/database/dbauthz"
31+
"github.com/coder/coder/coderd/httpmw"
3032
"github.com/coder/coder/coderd/parameter"
3133
"github.com/coder/coder/coderd/schedule"
3234
"github.com/coder/coder/coderd/telemetry"
@@ -58,6 +60,7 @@ type Server struct {
5860
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
5961

6062
AcquireJobDebounce time.Duration
63+
OIDCConfig httpmw.OAuth2Config
6164
}
6265

6366
// AcquireJob queries the database to lock a job.
@@ -168,6 +171,14 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
168171
return nil, failJob(fmt.Sprintf("publish workspace update: %s", err))
169172
}
170173

174+
var workspaceOwnerOIDCAccessToken string
175+
if server.OIDCConfig != nil {
176+
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, server.Database, server.OIDCConfig, owner.ID)
177+
if err != nil {
178+
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
179+
}
180+
}
181+
171182
// Compute parameters for the workspace to consume.
172183
parameters, err := parameter.Compute(ctx, server.Database, parameter.ComputeScope{
173184
TemplateImportJobID: templateVersion.JobID,
@@ -208,15 +219,16 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
208219
RichParameterValues: convertRichParameterValues(workspaceBuildParameters),
209220
VariableValues: asVariableValues(templateVariables),
210221
Metadata: &sdkproto.Provision_Metadata{
211-
CoderUrl: server.AccessURL.String(),
212-
WorkspaceTransition: transition,
213-
WorkspaceName: workspace.Name,
214-
WorkspaceOwner: owner.Username,
215-
WorkspaceOwnerEmail: owner.Email,
216-
WorkspaceId: workspace.ID.String(),
217-
WorkspaceOwnerId: owner.ID.String(),
218-
TemplateName: template.Name,
219-
TemplateVersion: templateVersion.Name,
222+
CoderUrl: server.AccessURL.String(),
223+
WorkspaceTransition: transition,
224+
WorkspaceName: workspace.Name,
225+
WorkspaceOwner: owner.Username,
226+
WorkspaceOwnerEmail: owner.Email,
227+
WorkspaceOwnerOidcAccessToken: workspaceOwnerOIDCAccessToken,
228+
WorkspaceId: workspace.ID.String(),
229+
WorkspaceOwnerId: owner.ID.String(),
230+
TemplateName: template.Name,
231+
TemplateVersion: templateVersion.Name,
220232
},
221233
},
222234
}
@@ -1295,6 +1307,51 @@ func InsertWorkspaceResource(ctx context.Context, db database.Store, jobID uuid.
12951307
return nil
12961308
}
12971309

1310+
// obtainOIDCAccessToken returns a valid OpenID Connect access token
1311+
// for the user if it's able to obtain one, otherwise it returns an empty string.
1312+
func obtainOIDCAccessToken(ctx context.Context, db database.Store, oidcConfig httpmw.OAuth2Config, userID uuid.UUID) (string, error) {
1313+
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
1314+
UserID: userID,
1315+
LoginType: database.LoginTypeOIDC,
1316+
})
1317+
if errors.Is(err, sql.ErrNoRows) {
1318+
err = nil
1319+
}
1320+
if err != nil {
1321+
return "", xerrors.Errorf("get owner oidc link: %w", err)
1322+
}
1323+
1324+
if link.OAuthExpiry.Before(database.Now()) && !link.OAuthExpiry.IsZero() && link.OAuthRefreshToken != "" {
1325+
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
1326+
AccessToken: link.OAuthAccessToken,
1327+
RefreshToken: link.OAuthRefreshToken,
1328+
Expiry: link.OAuthExpiry,
1329+
}).Token()
1330+
if err != nil {
1331+
// If OIDC fails to refresh, we return an empty string and don't fail.
1332+
// There isn't a way to hard-opt in to OIDC from a template, so we don't
1333+
// want to fail builds if users haven't authenticated for a while or something.
1334+
return "", nil
1335+
}
1336+
link.OAuthAccessToken = token.AccessToken
1337+
link.OAuthRefreshToken = token.RefreshToken
1338+
link.OAuthExpiry = token.Expiry
1339+
1340+
link, err = db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
1341+
UserID: userID,
1342+
LoginType: database.LoginTypeOIDC,
1343+
OAuthAccessToken: link.OAuthAccessToken,
1344+
OAuthRefreshToken: link.OAuthRefreshToken,
1345+
OAuthExpiry: link.OAuthExpiry,
1346+
})
1347+
if err != nil {
1348+
return "", xerrors.Errorf("update user link: %w", err)
1349+
}
1350+
}
1351+
1352+
return link.OAuthAccessToken, nil
1353+
}
1354+
12981355
func convertValidationTypeSystem(typeSystem sdkproto.ParameterSchema_TypeSystem) (database.ParameterTypeSystem, error) {
12991356
switch typeSystem {
13001357
case sdkproto.ParameterSchema_None:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package provisionerdserver
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"golang.org/x/oauth2"
11+
12+
"github.com/coder/coder/coderd/database"
13+
"github.com/coder/coder/coderd/database/dbfake"
14+
"github.com/coder/coder/coderd/database/dbgen"
15+
)
16+
17+
func TestObtainOIDCAccessToken(t *testing.T) {
18+
t.Parallel()
19+
ctx := context.Background()
20+
t.Run("NoToken", func(t *testing.T) {
21+
t.Parallel()
22+
db := dbfake.New()
23+
_, err := obtainOIDCAccessToken(ctx, db, nil, uuid.Nil)
24+
require.NoError(t, err)
25+
})
26+
t.Run("InvalidConfig", func(t *testing.T) {
27+
// We still want OIDC to succeed even if exchanging the token fails.
28+
t.Parallel()
29+
db := dbfake.New()
30+
user := dbgen.User(t, db, database.User{})
31+
dbgen.UserLink(t, db, database.UserLink{
32+
UserID: user.ID,
33+
LoginType: database.LoginTypeOIDC,
34+
OAuthExpiry: database.Now().Add(-time.Hour),
35+
})
36+
_, err := obtainOIDCAccessToken(ctx, db, &oauth2.Config{}, user.ID)
37+
require.NoError(t, err)
38+
})
39+
t.Run("Exchange", func(t *testing.T) {
40+
t.Parallel()
41+
db := dbfake.New()
42+
user := dbgen.User(t, db, database.User{})
43+
dbgen.UserLink(t, db, database.UserLink{
44+
UserID: user.ID,
45+
LoginType: database.LoginTypeOIDC,
46+
OAuthExpiry: database.Now().Add(-time.Hour),
47+
})
48+
_, err := obtainOIDCAccessToken(ctx, db, &oauth2Config{
49+
tokenSource: func() (*oauth2.Token, error) {
50+
return &oauth2.Token{
51+
AccessToken: "token",
52+
}, nil
53+
},
54+
}, user.ID)
55+
require.NoError(t, err)
56+
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
57+
UserID: user.ID,
58+
LoginType: database.LoginTypeOIDC,
59+
})
60+
require.NoError(t, err)
61+
require.Equal(t, "token", link.OAuthAccessToken)
62+
})
63+
}
64+
65+
type oauth2Config struct {
66+
tokenSource oauth2TokenSource
67+
}
68+
69+
func (o *oauth2Config) TokenSource(context.Context, *oauth2.Token) oauth2.TokenSource {
70+
return o.tokenSource
71+
}
72+
73+
func (*oauth2Config) AuthCodeURL(string, ...oauth2.AuthCodeOption) string {
74+
return ""
75+
}
76+
77+
func (*oauth2Config) Exchange(context.Context, string, ...oauth2.AuthCodeOption) (*oauth2.Token, error) {
78+
return &oauth2.Token{}, nil
79+
}
80+
81+
type oauth2TokenSource func() (*oauth2.Token, error)
82+
83+
func (o oauth2TokenSource) Token() (*oauth2.Token, error) {
84+
return o()
85+
}

coderd/provisionerdserver/provisionerdserver_test.go

+18-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
"github.com/google/uuid"
1313
"github.com/stretchr/testify/require"
14+
"golang.org/x/oauth2"
1415

1516
"cdr.dev/slog/sloggers/slogtest"
1617
"github.com/coder/coder/coderd/audit"
@@ -100,6 +101,12 @@ func TestAcquireJob(t *testing.T) {
100101
ctx := context.Background()
101102

102103
user := dbgen.User(t, srv.Database, database.User{})
104+
link := dbgen.UserLink(t, srv.Database, database.UserLink{
105+
LoginType: database.LoginTypeOIDC,
106+
UserID: user.ID,
107+
OAuthExpiry: database.Now().Add(time.Hour),
108+
OAuthAccessToken: "access-token",
109+
})
103110
template := dbgen.Template(t, srv.Database, database.Template{
104111
Name: "template",
105112
Provisioner: database.ProvisionerTypeEcho,
@@ -208,15 +215,16 @@ func TestAcquireJob(t *testing.T) {
208215
},
209216
},
210217
Metadata: &sdkproto.Provision_Metadata{
211-
CoderUrl: srv.AccessURL.String(),
212-
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
213-
WorkspaceName: workspace.Name,
214-
WorkspaceOwner: user.Username,
215-
WorkspaceOwnerEmail: user.Email,
216-
WorkspaceId: workspace.ID.String(),
217-
WorkspaceOwnerId: user.ID.String(),
218-
TemplateName: template.Name,
219-
TemplateVersion: version.Name,
218+
CoderUrl: srv.AccessURL.String(),
219+
WorkspaceTransition: sdkproto.WorkspaceTransition_START,
220+
WorkspaceName: workspace.Name,
221+
WorkspaceOwner: user.Username,
222+
WorkspaceOwnerEmail: user.Email,
223+
WorkspaceOwnerOidcAccessToken: link.OAuthAccessToken,
224+
WorkspaceId: workspace.ID.String(),
225+
WorkspaceOwnerId: user.ID.String(),
226+
TemplateName: template.Name,
227+
TemplateVersion: version.Name,
220228
},
221229
},
222230
})
@@ -1152,6 +1160,7 @@ func setup(t *testing.T, ignoreLogErrors bool) *provisionerdserver.Server {
11521160
return &provisionerdserver.Server{
11531161
ID: uuid.New(),
11541162
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
1163+
OIDCConfig: &oauth2.Config{},
11551164
AccessURL: &url.URL{},
11561165
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho},
11571166
Database: db,

enterprise/coderd/provisionerdaemons.go

+6
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,14 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
217217
return
218218
}
219219
mux := drpcmux.New()
220+
gitAuthProviders := make([]string, 0, len(api.GitAuthConfigs))
221+
for _, cfg := range api.GitAuthConfigs {
222+
gitAuthProviders = append(gitAuthProviders, cfg.ID)
223+
}
220224
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
221225
AccessURL: api.AccessURL,
226+
GitAuthProviders: gitAuthProviders,
227+
OIDCConfig: api.OIDCConfig,
222228
ID: daemon.ID,
223229
Database: api.Database,
224230
Pubsub: api.Pubsub,

provisioner/terraform/provision.go

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ func provisionEnv(config *proto.Provision_Config, params []*proto.ParameterValue
213213
"CODER_WORKSPACE_NAME="+config.Metadata.WorkspaceName,
214214
"CODER_WORKSPACE_OWNER="+config.Metadata.WorkspaceOwner,
215215
"CODER_WORKSPACE_OWNER_EMAIL="+config.Metadata.WorkspaceOwnerEmail,
216+
"CODER_WORKSPACE_OWNER_OIDC_ACCESS_TOKEN="+config.Metadata.WorkspaceOwnerOidcAccessToken,
216217
"CODER_WORKSPACE_ID="+config.Metadata.WorkspaceId,
217218
"CODER_WORKSPACE_OWNER_ID="+config.Metadata.WorkspaceOwnerId,
218219
)

0 commit comments

Comments
 (0)