Skip to content

chore: remove duplicate validate calls on same oauth token #11598

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ type Options struct {
// under the enterprise license, and can't be imported into AGPL.
ParseLicenseClaims func(rawJWT string) (email string, trial bool, err error)
AllowWorkspaceRenames bool

// NewTicker is used for unit tests to replace "time.NewTicker".
NewTicker func(duration time.Duration) (tick <-chan time.Time, done func())
}

// @title Coder API
Expand All @@ -208,6 +211,12 @@ func New(options *Options) *API {
if options == nil {
options = &Options{}
}
if options.NewTicker == nil {
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
ticker := time.NewTicker(duration)
return ticker.C, ticker.Stop
}
}
Comment on lines +214 to +219
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is OK for now but we should probably swap this out for something better later on 🕰️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Some fake clock solution would be ideal, then we can use across all tests more generally.


// Safety check: if we're not running a unit test, we *must* have a Prometheus registry.
if options.PrometheusRegistry == nil && flag.Lookup("test.v") == nil {
Expand Down
2 changes: 2 additions & 0 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ type Options struct {

WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions
AllowWorkspaceRenames bool
NewTicker func(duration time.Duration) (<-chan time.Time, func())
}

// New constructs a codersdk client connected to an in-memory API instance.
Expand Down Expand Up @@ -451,6 +452,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
StatsBatcher: options.StatsBatcher,
WorkspaceAppsStatsCollectorOptions: options.WorkspaceAppsStatsCollectorOptions,
AllowWorkspaceRenames: options.AllowWorkspaceRenames,
NewTicker: options.NewTicker,
}
}

Expand Down
17 changes: 14 additions & 3 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -2051,13 +2051,14 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
if listen {
// Since we're ticking frequently and this sign-in operation is rare,
// we are OK with polling to avoid the complexity of pubsub.
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
ticker, done := api.NewTicker(time.Second)
defer done()
var previousToken database.ExternalAuthLink
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
case <-ticker:
}
externalAuthLink, err := api.Database.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{
ProviderID: externalAuthConfig.ID,
Expand All @@ -2081,6 +2082,15 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
if externalAuthLink.OAuthExpiry.Before(dbtime.Now()) && !externalAuthLink.OAuthExpiry.IsZero() {
continue
}

// Only attempt to revalidate an oauth token if it has actually changed.
// No point in trying to validate the same token over and over again.
if previousToken.OAuthAccessToken == externalAuthLink.OAuthAccessToken &&
previousToken.OAuthRefreshToken == externalAuthLink.OAuthRefreshToken &&
previousToken.OAuthExpiry == externalAuthLink.OAuthExpiry {
continue
}

valid, _, err := externalAuthConfig.ValidateToken(ctx, externalAuthLink.OAuthAccessToken)
if err != nil {
api.Logger.Warn(ctx, "failed to validate external auth token",
Expand All @@ -2089,6 +2099,7 @@ func (api *API) workspaceAgentsExternalAuth(rw http.ResponseWriter, r *http.Requ
slog.Error(err),
)
}
previousToken = externalAuthLink
if !valid {
continue
}
Expand Down
94 changes: 94 additions & 0 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ import (
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmem"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/externalauth"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/agentsdk"
Expand Down Expand Up @@ -1536,3 +1539,94 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) {
require.True(t, ok)
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
}

func TestWorkspaceAgentExternalAuthListen(t *testing.T) {
t.Parallel()

// ValidateURLSpam acts as a workspace calling GIT_ASK_PASS which
// will wait until the external auth token is valid. The issue is we spam
// the validate endpoint with requests until the token is valid. We do this
// even if the token has not changed. We are calling validate with the
// same inputs expecting a different result (insanity?). To reduce our
// api rate limit usage, we should do nothing if the inputs have not
// changed.
//
// Note that an expired oauth token is already skipped, so this really
// only covers the case of a revoked token.
t.Run("ValidateURLSpam", func(t *testing.T) {
t.Parallel()

const providerID = "fake-idp"

// Count all the times we call validate
validateCalls := 0
fake := oidctest.NewFakeIDP(t, oidctest.WithServing(), oidctest.WithMiddlewares(func(handler http.Handler) http.Handler {
return http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Count all the validate calls
if strings.Contains(r.URL.Path, "/external-auth-validate/") {
validateCalls++
}
handler.ServeHTTP(w, r)
}))
}))

ticks := make(chan time.Time)
// setup
ownerClient, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
NewTicker: func(duration time.Duration) (<-chan time.Time, func()) {
return ticks, func() {}
},
ExternalAuthConfigs: []*externalauth.Config{
fake.ExternalAuthConfig(t, providerID, nil, func(cfg *externalauth.Config) {
cfg.Type = codersdk.EnhancedExternalAuthProviderGitHub.String()
}),
},
})
first := coderdtest.CreateFirstUser(t, ownerClient)
tmpDir := t.TempDir()
client, user := coderdtest.CreateAnotherUser(t, ownerClient, first.OrganizationID)

r := dbfake.WorkspaceBuild(t, db, database.Workspace{
OrganizationID: first.OrganizationID,
OwnerID: user.ID,
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
agents[0].Directory = tmpDir
return agents
}).Do()

agentClient := agentsdk.New(client.URL)
agentClient.SetSessionToken(r.AgentToken)

// We need to include an invalid oauth token that is not expired.
dbgen.ExternalAuthLink(t, db, database.ExternalAuthLink{
ProviderID: providerID,
UserID: user.ID,
CreatedAt: dbtime.Now(),
UpdatedAt: dbtime.Now(),
OAuthAccessToken: "invalid",
OAuthRefreshToken: "bad",
OAuthExpiry: dbtime.Now().Add(time.Hour),
})

ctx, cancel := context.WithCancel(testutil.Context(t, testutil.WaitShort))
go func() {
// The request that will block and fire off validate calls.
_, err := agentClient.ExternalAuth(ctx, agentsdk.ExternalAuthRequest{
ID: providerID,
Match: "",
Listen: true,
})
assert.Error(t, err, "this should fail")
}()

// Send off 10 ticks to cause 10 validate calls
for i := 0; i < 10; i++ {
ticks <- time.Now()
}
cancel()
// We expect only 1
// In a failed test, you will likely see 9, as the last one
// gets cancelled.
require.Equal(t, 1, validateCalls, "validate calls duplicated on same token")
})
}