diff --git a/coderd/telemetry/telemetry_test.go b/coderd/telemetry/telemetry_test.go index 6f97ce8a1270b..7de4c98e07fa8 100644 --- a/coderd/telemetry/telemetry_test.go +++ b/coderd/telemetry/telemetry_test.go @@ -1,6 +1,7 @@ package telemetry_test import ( + "context" "database/sql" "encoding/json" "net/http" @@ -115,7 +116,7 @@ func TestTelemetry(t *testing.T) { _ = dbgen.WorkspaceAgentMemoryResourceMonitor(t, db, database.WorkspaceAgentMemoryResourceMonitor{}) _ = dbgen.WorkspaceAgentVolumeResourceMonitor(t, db, database.WorkspaceAgentVolumeResourceMonitor{}) - _, snapshot := collectSnapshot(t, db, nil) + _, snapshot := collectSnapshot(ctx, t, db, nil) require.Len(t, snapshot.ProvisionerJobs, 1) require.Len(t, snapshot.Licenses, 1) require.Len(t, snapshot.Templates, 1) @@ -168,17 +169,19 @@ func TestTelemetry(t *testing.T) { }) t.Run("HashedEmail", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitMedium) db := dbmem.New() _ = dbgen.User(t, db, database.User{ Email: "kyle@coder.com", }) - _, snapshot := collectSnapshot(t, db, nil) + _, snapshot := collectSnapshot(ctx, t, db, nil) require.Len(t, snapshot.Users, 1) require.Equal(t, snapshot.Users[0].EmailHashed, "bb44bf07cf9a2db0554bba63a03d822c927deae77df101874496df5a6a3e896d@coder.com") }) t.Run("HashedModule", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitMedium) pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{}) _ = dbgen.WorkspaceModule(t, db, database.WorkspaceModule{ JobID: pj.ID, @@ -190,7 +193,7 @@ func TestTelemetry(t *testing.T) { Source: "https://internal-url.com/some-module", Version: "1.0.0", }) - _, snapshot := collectSnapshot(t, db, nil) + _, snapshot := collectSnapshot(ctx, t, db, nil) require.Len(t, snapshot.WorkspaceModules, 2) modules := snapshot.WorkspaceModules sort.Slice(modules, func(i, j int) bool { @@ -286,11 +289,11 @@ func TestTelemetry(t *testing.T) { db, _ := dbtestutil.NewDB(t) // 1. No org sync settings - deployment, _ := collectSnapshot(t, db, nil) + deployment, _ := collectSnapshot(ctx, t, db, nil) require.False(t, *deployment.IDPOrgSync) // 2. Org sync settings set in server flags - deployment, _ = collectSnapshot(t, db, func(opts telemetry.Options) telemetry.Options { + deployment, _ = collectSnapshot(ctx, t, db, func(opts telemetry.Options) telemetry.Options { opts.DeploymentConfig = &codersdk.DeploymentValues{ OIDC: codersdk.OIDCConfig{ OrganizationField: "organizations", @@ -312,7 +315,7 @@ func TestTelemetry(t *testing.T) { AssignDefault: true, }) require.NoError(t, err) - deployment, _ = collectSnapshot(t, db, nil) + deployment, _ = collectSnapshot(ctx, t, db, nil) require.True(t, *deployment.IDPOrgSync) }) } @@ -320,8 +323,9 @@ func TestTelemetry(t *testing.T) { // nolint:paralleltest func TestTelemetryInstallSource(t *testing.T) { t.Setenv("CODER_TELEMETRY_INSTALL_SOURCE", "aws_marketplace") + ctx := testutil.Context(t, testutil.WaitMedium) db := dbmem.New() - deployment, _ := collectSnapshot(t, db, nil) + deployment, _ := collectSnapshot(ctx, t, db, nil) require.Equal(t, "aws_marketplace", deployment.InstallSource) } @@ -436,7 +440,7 @@ func TestRecordTelemetryStatus(t *testing.T) { } } -func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, chan *telemetry.Snapshot) { +func mockTelemetryServer(ctx context.Context, t *testing.T) (*url.URL, chan *telemetry.Deployment, chan *telemetry.Snapshot) { t.Helper() deployment := make(chan *telemetry.Deployment, 64) snapshot := make(chan *telemetry.Snapshot, 64) @@ -446,7 +450,11 @@ func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, ch dd := &telemetry.Deployment{} err := json.NewDecoder(r.Body).Decode(dd) require.NoError(t, err) - deployment <- dd + ok := testutil.AssertSend(ctx, t, deployment, dd) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } // Ensure the header is sent only after deployment is sent w.WriteHeader(http.StatusAccepted) }) @@ -455,7 +463,11 @@ func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, ch ss := &telemetry.Snapshot{} err := json.NewDecoder(r.Body).Decode(ss) require.NoError(t, err) - snapshot <- ss + ok := testutil.AssertSend(ctx, t, snapshot, ss) + if !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } // Ensure the header is sent only after snapshot is sent w.WriteHeader(http.StatusAccepted) }) @@ -467,10 +479,15 @@ func mockTelemetryServer(t *testing.T) (*url.URL, chan *telemetry.Deployment, ch return serverURL, deployment, snapshot } -func collectSnapshot(t *testing.T, db database.Store, addOptionsFn func(opts telemetry.Options) telemetry.Options) (*telemetry.Deployment, *telemetry.Snapshot) { +func collectSnapshot( + ctx context.Context, + t *testing.T, + db database.Store, + addOptionsFn func(opts telemetry.Options) telemetry.Options, +) (*telemetry.Deployment, *telemetry.Snapshot) { t.Helper() - serverURL, deployment, snapshot := mockTelemetryServer(t) + serverURL, deployment, snapshot := mockTelemetryServer(ctx, t) options := telemetry.Options{ Database: db, @@ -485,5 +502,6 @@ func collectSnapshot(t *testing.T, db database.Store, addOptionsFn func(opts tel reporter, err := telemetry.New(options) require.NoError(t, err) t.Cleanup(reporter.Close) - return <-deployment, <-snapshot + + return testutil.RequireReceive(ctx, t, deployment), testutil.RequireReceive(ctx, t, snapshot) } diff --git a/testutil/chan.go b/testutil/chan.go index a6766a1a49053..3a06f03ab4a02 100644 --- a/testutil/chan.go +++ b/testutil/chan.go @@ -55,3 +55,61 @@ func RequireSend[A any](ctx context.Context, t testing.TB, c chan<- A, a A) { // OK! } } + +// SoftTryReceive will attempt to receive a value from the chan and return it. If +// the context expires before a value can be received, it will mark the test as +// failed but continue execution. If the channel is closed, the zero value of the +// channel type will be returned. +// The second return value indicates whether the receive was successful. In +// particular, if the channel is closed, the second return value will be true. +// +// Safety: can be called from any goroutine. +func SoftTryReceive[A any](ctx context.Context, t testing.TB, c <-chan A) (A, bool) { + t.Helper() + select { + case <-ctx.Done(): + t.Error("timeout") + var a A + return a, false + case a := <-c: + return a, true + } +} + +// AssertReceive will receive a value from the chan and return it. If the +// context expires or the channel is closed before a value can be received, +// it will mark the test as failed but continue execution. +// The second return value indicates whether the receive was successful. +// +// Safety: can be called from any goroutine. +func AssertReceive[A any](ctx context.Context, t testing.TB, c <-chan A) (A, bool) { + t.Helper() + select { + case <-ctx.Done(): + t.Error("timeout") + var a A + return a, false + case a, ok := <-c: + if !ok { + t.Error("channel closed") + } + return a, ok + } +} + +// AssertSend will send the given value over the chan and then return. If +// the context expires before the send succeeds, it will mark the test as failed +// but continue execution. +// The second return value indicates whether the send was successful. +// +// Safety: can be called from any goroutine. +func AssertSend[A any](ctx context.Context, t testing.TB, c chan<- A, a A) bool { + t.Helper() + select { + case <-ctx.Done(): + t.Error("timeout") + return false + case c <- a: + return true + } +}