Skip to content

Commit 03ee639

Browse files
authored
chore: remove duplicate validate calls on same oauth token (#11598)
* chore: remove duplicate validate calls on same oauth token
1 parent 8181c9f commit 03ee639

File tree

4 files changed

+119
-3
lines changed

4 files changed

+119
-3
lines changed

coderd/coderd.go

+9
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ type Options struct {
184184
// under the enterprise license, and can't be imported into AGPL.
185185
ParseLicenseClaims func(rawJWT string) (email string, trial bool, err error)
186186
AllowWorkspaceRenames bool
187+
188+
// NewTicker is used for unit tests to replace "time.NewTicker".
189+
NewTicker func(duration time.Duration) (tick <-chan time.Time, done func())
187190
}
188191

189192
// @title Coder API
@@ -208,6 +211,12 @@ func New(options *Options) *API {
208211
if options == nil {
209212
options = &Options{}
210213
}
214+
if options.NewTicker == nil {
215+
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
216+
ticker := time.NewTicker(duration)
217+
return ticker.C, ticker.Stop
218+
}
219+
}
211220

212221
// Safety check: if we're not running a unit test, we *must* have a Prometheus registry.
213222
if options.PrometheusRegistry == nil && flag.Lookup("test.v") == nil {

coderd/coderdtest/coderdtest.go

+2
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ type Options struct {
145145

146146
WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions
147147
AllowWorkspaceRenames bool
148+
NewTicker func(duration time.Duration) (<-chan time.Time, func())
148149
}
149150

150151
// New constructs a codersdk client connected to an in-memory API instance.
@@ -451,6 +452,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
451452
StatsBatcher: options.StatsBatcher,
452453
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
453454
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
455+
NewTicker: options.NewTicker,
454456
}
455457
}
456458

coderd/workspaceagents.go

+14-3
Original file line numberDiff line numberDiff line change
@@ -2051,13 +2051,14 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
20512051
if listen {
20522052
// Since we're ticking frequently and this sign-in operation is rare,
20532053
// we are OK with polling to avoid the complexity of pubsub.
2054-
ticker := time.NewTicker(time.Second)
2055-
defer ticker.Stop()
2054+
ticker, done := api.NewTicker(time.Second)
2055+
defer done()
2056+
var previousToken database.ExternalAuthLink
20562057
for {
20572058
select {
20582059
case <-ctx.Done():
20592060
return
2060-
case <-ticker.C:
2061+
case <-ticker:
20612062
}
20622063
externalAuthLink, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
20632064
ProviderID: externalAuthConfig.ID,
@@ -2081,6 +2082,15 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
20812082
if externalAuthLink.OAuthExpiry.Before(dbtime.Now()) && !externalAuthLink.OAuthExpiry.IsZero() {
20822083
continue
20832084
}
2085+
2086+
// Only attempt to revalidate an oauth token if it has actually changed.
2087+
// No point in trying to validate the same token over and over again.
2088+
if previousToken.OAuthAccessToken == externalAuthLink.OAuthAccessToken &&
2089+
previousToken.OAuthRefreshToken == externalAuthLink.OAuthRefreshToken &&
2090+
previousToken.OAuthExpiry == externalAuthLink.OAuthExpiry {
2091+
continue
2092+
}
2093+
20842094
valid, _, err := externalAuthConfig.ValidateToken(ctx, externalAuthLink.OAuthAccessToken)
20852095
if err != nil {
20862096
api.Logger.Warn(ctx, "failed to validate external auth token",
@@ -2089,6 +2099,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
20892099
slog.Error(err),
20902100
)
20912101
}
2102+
previousToken = externalAuthLink
20922103
if !valid {
20932104
continue
20942105
}

coderd/workspaceagents_test.go

+94
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,15 @@ import (
2525
"github.com/coder/coder/v2/agent/agenttest"
2626
"github.com/coder/coder/v2/coderd"
2727
"github.com/coder/coder/v2/coderd/coderdtest"
28+
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
2829
"github.com/coder/coder/v2/coderd/database"
2930
"github.com/coder/coder/v2/coderd/database/dbauthz"
3031
"github.com/coder/coder/v2/coderd/database/dbfake"
32+
"github.com/coder/coder/v2/coderd/database/dbgen"
3133
"github.com/coder/coder/v2/coderd/database/dbmem"
3234
"github.com/coder/coder/v2/coderd/database/dbtime"
3335
"github.com/coder/coder/v2/coderd/database/pubsub"
36+
"github.com/coder/coder/v2/coderd/externalauth"
3437
"github.com/coder/coder/v2/coderd/rbac"
3538
"github.com/coder/coder/v2/codersdk"
3639
"github.com/coder/coder/v2/codersdk/agentsdk"
@@ -1536,3 +1539,94 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) {
15361539
require.True(t, ok)
15371540
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
15381541
}
1542+
1543+
func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
1544+
t.Parallel()
1545+
1546+
// ValidateURLSpam acts as a workspace calling GIT_ASK_PASS which
1547+
// will wait until the external auth token is valid. The issue is we spam
1548+
// the validate endpoint with requests until the token is valid. We do this
1549+
// even if the token has not changed. We are calling validate with the
1550+
// same inputs expecting a different result (insanity?). To reduce our
1551+
// api rate limit usage, we should do nothing if the inputs have not
1552+
// changed.
1553+
//
1554+
// Note that an expired oauth token is already skipped, so this really
1555+
// only covers the case of a revoked token.
1556+
t.Run("ValidateURLSpam", func(t *testing.T) {
1557+
t.Parallel()
1558+
1559+
const providerID = "fake-idp"
1560+
1561+
// Count all the times we call validate
1562+
validateCalls := 0
1563+
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
1564+
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1565+
// Count all the validate calls
1566+
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
1567+
validateCalls++
1568+
}
1569+
handler.ServeHTTP(w, r)
1570+
}))
1571+
}))
1572+
1573+
ticks := make(chan time.Time)
1574+
// setup
1575+
ownerClient, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
1576+
NewTicker: func(duration time.Duration) (<-chan time.Time, func()) {
1577+
return ticks, func() {}
1578+
},
1579+
ExternalAuthConfigs: []*externalauth.Config{
1580+
fake.ExternalAuthConfig(t, providerID, nil, func(cfg *externalauth.Config) {
1581+
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
1582+
}),
1583+
},
1584+
})
1585+
first := coderdtest.CreateFirstUser(t, ownerClient)
1586+
tmpDir := t.TempDir()
1587+
client, user := coderdtest.CreateAnotherUser(t, ownerClient, first.OrganizationID)
1588+
1589+
r := dbfake.WorkspaceBuild(t, db, database.Workspace{
1590+
OrganizationID: first.OrganizationID,
1591+
OwnerID: user.ID,
1592+
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
1593+
agents[0].Directory = tmpDir
1594+
return agents
1595+
}).Do()
1596+
1597+
agentClient := agentsdk.New(client.URL)
1598+
agentClient.SetSessionToken(r.AgentToken)
1599+
1600+
// We need to include an invalid oauth token that is not expired.
1601+
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
1602+
ProviderID: providerID,
1603+
UserID: user.ID,
1604+
CreatedAt: dbtime.Now(),
1605+
UpdatedAt: dbtime.Now(),
1606+
OAuthAccessToken: "invalid",
1607+
OAuthRefreshToken: "bad",
1608+
OAuthExpiry: dbtime.Now().Add(time.Hour),
1609+
})
1610+
1611+
ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
1612+
go func() {
1613+
// The request that will block and fire off validate calls.
1614+
_, err := agentClient.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
1615+
ID: providerID,
1616+
Match: "",
1617+
Listen: true,
1618+
})
1619+
assert.Error(t, err, "this should fail")
1620+
}()
1621+
1622+
// Send off 10 ticks to cause 10 validate calls
1623+
for i := 0; i < 10; i++ {
1624+
ticks <- time.Now()
1625+
}
1626+
cancel()
1627+
// We expect only 1
1628+
// In a failed test, you will likely see 9, as the last one
1629+
// gets cancelled.
1630+
require.Equal(t, 1, validateCalls, "validate calls duplicated on same token")
1631+
})
1632+
}

0 commit comments

Comments
 (0)