From 36bfcbd9838b800c9e401908587c7775d0323333 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Mon, 18 Dec 2023 18:13:18 +0000 Subject: [PATCH 1/7] chore: add agentapi tests --- coderd/agentapi/api.go | 38 +-- coderd/agentapi/apps.go | 8 +- coderd/agentapi/apps_test.go | 253 ++++++++++++++ coderd/agentapi/lifecycle.go | 27 +- coderd/agentapi/lifecycle_test.go | 460 ++++++++++++++++++++++++++ coderd/agentapi/logs.go | 40 ++- coderd/agentapi/logs_test.go | 427 ++++++++++++++++++++++++ coderd/agentapi/manifest.go | 54 ++- coderd/agentapi/manifest_test.go | 392 ++++++++++++++++++++++ coderd/agentapi/servicebanner_test.go | 84 +++++ coderd/database/db2sdk/db2sdk.go | 21 +- coderd/workspaceagentsrpc.go | 17 +- 12 files changed, 1729 insertions(+), 92 deletions(-) create mode 100644 coderd/agentapi/apps_test.go create mode 100644 coderd/agentapi/lifecycle_test.go create mode 100644 coderd/agentapi/logs_test.go create mode 100644 coderd/agentapi/manifest_test.go create mode 100644 coderd/agentapi/servicebanner_test.go diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index 57cb859aafe2a..a97e76efb16be 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -25,7 +25,6 @@ import ( "github.com/coder/coder/v2/coderd/schedule" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk/agentsdk" - "github.com/coder/coder/v2/tailnet" ) const AgentAPIVersionDRPC = "2.0" @@ -58,21 +57,18 @@ type Options struct { Database database.Store Pubsub pubsub.Pubsub DerpMapFn func() *tailcfg.DERPMap - TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] StatsBatcher *batchstats.Batcher PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) - AccessURL *url.URL - AppHostname string - AgentInactiveDisconnectTimeout time.Duration - AgentFallbackTroubleshootingURL string - AgentStatsRefreshInterval time.Duration - DisableDirectConnections bool - DerpForceWebSockets bool - DerpMapUpdateFrequency time.Duration - ExternalAuthConfigs []*externalauth.Config + AccessURL *url.URL + AppHostname string + AgentStatsRefreshInterval time.Duration + DisableDirectConnections bool + DerpForceWebSockets bool + DerpMapUpdateFrequency time.Duration + ExternalAuthConfigs []*externalauth.Config // Optional: // WorkspaceID avoids a future lookup to find the workspace ID by setting @@ -89,17 +85,15 @@ func New(opts Options) *API { } api.ManifestAPI = &ManifestAPI{ - AccessURL: opts.AccessURL, - AppHostname: opts.AppHostname, - AgentInactiveDisconnectTimeout: opts.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: opts.AgentFallbackTroubleshootingURL, - ExternalAuthConfigs: opts.ExternalAuthConfigs, - DisableDirectConnections: opts.DisableDirectConnections, - DerpForceWebSockets: opts.DerpForceWebSockets, - AgentFn: api.agent, - Database: opts.Database, - DerpMapFn: opts.DerpMapFn, - TailnetCoordinator: opts.TailnetCoordinator, + AccessURL: opts.AccessURL, + AppHostname: opts.AppHostname, + ExternalAuthConfigs: opts.ExternalAuthConfigs, + DisableDirectConnections: opts.DisableDirectConnections, + DerpForceWebSockets: opts.DerpForceWebSockets, + AgentFn: api.agent, + WorkspaceIDFn: api.workspaceID, + Database: opts.Database, + DerpMapFn: opts.DerpMapFn, } api.ServiceBannerAPI = &ServiceBannerAPI{ diff --git a/coderd/agentapi/apps.go b/coderd/agentapi/apps.go index 1346d7a9b4bcb..7e8bda1262426 100644 --- a/coderd/agentapi/apps.go +++ b/coderd/agentapi/apps.go @@ -90,9 +90,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat } } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return &agentproto.BatchUpdateAppHealthResponse{}, nil } diff --git a/coderd/agentapi/apps_test.go b/coderd/agentapi/apps_test.go new file mode 100644 index 0000000000000..8d0b802063bfe --- /dev/null +++ b/coderd/agentapi/apps_test.go @@ -0,0 +1,253 @@ +package agentapi_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/sloggers/slogtest" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" +) + +func TestBatchUpdateAppHealths(t *testing.T) { + t.Parallel() + + var ( + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + app1 = database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-1", + DisplayName: "code-server 1", + HealthcheckUrl: "http://localhost:3000", + Health: database.WorkspaceAppHealthInitializing, + } + app2 = database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-2", + DisplayName: "code-server 2", + HealthcheckUrl: "http://localhost:3001", + Health: database.WorkspaceAppHealthHealthy, + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{ + ID: app1.ID, + Health: database.WorkspaceAppHealthHealthy, + }).Return(nil) + dbM.EXPECT().UpdateWorkspaceAppHealthByID(gomock.Any(), database.UpdateWorkspaceAppHealthByIDParams{ + ID: app2.ID, + Health: database.WorkspaceAppHealthUnhealthy, + }).Return(nil) + + var publishCalled int64 + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + // Set both to healthy, only one should be updated in the DB. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + { + Id: app2.ID[:], + Health: agentproto.AppHealth_UNHEALTHY, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.EqualValues(t, 1, atomic.LoadInt64(&publishCalled)) + }) + + t.Run("Unchanged", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + var publishCalled int64 + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + // Set both to their current status, neither should be updated in the + // DB. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: agentproto.AppHealth_INITIALIZING, + }, + { + Id: app2.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.EqualValues(t, 0, atomic.LoadInt64(&publishCalled)) + }) + + t.Run("Empty", func(t *testing.T) { + t.Parallel() + + // No DB queries are made if there are no updates to process. + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishCalled int64 + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + // Do nothing. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{}, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) + + require.EqualValues(t, 0, atomic.LoadInt64(&publishCalled)) + }) + + t.Run("AppNoHealthcheck", func(t *testing.T) { + t.Parallel() + + app3 := database.WorkspaceApp{ + ID: uuid.New(), + AgentID: agent.ID, + Slug: "code-server-3", + DisplayName: "code-server 3", + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app3}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set app3 to healthy, should error. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app3.ID[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "does not have healthchecks enabled") + require.Nil(t, resp) + }) + + t.Run("UnknownApp", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set an unknown app to healthy, should error. + id := uuid.New() + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: id[:], + Health: agentproto.AppHealth_HEALTHY, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "not found") + require.Nil(t, resp) + }) + + t.Run("InvalidHealth", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) + + api := &agentapi.AppsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + } + + // Set an unknown app to healthy, should error. + resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ + Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ + { + Id: app1.ID[:], + Health: -999, + }, + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "unknown health status") + require.Nil(t, resp) + }) +} diff --git a/coderd/agentapi/lifecycle.go b/coderd/agentapi/lifecycle.go index d909d35eb8f4a..662d0c0c2e28e 100644 --- a/coderd/agentapi/lifecycle.go +++ b/coderd/agentapi/lifecycle.go @@ -3,6 +3,7 @@ package agentapi import ( "context" "database/sql" + "time" "github.com/google/uuid" "golang.org/x/mod/semver" @@ -21,6 +22,15 @@ type LifecycleAPI struct { Database database.Store Log slog.Logger PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *LifecycleAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.UpdateLifecycleRequest) (*agentproto.Lifecycle, error) { @@ -68,7 +78,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda changedAt := req.Lifecycle.ChangedAt.AsTime() if changedAt.IsZero() { - changedAt = dbtime.Now() + changedAt = a.now() req.Lifecycle.ChangedAt = timestamppb.New(changedAt) } dbChangedAt := sql.NullTime{Time: changedAt, Valid: true} @@ -78,8 +88,13 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda switch lifecycleState { case database.WorkspaceAgentLifecycleStateStarting: startedAt = dbChangedAt - readyAt.Valid = false // This agent is re-starting, so it's not ready yet. + // This agent is (re)starting, so it's not ready yet. + readyAt.Time = time.Time{} + readyAt.Valid = false case database.WorkspaceAgentLifecycleStateReady, database.WorkspaceAgentLifecycleStateStartError: + if !startedAt.Valid { + startedAt = dbChangedAt + } readyAt = dbChangedAt } @@ -97,9 +112,11 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda return nil, xerrors.Errorf("update workspace agent lifecycle state: %w", err) } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return req.Lifecycle, nil diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go new file mode 100644 index 0000000000000..9029a84b955eb --- /dev/null +++ b/coderd/agentapi/lifecycle_test.go @@ -0,0 +1,460 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" +) + +func TestUpdateLifecycle(t *testing.T) { + t.Parallel() + + someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z") + require.NoError(t, err) + someTime = dbtime.Time(someTime) + now := dbtime.Now() + + var ( + workspaceID = uuid.New() + agentCreated = database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleStateCreated, + StartedAt: sql.NullTime{Valid: false}, + ReadyAt: sql.NullTime{Valid: false}, + } + agentStarting = database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleStateStarting, + StartedAt: sql.NullTime{Valid: true, Time: someTime}, + ReadyAt: sql.NullTime{Valid: false}, + } + ) + + t.Run("OKStarting", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_STARTING, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateStarting, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{Valid: false}, + }).Return(nil) + + var publishCalled int64 + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.Equal(t, int64(1), atomic.LoadInt64(&publishCalled)) + }) + + t.Run("OKReadying", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentStarting.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: agentStarting.StartedAt, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }).Return(nil) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentStarting, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Test that nil publish fn works. + PublishWorkspaceUpdateFn: nil, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + }) + + // This test jumps from CREATING to READY, skipping STARTED. Both the + // StartedAt and ReadyAt fields should be set. + t.Run("OKStraightToReady", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }).Return(nil) + + var publishCalled int64 + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.Equal(t, int64(1), atomic.LoadInt64(&publishCalled)) + }) + + t.Run("NoTimeSpecified", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: agentproto.Lifecycle_READY, + // Zero time + ChangedAt: timestamppb.New(time.Time{}), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + now := dbtime.Now() + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agentCreated.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{ + Time: now, + Valid: true, + }, + ReadyAt: sql.NullTime{ + Time: now, + Valid: true, + }, + }) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: nil, + TimeNowFn: func() time.Time { + return now + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + }) + + t.Run("AllStates", func(t *testing.T) { + t.Parallel() + + agent := database.WorkspaceAgent{ + ID: uuid.New(), + LifecycleState: database.WorkspaceAgentLifecycleState(""), + StartedAt: sql.NullTime{Valid: false}, + ReadyAt: sql.NullTime{Valid: false}, + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishCalled int64 + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + states := []agentproto.Lifecycle_State{ + agentproto.Lifecycle_CREATED, + agentproto.Lifecycle_STARTING, + agentproto.Lifecycle_START_TIMEOUT, + agentproto.Lifecycle_START_ERROR, + agentproto.Lifecycle_READY, + agentproto.Lifecycle_SHUTTING_DOWN, + agentproto.Lifecycle_SHUTDOWN_TIMEOUT, + agentproto.Lifecycle_SHUTDOWN_ERROR, + agentproto.Lifecycle_OFF, + } + for i, state := range states { + t.Log("state", state) + now := now.Add(time.Hour * time.Duration(i)) + lifecycle := &agentproto.Lifecycle{ + State: state, + ChangedAt: timestamppb.New(now), + } + + expectedStartedAt := agent.StartedAt + expectedReadyAt := agent.ReadyAt + if state == agentproto.Lifecycle_STARTING { + expectedStartedAt = sql.NullTime{Valid: true, Time: now} + } + if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_ERROR { + expectedReadyAt = sql.NullTime{Valid: true, Time: now} + } + + dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ + ID: agent.ID, + LifecycleState: database.WorkspaceAgentLifecycleState(strings.ToLower(state.String())), + StartedAt: expectedStartedAt, + ReadyAt: expectedReadyAt, + }).Times(1).Return(nil) + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.NoError(t, err) + require.Equal(t, lifecycle, resp) + require.Equal(t, int64(i+1), atomic.LoadInt64(&publishCalled)) + + // For future iterations: + agent.StartedAt = expectedStartedAt + agent.ReadyAt = expectedReadyAt + } + }) + + t.Run("UnknownLifecycleState", func(t *testing.T) { + t.Parallel() + + lifecycle := &agentproto.Lifecycle{ + State: -999, + ChangedAt: timestamppb.New(now), + } + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishCalled int64 + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agentCreated, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + atomic.AddInt64(&publishCalled, 1) + return nil + }, + } + + resp, err := api.UpdateLifecycle(context.Background(), &agentproto.UpdateLifecycleRequest{ + Lifecycle: lifecycle, + }) + require.Error(t, err) + require.ErrorContains(t, err, "unknown lifecycle state") + require.Nil(t, resp) + require.Equal(t, int64(0), atomic.LoadInt64(&publishCalled)) + }) +} + +func TestUpdateStartup(t *testing.T) { + t.Parallel() + + var ( + workspaceID = uuid.New() + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "v1.2.3", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{ + agentproto.Startup_ENVBOX, + agentproto.Startup_ENVBUILDER, + agentproto.Startup_EXECTRACE, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentStartupByID(gomock.Any(), database.UpdateWorkspaceAgentStartupByIDParams{ + ID: agent.ID, + Version: startup.Version, + ExpandedDirectory: startup.ExpandedDirectory, + Subsystems: []database.WorkspaceAgentSubsystem{ + database.WorkspaceAgentSubsystemEnvbox, + database.WorkspaceAgentSubsystemEnvbuilder, + database.WorkspaceAgentSubsystemExectrace, + }, + APIVersion: agentapi.AgentAPIVersionDRPC, + }).Return(nil) + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.NoError(t, err) + require.Equal(t, startup, resp) + }) + + t.Run("BadVersion", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "asdf", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{}, + } + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.Error(t, err) + require.ErrorContains(t, err, "invalid agent semver version") + require.Nil(t, resp) + }) + + t.Run("BadSubsystem", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LifecycleAPI{ + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, agent *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Not used by UpdateStartup. + PublishWorkspaceUpdateFn: nil, + } + + startup := &agentproto.Startup{ + Version: "v1.2.3", + ExpandedDirectory: "/path/to/expanded/dir", + Subsystems: []agentproto.Startup_Subsystem{ + agentproto.Startup_ENVBOX, + -999, + }, + } + + resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + Startup: startup, + }) + require.Error(t, err) + require.ErrorContains(t, err, "invalid agent subsystem") + require.Nil(t, resp) + }) +} diff --git a/coderd/agentapi/logs.go b/coderd/agentapi/logs.go index 7d34b41e13201..cb3a920b9a63b 100644 --- a/coderd/agentapi/logs.go +++ b/coderd/agentapi/logs.go @@ -2,6 +2,7 @@ package agentapi import ( "context" + "time" "github.com/google/uuid" "golang.org/x/xerrors" @@ -19,6 +20,15 @@ type LogsAPI struct { Log slog.Logger PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *LogsAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCreateLogsRequest) (*agentproto.BatchCreateLogsResponse, error) { @@ -26,6 +36,9 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if err != nil { return nil, err } + if workspaceAgent.LogsOverflowed { + return nil, xerrors.New("workspace agent logs overflowed") + } if len(req.Logs) == 0 { return &agentproto.BatchCreateLogsResponse{}, nil @@ -42,7 +55,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea // Use the external log source externalSources, err := a.Database.InsertWorkspaceAgentLogSources(ctx, database.InsertWorkspaceAgentLogSourcesParams{ WorkspaceAgentID: workspaceAgent.ID, - CreatedAt: dbtime.Now(), + CreatedAt: a.now(), ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, DisplayName: []string{"External"}, Icon: []string{"/emojis/1f310.png"}, @@ -88,7 +101,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea logs, err := a.Database.InsertWorkspaceAgentLogs(ctx, database.InsertWorkspaceAgentLogsParams{ AgentID: workspaceAgent.ID, - CreatedAt: dbtime.Now(), + CreatedAt: a.now(), Output: output, Level: level, LogSourceID: logSourceID, @@ -98,9 +111,6 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if !database.IsWorkspaceAgentLogsLimitError(err) { return nil, xerrors.Errorf("insert workspace agent logs: %w", err) } - if workspaceAgent.LogsOverflowed { - return nil, xerrors.New("workspace agent logs overflowed") - } err := a.Database.UpdateWorkspaceAgentLogOverflowByID(ctx, database.UpdateWorkspaceAgentLogOverflowByIDParams{ ID: workspaceAgent.ID, LogsOverflowed: true, @@ -112,21 +122,25 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea a.Log.Warn(ctx, "failed to update workspace agent log overflow", slog.Error(err)) } - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + if a.PublishWorkspaceUpdateFn != nil { + err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + if err != nil { + return nil, xerrors.Errorf("publish workspace update: %w", err) + } } return nil, xerrors.New("workspace agent log limit exceeded") } // Publish by the lowest log ID inserted so the log stream will fetch // everything from that point. - lowestLogID := logs[0].ID - a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{ - CreatedAfter: lowestLogID - 1, - }) + if a.PublishWorkspaceAgentLogsUpdateFn != nil { + lowestLogID := logs[0].ID + a.PublishWorkspaceAgentLogsUpdateFn(ctx, workspaceAgent.ID, agentsdk.LogsNotifyMessage{ + CreatedAfter: lowestLogID - 1, + }) + } - if workspaceAgent.LogsLength == 0 { + if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil { // If these are the first logs being appended, we publish a UI update // to notify the UI that logs are now available. err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go new file mode 100644 index 0000000000000..1d4261a0191ea --- /dev/null +++ b/coderd/agentapi/logs_test.go @@ -0,0 +1,427 @@ +package agentapi_test + +import ( + "context" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/lib/pq" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/codersdk/agentsdk" +) + +func TestBatchCreateLogs(t *testing.T) { + t.Parallel() + + var ( + agent = database.WorkspaceAgent{ + ID: uuid.New(), + } + logSource = database.WorkspaceAgentLogSource{ + WorkspaceAgentID: agent.ID, + CreatedAt: dbtime.Now(), + ID: uuid.New(), + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + now := dbtime.Now() + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + + // Check the message content, should be for -1 since the lowest + // log we inserted was 0. + require.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg) + }, + TimeNowFn: func() time.Time { return now }, + } + + req := &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(now), + Level: agentproto.Log_TRACE, + Output: "log line 1", + }, + { + CreatedAt: timestamppb.New(now.Add(time.Hour)), + Level: agentproto.Log_DEBUG, + Output: "log line 2", + }, + { + CreatedAt: timestamppb.New(now.Add(2 * time.Hour)), + Level: agentproto.Log_INFO, + Output: "log line 3", + }, + { + CreatedAt: timestamppb.New(now.Add(3 * time.Hour)), + Level: agentproto.Log_WARN, + Output: "log line 4", + }, + { + CreatedAt: timestamppb.New(now.Add(4 * time.Hour)), + Level: agentproto.Log_ERROR, + Output: "log line 5", + }, + { + CreatedAt: timestamppb.New(now.Add(5 * time.Hour)), + Level: -999, // defaults to INFO + Output: "log line 6", + }, + }, + } + + // Craft expected DB request and response dynamically. + insertWorkspaceAgentLogsParams := database.InsertWorkspaceAgentLogsParams{ + AgentID: agent.ID, + LogSourceID: logSource.ID, + CreatedAt: now, + Output: make([]string, len(req.Logs)), + Level: make([]database.LogLevel, len(req.Logs)), + OutputLength: 0, + } + insertWorkspaceAgentLogsReturn := make([]database.WorkspaceAgentLog, len(req.Logs)) + for i, logEntry := range req.Logs { + insertWorkspaceAgentLogsParams.Output[i] = logEntry.Output + level := database.LogLevelInfo + if logEntry.Level >= 0 { + level = database.LogLevel(strings.ToLower(logEntry.Level.String())) + } + insertWorkspaceAgentLogsParams.Level[i] = level + insertWorkspaceAgentLogsParams.OutputLength += int32(len(logEntry.Output)) + + insertWorkspaceAgentLogsReturn[i] = database.WorkspaceAgentLog{ + AgentID: agent.ID, + CreatedAt: logEntry.CreatedAt.AsTime(), + ID: int64(i), + Output: logEntry.Output, + Level: insertWorkspaceAgentLogsParams.Level[i], + LogSourceID: logSource.ID, + } + } + + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), insertWorkspaceAgentLogsParams).Return(insertWorkspaceAgentLogsReturn, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) + + t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) { + t.Parallel() + + agentWithLogs := agent + agentWithLogs.LogsLength = 1 + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agentWithLogs, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + }, + } + + // Don't really care about the DB call. + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return([]database.WorkspaceAgentLog{ + { + ID: 1, + }, + }, nil) + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(dbtime.Now()), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) + + t.Run("AlreadyOverflowed", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + overflowedAgent := agent + overflowedAgent.LogsOverflowed = true + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return overflowedAgent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + }, + } + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{}, + }) + require.Error(t, err) + require.ErrorContains(t, err, "workspace agent logs overflowed") + require.Nil(t, resp) + require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) + + t.Run("InvalidLogSourceID", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + // Test that they are ignored when nil. + PublishWorkspaceUpdateFn: nil, + PublishWorkspaceAgentLogsUpdateFn: nil, + } + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: []byte("invalid"), + Logs: []*agentproto.Log{ + {}, // need at least 1 log + }, + }) + require.Error(t, err) + require.ErrorContains(t, err, "parse log source ID") + require.Nil(t, resp) + }) + + t.Run("UseExternalLogSourceID", func(t *testing.T) { + t.Parallel() + + now := dbtime.Now() + req := &agentproto.BatchCreateLogsRequest{ + LogSourceId: uuid.Nil[:], // defaults to "external" + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(now), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + } + dbInsertParams := database.InsertWorkspaceAgentLogsParams{ + AgentID: agent.ID, + LogSourceID: agentsdk.ExternalLogSourceID, + CreatedAt: now, + Output: []string{"hello world"}, + Level: []database.LogLevel{database.LogLevelInfo}, + OutputLength: int32(len(req.Logs[0].Output)), + } + dbInsertRes := []database.WorkspaceAgentLog{ + { + AgentID: agent.ID, + CreatedAt: now, + ID: 1, + Output: "hello world", + Level: database.LogLevelInfo, + LogSourceID: agentsdk.ExternalLogSourceID, + }, + } + + t.Run("Create", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + }, + TimeNowFn: func() time.Time { return now }, + } + + dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{ + WorkspaceAgentID: agent.ID, + CreatedAt: now, + ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, + DisplayName: []string{"External"}, + Icon: []string{"/emojis/1f310.png"}, + }).Return([]database.WorkspaceAgentLogSource{ + { + // only the ID field is used + ID: agentsdk.ExternalLogSourceID, + }, + }, nil) + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) + + t.Run("Exists", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + }, + TimeNowFn: func() time.Time { return now }, + } + + // Return a unique violation error to simulate the log source + // already existing. This should be handled gracefully. + logSourceInsertErr := &pq.Error{ + Code: pq.ErrorCode("23505"), // unique_violation + Constraint: string(database.UniqueWorkspaceAgentLogSourcesPkey), + } + dbM.EXPECT().InsertWorkspaceAgentLogSources(gomock.Any(), database.InsertWorkspaceAgentLogSourcesParams{ + WorkspaceAgentID: agent.ID, + CreatedAt: now, + ID: []uuid.UUID{agentsdk.ExternalLogSourceID}, + DisplayName: []string{"External"}, + Icon: []string{"/emojis/1f310.png"}, + }).Return([]database.WorkspaceAgentLogSource{}, logSourceInsertErr) + + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), dbInsertParams).Return(dbInsertRes, nil) + + resp, err := api.BatchCreateLogs(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) + }) + + t.Run("Overflow", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + + var publishWorkspaceUpdateCalled int64 + var publishWorkspaceAgentLogsUpdateCalled int64 + api := &agentapi.LogsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Log: slogtest.Make(t, nil), + PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + return nil + }, + PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { + atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + }, + } + + // Don't really care about the DB call params, just want to return an + // error. + dbErr := &pq.Error{ + Constraint: "max_logs_length", + Table: "workspace_agents", + } + dbM.EXPECT().InsertWorkspaceAgentLogs(gomock.Any(), gomock.Any()).Return(nil, dbErr) + + // Should also update the workspace agent. + dbM.EXPECT().UpdateWorkspaceAgentLogOverflowByID(gomock.Any(), database.UpdateWorkspaceAgentLogOverflowByIDParams{ + ID: agent.ID, + LogsOverflowed: true, + }).Return(nil) + + resp, err := api.BatchCreateLogs(context.Background(), &agentproto.BatchCreateLogsRequest{ + LogSourceId: logSource.ID[:], + Logs: []*agentproto.Log{ + { + CreatedAt: timestamppb.New(dbtime.Now()), + Level: agentproto.Log_INFO, + Output: "hello world", + }, + }, + }) + require.Error(t, err) + require.Nil(t, resp) + require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) + require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + }) +} diff --git a/coderd/agentapi/manifest.go b/coderd/agentapi/manifest.go index 7304899ceb02c..4b4ea7e7b64c7 100644 --- a/coderd/agentapi/manifest.go +++ b/coderd/agentapi/manifest.go @@ -6,7 +6,6 @@ import ( "fmt" "net/url" "strings" - "sync/atomic" "time" "github.com/google/uuid" @@ -26,18 +25,16 @@ import ( ) type ManifestAPI struct { - AccessURL *url.URL - AppHostname string - AgentInactiveDisconnectTimeout time.Duration - AgentFallbackTroubleshootingURL string - ExternalAuthConfigs []*externalauth.Config - DisableDirectConnections bool - DerpForceWebSockets bool - - AgentFn func(context.Context) (database.WorkspaceAgent, error) - Database database.Store - DerpMapFn func() *tailcfg.DERPMap - TailnetCoordinator *atomic.Pointer[tailnet.Coordinator] + AccessURL *url.URL + AppHostname string + ExternalAuthConfigs []*externalauth.Config + DisableDirectConnections bool + DerpForceWebSockets bool + + AgentFn func(context.Context) (database.WorkspaceAgent, error) + WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) + Database database.Store + DerpMapFn func() *tailcfg.DERPMap } func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifestRequest) (*agentproto.Manifest, error) { @@ -45,21 +42,15 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest if err != nil { return nil, err } - - apiAgent, err := db2sdk.WorkspaceAgent( - a.DerpMapFn(), *a.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, a.AgentInactiveDisconnectTimeout, - a.AgentFallbackTroubleshootingURL, - ) + workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) if err != nil { - return nil, xerrors.Errorf("converting workspace agent: %w", err) + return nil, err } var ( dbApps []database.WorkspaceApp scripts []database.WorkspaceAgentScript metadata []database.WorkspaceAgentMetadatum - resource database.WorkspaceResource - build database.WorkspaceBuild workspace database.Workspace owner database.User ) @@ -80,20 +71,12 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest eg.Go(func() (err error) { metadata, err = a.Database.GetWorkspaceAgentMetadata(ctx, database.GetWorkspaceAgentMetadataParams{ WorkspaceAgentID: workspaceAgent.ID, - Keys: nil, + Keys: nil, // all }) return err }) eg.Go(func() (err error) { - resource, err = a.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) - if err != nil { - return xerrors.Errorf("getting resource by id: %w", err) - } - build, err = a.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - return xerrors.Errorf("getting workspace build by job id: %w", err) - } - workspace, err = a.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + workspace, err = a.Database.GetWorkspaceByID(ctx, workspaceID) if err != nil { return xerrors.Errorf("getting workspace by id: %w", err) } @@ -122,6 +105,11 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest vscodeProxyURI += fmt.Sprintf(":%s", a.AccessURL.Port()) } + envs, err := db2sdk.WorkspaceAgentEnvironment(workspaceAgent) + if err != nil { + return nil, err + } + var gitAuthConfigs uint32 for _, cfg := range a.ExternalAuthConfigs { if codersdk.EnhancedExternalAuthProvider(cfg.Type).Git() { @@ -139,8 +127,8 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest OwnerUsername: owner.Username, WorkspaceId: workspace.ID[:], GitAuthConfigs: gitAuthConfigs, - EnvironmentVariables: apiAgent.EnvironmentVariables, - Directory: apiAgent.Directory, + EnvironmentVariables: envs, + Directory: workspaceAgent.Directory, VsCodePortProxyUri: vscodeProxyURI, MotdPath: workspaceAgent.MOTDFile, DisableDirectConnections: a.DisableDirectConnections, diff --git a/coderd/agentapi/manifest_test.go b/coderd/agentapi/manifest_test.go new file mode 100644 index 0000000000000..008618c12c57b --- /dev/null +++ b/coderd/agentapi/manifest_test.go @@ -0,0 +1,392 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/url" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "tailscale.com/tailcfg" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/externalauth" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/tailnet" +) + +func TestGetManifest(t *testing.T) { + t.Parallel() + + someTime, err := time.Parse(time.RFC3339, "2023-01-01T00:00:00Z") + require.NoError(t, err) + someTime = dbtime.Time(someTime) + + expectedEnvVars := map[string]string{ + "FOO": "bar", + "COOL_ENV": "dean was here", + } + expectedEnvVarsJSON, err := json.Marshal(expectedEnvVars) + require.NoError(t, err) + + var ( + owner = database.User{ + ID: uuid.New(), + Username: "cool-user", + } + workspace = database.Workspace{ + ID: uuid.New(), + OwnerID: owner.ID, + Name: "cool-workspace", + } + agent = database.WorkspaceAgent{ + ID: uuid.New(), + Name: "cool-agent", + EnvironmentVariables: pqtype.NullRawMessage{ + RawMessage: expectedEnvVarsJSON, + Valid: true, + }, + Directory: "/cool/dir", + MOTDFile: "/cool/motd", + } + apps = []database.WorkspaceApp{ + { + ID: uuid.New(), + Url: sql.NullString{String: "http://localhost:1234", Valid: true}, + External: false, + Slug: "cool-app-1", + DisplayName: "app 1", + Command: sql.NullString{String: "cool command", Valid: true}, + Icon: "/icon.png", + Subdomain: true, + SharingLevel: database.AppSharingLevelAuthenticated, + Health: database.WorkspaceAppHealthHealthy, + HealthcheckUrl: "http://localhost:1234/health", + HealthcheckInterval: 10, + HealthcheckThreshold: 3, + }, + { + ID: uuid.New(), + Url: sql.NullString{String: "http://google.com", Valid: true}, + External: true, + Slug: "google", + DisplayName: "Literally Google", + Command: sql.NullString{Valid: false}, + Icon: "/google.png", + Subdomain: false, + SharingLevel: database.AppSharingLevelPublic, + Health: database.WorkspaceAppHealthDisabled, + }, + { + ID: uuid.New(), + Url: sql.NullString{String: "http://localhost:4321", Valid: true}, + External: true, + Slug: "cool-app-2", + DisplayName: "another COOL app", + Command: sql.NullString{Valid: false}, + Icon: "", + Subdomain: false, + SharingLevel: database.AppSharingLevelOwner, + Health: database.WorkspaceAppHealthUnhealthy, + HealthcheckUrl: "http://localhost:4321/health", + HealthcheckInterval: 20, + HealthcheckThreshold: 5, + }, + } + scripts = []database.WorkspaceAgentScript{ + { + WorkspaceAgentID: agent.ID, + LogSourceID: uuid.New(), + LogPath: "/cool/log/path/1", + Script: "cool script 1", + Cron: "30 2 * * *", + StartBlocksLogin: true, + RunOnStart: true, + RunOnStop: false, + TimeoutSeconds: 60, + }, + { + WorkspaceAgentID: agent.ID, + LogSourceID: uuid.New(), + LogPath: "/cool/log/path/2", + Script: "cool script 2", + Cron: "", + StartBlocksLogin: false, + RunOnStart: false, + RunOnStop: true, + TimeoutSeconds: 30, + }, + } + metadata = []database.WorkspaceAgentMetadatum{ + { + WorkspaceAgentID: agent.ID, + DisplayName: "cool metadata 1", + Key: "cool-key-1", + Script: "cool script 1", + Value: "cool value 1", + Error: "", + Timeout: int64(time.Minute), + Interval: int64(time.Minute), + CollectedAt: someTime, + }, + { + WorkspaceAgentID: agent.ID, + DisplayName: "cool metadata 2", + Key: "cool-key-2", + Script: "cool script 2", + Value: "cool value 2", + Error: "some uncool error", + Timeout: int64(5 * time.Second), + Interval: int64(20 * time.Minute), + CollectedAt: someTime.Add(time.Hour), + }, + } + derpMapFn = func() *tailcfg.DERPMap { + return &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {RegionName: "cool region"}, + }, + } + } + ) + + // These are done manually to ensure the conversion logic matches what a + // human expects. + var ( + protoApps = []*agentproto.WorkspaceApp{ + { + Id: apps[0].ID[:], + Url: apps[0].Url.String, + External: apps[0].External, + Slug: apps[0].Slug, + DisplayName: apps[0].DisplayName, + Command: apps[0].Command.String, + Icon: apps[0].Icon, + Subdomain: apps[0].Subdomain, + SubdomainName: fmt.Sprintf("%s--%s--%s--%s", apps[0].Slug, agent.Name, workspace.Name, owner.Username), + SharingLevel: agentproto.WorkspaceApp_AUTHENTICATED, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: apps[0].HealthcheckUrl, + Interval: durationpb.New(time.Duration(apps[0].HealthcheckInterval) * time.Second), + Threshold: apps[0].HealthcheckThreshold, + }, + Health: agentproto.WorkspaceApp_HEALTHY, + }, + { + Id: apps[1].ID[:], + Url: apps[1].Url.String, + External: apps[1].External, + Slug: apps[1].Slug, + DisplayName: apps[1].DisplayName, + Command: apps[1].Command.String, + Icon: apps[1].Icon, + Subdomain: false, + SubdomainName: "", + SharingLevel: agentproto.WorkspaceApp_PUBLIC, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: "", + Interval: durationpb.New(0), + Threshold: 0, + }, + Health: agentproto.WorkspaceApp_DISABLED, + }, + { + Id: apps[2].ID[:], + Url: apps[2].Url.String, + External: apps[2].External, + Slug: apps[2].Slug, + DisplayName: apps[2].DisplayName, + Command: apps[2].Command.String, + Icon: apps[2].Icon, + Subdomain: false, + SubdomainName: "", + SharingLevel: agentproto.WorkspaceApp_OWNER, + Healthcheck: &agentproto.WorkspaceApp_Healthcheck{ + Url: apps[2].HealthcheckUrl, + Interval: durationpb.New(time.Duration(apps[2].HealthcheckInterval) * time.Second), + Threshold: apps[2].HealthcheckThreshold, + }, + Health: agentproto.WorkspaceApp_UNHEALTHY, + }, + } + protoScripts = []*agentproto.WorkspaceAgentScript{ + { + LogSourceId: scripts[0].LogSourceID[:], + LogPath: scripts[0].LogPath, + Script: scripts[0].Script, + Cron: scripts[0].Cron, + RunOnStart: scripts[0].RunOnStart, + RunOnStop: scripts[0].RunOnStop, + StartBlocksLogin: scripts[0].StartBlocksLogin, + Timeout: durationpb.New(time.Duration(scripts[0].TimeoutSeconds) * time.Second), + }, + { + LogSourceId: scripts[1].LogSourceID[:], + LogPath: scripts[1].LogPath, + Script: scripts[1].Script, + Cron: scripts[1].Cron, + RunOnStart: scripts[1].RunOnStart, + RunOnStop: scripts[1].RunOnStop, + StartBlocksLogin: scripts[1].StartBlocksLogin, + Timeout: durationpb.New(time.Duration(scripts[1].TimeoutSeconds) * time.Second), + }, + } + protoMetadata = []*agentproto.WorkspaceAgentMetadata_Description{ + { + DisplayName: metadata[0].DisplayName, + Key: metadata[0].Key, + Script: metadata[0].Script, + Interval: durationpb.New(time.Duration(metadata[0].Interval)), + Timeout: durationpb.New(time.Duration(metadata[0].Timeout)), + }, + { + DisplayName: metadata[1].DisplayName, + Key: metadata[1].Key, + Script: metadata[1].Script, + Interval: durationpb.New(time.Duration(metadata[1].Interval)), + Timeout: durationpb.New(time.Duration(metadata[1].Timeout)), + }, + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + mDB := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.ManifestAPI{ + AccessURL: &url.URL{Scheme: "https", Host: "example.com"}, + AppHostname: "*--apps.example.com", + ExternalAuthConfigs: []*externalauth.Config{ + {Type: string(codersdk.EnhancedExternalAuthProviderGitHub)}, + {Type: "some-provider"}, + {Type: string(codersdk.EnhancedExternalAuthProviderGitLab)}, + }, + DisableDirectConnections: true, + DerpForceWebSockets: true, + + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { + return workspace.ID, nil + }, + Database: mDB, + DerpMapFn: derpMapFn, + } + + mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) + mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil) + mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Keys: nil, // all + }).Return(metadata, nil) + mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) + mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil) + + got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) + require.NoError(t, err) + + expected := &agentproto.Manifest{ + AgentId: agent.ID[:], + OwnerUsername: owner.Username, + WorkspaceId: workspace.ID[:], + GitAuthConfigs: 2, // two "enhanced" external auth configs + EnvironmentVariables: expectedEnvVars, + Directory: agent.Directory, + VsCodePortProxyUri: fmt.Sprintf("https://{{port}}--%s--%s--%s--apps.example.com", agent.Name, workspace.Name, owner.Username), + MotdPath: agent.MOTDFile, + DisableDirectConnections: true, + DerpForceWebsockets: true, + // tailnet.DERPMapToProto() is extensively tested elsewhere, so it's + // not necessary to manually recreate a big DERP map here like we + // did for apps and metadata. + DerpMap: tailnet.DERPMapToProto(derpMapFn()), + Scripts: protoScripts, + Apps: protoApps, + Metadata: protoMetadata, + } + + // Log got and expected with spew. + // t.Log("got:\n" + spew.Sdump(got)) + // t.Log("expected:\n" + spew.Sdump(expected)) + + require.Equal(t, expected, got) + }) + + t.Run("NoAppHostname", func(t *testing.T) { + t.Parallel() + + mDB := dbmock.NewMockStore(gomock.NewController(t)) + + api := &agentapi.ManifestAPI{ + AccessURL: &url.URL{Scheme: "https", Host: "example.com"}, + AppHostname: "", + ExternalAuthConfigs: []*externalauth.Config{ + {Type: string(codersdk.EnhancedExternalAuthProviderGitHub)}, + {Type: "some-provider"}, + {Type: string(codersdk.EnhancedExternalAuthProviderGitLab)}, + }, + DisableDirectConnections: true, + DerpForceWebSockets: true, + + AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + WorkspaceIDFn: func(ctx context.Context, _ *database.WorkspaceAgent) (uuid.UUID, error) { + return workspace.ID, nil + }, + Database: mDB, + DerpMapFn: derpMapFn, + } + + mDB.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return(apps, nil) + mDB.EXPECT().GetWorkspaceAgentScriptsByAgentIDs(gomock.Any(), []uuid.UUID{agent.ID}).Return(scripts, nil) + mDB.EXPECT().GetWorkspaceAgentMetadata(gomock.Any(), database.GetWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Keys: nil, // all + }).Return(metadata, nil) + mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil) + mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil) + + got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{}) + require.NoError(t, err) + + expected := &agentproto.Manifest{ + AgentId: agent.ID[:], + OwnerUsername: owner.Username, + WorkspaceId: workspace.ID[:], + GitAuthConfigs: 2, // two "enhanced" external auth configs + EnvironmentVariables: expectedEnvVars, + Directory: agent.Directory, + VsCodePortProxyUri: "https://example.com", + MotdPath: agent.MOTDFile, + DisableDirectConnections: true, + DerpForceWebsockets: true, + // tailnet.DERPMapToProto() is extensively tested elsewhere, so it's + // not necessary to manually recreate a big DERP map here like we + // did for apps and metadata. + DerpMap: tailnet.DERPMapToProto(derpMapFn()), + Scripts: protoScripts, + Apps: protoApps, + Metadata: protoMetadata, + } + + // Log got and expected with spew. + // t.Log("got:\n" + spew.Sdump(got)) + // t.Log("expected:\n" + spew.Sdump(expected)) + + require.Equal(t, expected, got) + }) +} diff --git a/coderd/agentapi/servicebanner_test.go b/coderd/agentapi/servicebanner_test.go new file mode 100644 index 0000000000000..f7a860a96b70e --- /dev/null +++ b/coderd/agentapi/servicebanner_test.go @@ -0,0 +1,84 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/codersdk" +) + +func TestGetServiceBanner(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + cfg := codersdk.ServiceBannerConfig{ + Enabled: true, + Message: "hello world", + BackgroundColor: "#000000", + } + cfgJSON, err := json.Marshal(cfg) + require.NoError(t, err) + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return(string(cfgJSON), nil) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.NoError(t, err) + + require.Equal(t, &agentproto.ServiceBanner{ + Enabled: cfg.Enabled, + Message: cfg.Message, + BackgroundColor: cfg.BackgroundColor, + }, resp) + }) + + t.Run("None", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("", sql.ErrNoRows) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.NoError(t, err) + + require.Equal(t, &agentproto.ServiceBanner{ + Enabled: false, + Message: "", + BackgroundColor: "", + }, resp) + }) + + t.Run("BadJSON", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetServiceBanner(gomock.Any()).Return("hi", nil) + + api := &agentapi.ServiceBannerAPI{ + Database: dbM, + } + + resp, err := api.GetServiceBanner(context.Background(), &agentproto.GetServiceBannerRequest{}) + require.Error(t, err) + require.ErrorContains(t, err, "unmarshal json") + require.Nil(t, resp) + }) +} diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index ccf67ea98dd9a..8631bc7c5164b 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -237,16 +237,25 @@ func convertDisplayApps(apps []database.DisplayApp) []codersdk.DisplayApp { return dapps } +func WorkspaceAgentEnvironment(workspaceAgent database.WorkspaceAgent) (map[string]string, error) { + var envs map[string]string + if workspaceAgent.EnvironmentVariables.Valid { + err := json.Unmarshal(workspaceAgent.EnvironmentVariables.RawMessage, &envs) + if err != nil { + return nil, xerrors.Errorf("unmarshal environment variables: %w", err) + } + } + + return envs, nil +} + func WorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordinator, dbAgent database.WorkspaceAgent, apps []codersdk.WorkspaceApp, scripts []codersdk.WorkspaceAgentScript, logSources []codersdk.WorkspaceAgentLogSource, agentInactiveDisconnectTimeout time.Duration, agentFallbackTroubleshootingURL string, ) (codersdk.WorkspaceAgent, error) { - var envs map[string]string - if dbAgent.EnvironmentVariables.Valid { - err := json.Unmarshal(dbAgent.EnvironmentVariables.RawMessage, &envs) - if err != nil { - return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal env vars: %w", err) - } + envs, err := WorkspaceAgentEnvironment(dbAgent) + if err != nil { + return codersdk.WorkspaceAgent{}, err } troubleshootingURL := agentFallbackTroubleshootingURL if dbAgent.TroubleshootingURL != "" { diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 9b4987867e40a..d33eb10163ebb 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -107,21 +107,18 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Database: api.Database, Pubsub: api.Pubsub, DerpMapFn: api.DERPMap, - TailnetCoordinator: &api.TailnetCoordinator, TemplateScheduleStore: api.TemplateScheduleStore, StatsBatcher: api.statsBatcher, PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, PublishWorkspaceAgentLogsUpdateFn: api.publishWorkspaceAgentLogsUpdate, - AccessURL: api.AccessURL, - AppHostname: api.AppHostname, - AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), - AgentStatsRefreshInterval: api.AgentStatsRefreshInterval, - DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), - DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), - DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency, - ExternalAuthConfigs: api.ExternalAuthConfigs, + AccessURL: api.AccessURL, + AppHostname: api.AppHostname, + AgentStatsRefreshInterval: api.AgentStatsRefreshInterval, + DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), + DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), + DerpMapUpdateFrequency: api.Options.DERPMapUpdateFrequency, + ExternalAuthConfigs: api.ExternalAuthConfigs, // Optional: WorkspaceID: build.WorkspaceID, // saves the extra lookup later From 24f6614db4c8da30117df61e18cf4d45a3d91173 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 19 Dec 2023 12:12:10 +0000 Subject: [PATCH 2/7] Remaining tests --- coderd/agentapi/activitybump.go | 7 +- coderd/agentapi/api.go | 16 +- coderd/agentapi/metadata.go | 75 ++++-- coderd/agentapi/metadata_test.go | 275 +++++++++++++++++++ coderd/agentapi/stats.go | 61 +++-- coderd/agentapi/stats_test.go | 375 ++++++++++++++++++++++++++ coderd/agentapi/tailnet_test.go | 184 +++++++++++++ coderd/database/dbauthz/setup_test.go | 2 +- coderd/workspaceagents.go | 30 ++- 9 files changed, 954 insertions(+), 71 deletions(-) create mode 100644 coderd/agentapi/metadata_test.go create mode 100644 coderd/agentapi/stats_test.go create mode 100644 coderd/agentapi/tailnet_test.go diff --git a/coderd/agentapi/activitybump.go b/coderd/agentapi/activitybump.go index ab0797d6126bb..90afaf7e36111 100644 --- a/coderd/agentapi/activitybump.go +++ b/coderd/agentapi/activitybump.go @@ -41,13 +41,14 @@ func ActivityBumpWorkspace(ctx context.Context, log slog.Logger, db database.Sto // low priority operations fail first. ctx, cancel := context.WithTimeout(ctx, time.Second*15) defer cancel() - if err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{ + err := db.ActivityBumpWorkspace(ctx, database.ActivityBumpWorkspaceParams{ NextAutostart: nextAutostart.UTC(), WorkspaceID: workspaceID, - }); err != nil { + }) + if err != nil { if !xerrors.Is(err, context.Canceled) && !database.IsQueryCanceledError(err) { // Bump will fail if the context is canceled, but this is ok. - log.Error(ctx, "bump failed", slog.Error(err), + log.Error(ctx, "activity bump failed", slog.Error(err), slog.F("workspace_id", workspaceID), ) } diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index a97e76efb16be..cbabd9dd273c4 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -17,7 +17,6 @@ import ( "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" - "github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/pubsub" "github.com/coder/coder/v2/coderd/externalauth" @@ -58,7 +57,7 @@ type Options struct { Pubsub pubsub.Pubsub DerpMapFn func() *tailcfg.DERPMap TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] - StatsBatcher *batchstats.Batcher + StatsBatcher StatsBatcher // *batchstats.Batcher PublishWorkspaceUpdateFn func(ctx context.Context, workspaceID uuid.UUID) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) @@ -201,20 +200,15 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) ( agent = &agnt } - resource, err := a.opts.Database.GetWorkspaceResourceByID(ctx, agent.ResourceID) + getWorkspaceAgentByIDRow, err := a.opts.Database.GetWorkspaceByAgentID(ctx, agent.ID) if err != nil { - return uuid.Nil, xerrors.Errorf("get workspace agent resource by id %q: %w", agent.ResourceID, err) - } - - build, err := a.opts.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - return uuid.Nil, xerrors.Errorf("get workspace build by job id %q: %w", resource.JobID, err) + return uuid.Nil, xerrors.Errorf("get workspace by agent id %q: %w", agent.ID, err) } a.mu.Lock() - a.cachedWorkspaceID = build.WorkspaceID + a.cachedWorkspaceID = getWorkspaceAgentByIDRow.Workspace.ID a.mu.Unlock() - return build.WorkspaceID, nil + return getWorkspaceAgentByIDRow.Workspace.ID, nil } func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error { diff --git a/coderd/agentapi/metadata.go b/coderd/agentapi/metadata.go index a3bf24b2036fc..0c3e0c8630b01 100644 --- a/coderd/agentapi/metadata.go +++ b/coderd/agentapi/metadata.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/pubsub" ) @@ -20,14 +21,26 @@ type MetadataAPI struct { Database database.Store Pubsub pubsub.Pubsub Log slog.Logger + + TimeNowFn func() time.Time // defaults to dbtime.Now() +} + +func (a *MetadataAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() + } + return dbtime.Now() } func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) { const ( - // maxValueLen is set to 2048 to stay under the 8000 byte Postgres - // NOTIFY limit. Since both value and error can be set, the real payload - // limit is 2 * 2048 * 4/3 = 5461 bytes + a few - // hundred bytes for JSON syntax, key names, and metadata. + // maxAllKeysLen is the maximum length of all metadata keys. This is + // 6144 to stay below the Postgres NOTIFY limit of 8000 bytes, with some + // headway for the timestamp and JSON encoding. Any values that would + // exceed this limit are discarded (the rest are still inserted) and an + // error is returned. + maxAllKeysLen = 6144 // 1024 * 6 + maxValueLen = 2048 maxErrorLen = maxValueLen ) @@ -37,18 +50,36 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B return nil, err } - collectedAt := time.Now() - dbUpdate := database.UpdateWorkspaceAgentMetadataParams{ - WorkspaceAgentID: workspaceAgent.ID, - Key: make([]string, 0, len(req.Metadata)), - Value: make([]string, 0, len(req.Metadata)), - Error: make([]string, 0, len(req.Metadata)), - CollectedAt: make([]time.Time, 0, len(req.Metadata)), - } - + var ( + collectedAt = a.now() + allKeysLen = 0 + dbUpdate = database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: workspaceAgent.ID, + // These need to be `make(x, 0, len(req.Metadata))` instead of + // `make(x, len(req.Metadata))` because we may not insert all + // metadata if the keys are large. + Key: make([]string, 0, len(req.Metadata)), + Value: make([]string, 0, len(req.Metadata)), + Error: make([]string, 0, len(req.Metadata)), + CollectedAt: make([]time.Time, 0, len(req.Metadata)), + } + ) for _, md := range req.Metadata { metadataError := md.Result.Error + allKeysLen += len(md.Key) + if allKeysLen > maxAllKeysLen { + // We still insert the rest of the metadata, and we return an error + // after the insert. + a.Log.Warn( + ctx, "discarded extra agent metadata due to excessive key length", + slog.F("collected_at", collectedAt), + slog.F("all_keys_len", allKeysLen), + slog.F("max_all_keys_len", maxAllKeysLen), + ) + break + } + // We overwrite the error if the provided payload is too long. if len(md.Result.Value) > maxValueLen { metadataError = fmt.Sprintf("value of %d bytes exceeded %d bytes", len(md.Result.Value), maxValueLen) @@ -71,12 +102,16 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B a.Log.Debug( ctx, "accepted metadata report", slog.F("collected_at", collectedAt), - slog.F("original_collected_at", collectedAt), slog.F("key", md.Key), slog.F("value", ellipse(md.Result.Value, 16)), ) } + err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate) + if err != nil { + return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err) + } + payload, err := json.Marshal(WorkspaceAgentMetadataChannelPayload{ CollectedAt: collectedAt, Keys: dbUpdate.Key, @@ -84,17 +119,17 @@ func (a *MetadataAPI) BatchUpdateMetadata(ctx context.Context, req *agentproto.B if err != nil { return nil, xerrors.Errorf("marshal workspace agent metadata channel payload: %w", err) } - - err = a.Database.UpdateWorkspaceAgentMetadata(ctx, dbUpdate) - if err != nil { - return nil, xerrors.Errorf("update workspace agent metadata in database: %w", err) - } - err = a.Pubsub.Publish(WatchWorkspaceAgentMetadataChannel(workspaceAgent.ID), payload) if err != nil { return nil, xerrors.Errorf("publish workspace agent metadata: %w", err) } + // If the metadata keys were too large, we return an error so the agent can + // log it. + if allKeysLen > maxAllKeysLen { + return nil, xerrors.Errorf("metadata keys of %d bytes exceeded %d bytes", allKeysLen, maxAllKeysLen) + } + return &agentproto.BatchUpdateMetadataResponse{}, nil } diff --git a/coderd/agentapi/metadata_test.go b/coderd/agentapi/metadata_test.go new file mode 100644 index 0000000000000..71abcc0e9f46a --- /dev/null +++ b/coderd/agentapi/metadata_test.go @@ -0,0 +1,275 @@ +package agentapi_test + +import ( + "context" + "encoding/json" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "cdr.dev/slog/sloggers/slogtest" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/database/pubsub" +) + +func TestBatchUpdateMetadata(t *testing.T) { + t.Parallel() + + agent := database.WorkspaceAgent{ + ID: uuid.New(), + } + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := pubsub.NewInMemory() + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "awesome key", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + CollectedAt: timestamppb.New(now.Add(-10 * time.Second)), + Age: 10, + Value: "awesome value", + Error: "", + }, + }, + { + Key: "uncool key", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + CollectedAt: timestamppb.New(now.Add(-3 * time.Second)), + Age: 3, + Value: "", + Error: "uncool value", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key}, + Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value}, + Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error}, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + // Watch the pubsub for events. + var ( + eventCount int64 + gotEvent agentapi.WorkspaceAgentMetadataChannelPayload + ) + cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) { + if atomic.AddInt64(&eventCount, 1) > 1 { + return + } + require.NoError(t, json.Unmarshal(message, &gotEvent)) + }) + require.NoError(t, err) + defer cancel() + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) + + require.Equal(t, int64(1), atomic.LoadInt64(&eventCount)) + require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ + CollectedAt: now, + Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key}, + }, gotEvent) + }) + + t.Run("ExceededLength", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := pubsub.NewInMemory() + + almostLongValue := "" + for i := 0; i < 2048; i++ { + almostLongValue += "a" + } + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "almost long value", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: almostLongValue, + }, + }, + { + Key: "too long value", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: almostLongValue + "a", + }, + }, + { + Key: "almost long error", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Error: almostLongValue, + }, + }, + { + Key: "too long error", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Error: almostLongValue + "a", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key, req.Metadata[3].Key}, + Value: []string{ + almostLongValue, + almostLongValue, // truncated + "", + "", + }, + Error: []string{ + "", + "value of 2049 bytes exceeded 2048 bytes", + almostLongValue, + "error of 2049 bytes exceeded 2048 bytes", // replaced + }, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now, now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) + }) + + t.Run("KeysTooLong", func(t *testing.T) { + t.Parallel() + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + pub := pubsub.NewInMemory() + + now := dbtime.Now() + req := &agentproto.BatchUpdateMetadataRequest{ + Metadata: []*agentproto.Metadata{ + { + Key: "key 1", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 1", + }, + }, + { + Key: "key 2", + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 2", + }, + }, + { + Key: func() string { + key := "key 3 " + for i := 0; i < (6144 - len("key 1") - len("key 2") - len("key 3") - 1); i++ { + key += "a" + } + return key + }(), + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 3", + }, + }, + { + Key: "a", // should be ignored + Result: &agentproto.WorkspaceAgentMetadata_Result{ + Value: "value 4", + }, + }, + }, + } + + dbM.EXPECT().UpdateWorkspaceAgentMetadata(gomock.Any(), database.UpdateWorkspaceAgentMetadataParams{ + WorkspaceAgentID: agent.ID, + // No key 4. + Key: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, + Value: []string{req.Metadata[0].Result.Value, req.Metadata[1].Result.Value, req.Metadata[2].Result.Value}, + Error: []string{req.Metadata[0].Result.Error, req.Metadata[1].Result.Error, req.Metadata[2].Result.Error}, + // The value from the agent is ignored. + CollectedAt: []time.Time{now, now, now}, + }).Return(nil) + + api := &agentapi.MetadataAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + Pubsub: pub, + Log: slogtest.Make(t, nil), + TimeNowFn: func() time.Time { + return now + }, + } + + // Watch the pubsub for events. + var ( + eventCount int64 + gotEvent agentapi.WorkspaceAgentMetadataChannelPayload + ) + cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) { + if atomic.AddInt64(&eventCount, 1) > 1 { + return + } + require.NoError(t, json.Unmarshal(message, &gotEvent)) + }) + require.NoError(t, err) + defer cancel() + + resp, err := api.BatchUpdateMetadata(context.Background(), req) + require.Error(t, err) + require.Equal(t, "metadata keys of 6145 bytes exceeded 6144 bytes", err.Error()) + require.Nil(t, resp) + + require.Equal(t, int64(1), atomic.LoadInt64(&eventCount)) + require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ + CollectedAt: now, + // No key 4. + Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key, req.Metadata[2].Key}, + }, gotEvent) + }) +} diff --git a/coderd/agentapi/stats.go b/coderd/agentapi/stats.go index 05c1b744f2a9a..4a0047bd63564 100644 --- a/coderd/agentapi/stats.go +++ b/coderd/agentapi/stats.go @@ -9,57 +9,71 @@ import ( "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/durationpb" + "github.com/google/uuid" + "cdr.dev/slog" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/autobuild" - "github.com/coder/coder/v2/coderd/batchstats" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/prometheusmetrics" "github.com/coder/coder/v2/coderd/schedule" ) +type StatsBatcher interface { + Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error +} + type StatsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) Database database.Store Log slog.Logger - StatsBatcher *batchstats.Batcher + StatsBatcher StatsBatcher TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] AgentStatsRefreshInterval time.Duration UpdateAgentMetricsFn func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) + + TimeNowFn func() time.Time // defaults to dbtime.Now() } -func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { - workspaceAgent, err := a.AgentFn(ctx) - if err != nil { - return nil, err - } - row, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) - if err != nil { - return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err) +func (a *StatsAPI) now() time.Time { + if a.TimeNowFn != nil { + return a.TimeNowFn() } - workspace := row.Workspace + return dbtime.Now() +} +func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) { + // An empty stat means it's just looking for the report interval. res := &agentproto.UpdateStatsResponse{ ReportInterval: durationpb.New(a.AgentStatsRefreshInterval), } - - // An empty stat means it's just looking for the report interval. - if len(req.Stats.ConnectionsByProto) == 0 { + if req.Stats == nil || len(req.Stats.ConnectionsByProto) == 0 { return res, nil } + workspaceAgent, err := a.AgentFn(ctx) + if err != nil { + return nil, err + } + getWorkspaceAgentByIDRow, err := a.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) + if err != nil { + return nil, xerrors.Errorf("get workspace by agent ID %q: %w", workspaceAgent.ID, err) + } + workspace := getWorkspaceAgentByIDRow.Workspace a.Log.Debug(ctx, "read stats report", slog.F("interval", a.AgentStatsRefreshInterval), slog.F("workspace_id", workspace.ID), slog.F("payload", req), ) + now := a.now() if req.Stats.ConnectionCount > 0 { var nextAutostart time.Time if workspace.AutostartSchedule.String != "" { templateSchedule, err := (*(a.TemplateScheduleStore.Load())).Get(ctx, a.Database, workspace.TemplateID) - // If the template schedule fails to load, just default to bumping without the next trasition and log it. + // If the template schedule fails to load, just default to bumping + // without the next transition and log it. if err != nil { a.Log.Warn(ctx, "failed to load template schedule bumping activity, defaulting to bumping by 60min", slog.F("workspace_id", workspace.ID), @@ -67,7 +81,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR slog.Error(err), ) } else { - next, allowed := autobuild.NextAutostartSchedule(time.Now(), workspace.AutostartSchedule.String, templateSchedule) + next, allowed := autobuild.NextAutostartSchedule(now, workspace.AutostartSchedule.String, templateSchedule) if allowed { nextAutostart = next } @@ -76,13 +90,12 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR ActivityBumpWorkspace(ctx, a.Log.Named("activity_bump"), a.Database, workspace.ID, nextAutostart) } - now := dbtime.Now() - var errGroup errgroup.Group errGroup.Go(func() error { - if err := a.StatsBatcher.Add(time.Now(), workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats); err != nil { - a.Log.Error(ctx, "failed to add stats to batcher", slog.Error(err)) - return xerrors.Errorf("can't insert workspace agent stat: %w", err) + err := a.StatsBatcher.Add(now, workspaceAgent.ID, workspace.TemplateID, workspace.OwnerID, workspace.ID, req.Stats) + if err != nil { + a.Log.Error(ctx, "add agent stats to batcher", slog.Error(err)) + return xerrors.Errorf("insert workspace agent stats batch: %w", err) } return nil }) @@ -92,7 +105,7 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR LastUsedAt: now, }) if err != nil { - return xerrors.Errorf("can't update workspace LastUsedAt: %w", err) + return xerrors.Errorf("update workspace LastUsedAt: %w", err) } return nil }) @@ -100,14 +113,14 @@ func (a *StatsAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsR errGroup.Go(func() error { user, err := a.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { - return xerrors.Errorf("can't get user: %w", err) + return xerrors.Errorf("get user: %w", err) } a.UpdateAgentMetricsFn(ctx, prometheusmetrics.AgentMetricLabels{ Username: user.Username, WorkspaceName: workspace.Name, AgentName: workspaceAgent.Name, - TemplateName: row.TemplateName, + TemplateName: getWorkspaceAgentByIDRow.TemplateName, }, req.Stats.Metrics) return nil }) diff --git a/coderd/agentapi/stats_test.go b/coderd/agentapi/stats_test.go new file mode 100644 index 0000000000000..409c24df60b43 --- /dev/null +++ b/coderd/agentapi/stats_test.go @@ -0,0 +1,375 @@ +package agentapi_test + +import ( + "context" + "database/sql" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/prometheusmetrics" + "github.com/coder/coder/v2/coderd/schedule" +) + +type statsBatcher struct { + mu sync.Mutex + + called int64 + lastTime time.Time + lastAgentID uuid.UUID + lastTemplateID uuid.UUID + lastUserID uuid.UUID + lastWorkspaceID uuid.UUID + lastStats *agentproto.Stats +} + +var _ agentapi.StatsBatcher = &statsBatcher{} + +func (b *statsBatcher) Add(now time.Time, agentID uuid.UUID, templateID uuid.UUID, userID uuid.UUID, workspaceID uuid.UUID, st *agentproto.Stats) error { + b.mu.Lock() + defer b.mu.Unlock() + b.called++ + b.lastTime = now + b.lastAgentID = agentID + b.lastTemplateID = templateID + b.lastUserID = userID + b.lastWorkspaceID = workspaceID + b.lastStats = st + return nil +} + +func TestUpdateStates(t *testing.T) { + t.Parallel() + + var ( + user = database.User{ + ID: uuid.New(), + Username: "bill", + } + template = database.Template{ + ID: uuid.New(), + Name: "tpl", + } + workspace = database.Workspace{ + ID: uuid.New(), + OwnerID: user.ID, + TemplateID: template.ID, + Name: "xyz", + } + agent = database.WorkspaceAgent{ + ID: uuid.New(), + Name: "abc", + } + ) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + var ( + now = dbtime.Now() + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + panic("should not be called") + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + updateAgentMetricsFnCalled int64 + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + "dean": 2, + }, + ConnectionCount: 3, + ConnectionMedianLatencyMs: 23, + RxPackets: 120, + RxBytes: 1000, + TxPackets: 130, + TxBytes: 2000, + SessionCountVscode: 1, + SessionCountJetbrains: 2, + SessionCountReconnectingPty: 3, + SessionCountSsh: 4, + Metrics: []*agentproto.Stats_Metric{ + { + Name: "awesome metric", + Value: 42, + }, + { + Name: "uncool metric", + Value: 0, + }, + }, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 10 * time.Second, + UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { + atomic.AddInt64(&updateAgentMetricsFnCalled, 1) + assert.Equal(t, prometheusmetrics.AgentMetricLabels{ + Username: user.Username, + WorkspaceName: workspace.Name, + AgentName: agent.Name, + TemplateName: template.Name, + }, labels) + assert.Equal(t, req.Stats.Metrics, metrics) + }, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // We expect an activity bump because ConnectionCount > 0. + dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{ + WorkspaceID: workspace.ID, + NextAutostart: time.Time{}.UTC(), + }).Return(nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + // User gets fetched to hit the UpdateAgentMetricsFn. + dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil) + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(10 * time.Second), + }, resp) + + batcher.mu.Lock() + defer batcher.mu.Unlock() + require.Equal(t, int64(1), batcher.called) + require.Equal(t, now, batcher.lastTime) + require.Equal(t, agent.ID, batcher.lastAgentID) + require.Equal(t, template.ID, batcher.lastTemplateID) + require.Equal(t, user.ID, batcher.lastUserID) + require.Equal(t, workspace.ID, batcher.lastWorkspaceID) + require.Equal(t, req.Stats, batcher.lastStats) + }) + + t.Run("ConnectionCountZero", func(t *testing.T) { + t.Parallel() + + var ( + now = dbtime.Now() + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + panic("should not be called") + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + }, + ConnectionCount: 0, + ConnectionMedianLatencyMs: 23, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 10 * time.Second, + // Ignored when nil. + UpdateAgentMetricsFn: nil, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + _, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + }) + + t.Run("NoConnectionsByProto", func(t *testing.T) { + t.Parallel() + + var ( + dbM = dbmock.NewMockStore(gomock.NewController(t)) + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{}, // len() == 0 + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: nil, // should not be called + TemplateScheduleStore: nil, // should not be called + AgentStatsRefreshInterval: 10 * time.Second, + UpdateAgentMetricsFn: nil, // should not be called + TimeNowFn: func() time.Time { + panic("should not be called") + }, + } + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(10 * time.Second), + }, resp) + }) + + t.Run("AutostartAwareBump", func(t *testing.T) { + t.Parallel() + + // Use a workspace with an autostart schedule. + workspace := workspace + workspace.AutostartSchedule = sql.NullString{ + String: "CRON_TZ=Australia/Sydney 0 8 * * *", + Valid: true, + } + + // Use a custom time for now which would trigger the autostart aware + // bump. + now, err := time.Parse("2006-01-02 15:04:05 -0700 MST", "2023-12-19 07:30:00 +1100 AEDT") + require.NoError(t, err) + now = dbtime.Time(now) + nextAutostart := now.Add(30 * time.Minute).UTC() // always sent to DB as UTC + + var ( + dbM = dbmock.NewMockStore(gomock.NewController(t)) + templateScheduleStore = schedule.MockTemplateScheduleStore{ + GetFn: func(context.Context, database.Store, uuid.UUID) (schedule.TemplateScheduleOptions, error) { + return schedule.TemplateScheduleOptions{ + UserAutostartEnabled: true, + AutostartRequirement: schedule.TemplateAutostartRequirement{ + DaysOfWeek: 0b01111111, // every day + }, + }, nil + }, + SetFn: func(context.Context, database.Store, database.Template, schedule.TemplateScheduleOptions) (database.Template, error) { + panic("not implemented") + }, + } + batcher = &statsBatcher{} + updateAgentMetricsFnCalled int64 + + req = &agentproto.UpdateStatsRequest{ + Stats: &agentproto.Stats{ + ConnectionsByProto: map[string]int64{ + "tcp": 1, + "dean": 2, + }, + ConnectionCount: 3, + }, + } + ) + api := agentapi.StatsAPI{ + AgentFn: func(context.Context) (database.WorkspaceAgent, error) { + return agent, nil + }, + Database: dbM, + StatsBatcher: batcher, + TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), + AgentStatsRefreshInterval: 15 * time.Second, + UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { + atomic.AddInt64(&updateAgentMetricsFnCalled, 1) + assert.Equal(t, prometheusmetrics.AgentMetricLabels{ + Username: user.Username, + WorkspaceName: workspace.Name, + AgentName: agent.Name, + TemplateName: template.Name, + }, labels) + assert.Equal(t, req.Stats.Metrics, metrics) + }, + TimeNowFn: func() time.Time { + return now + }, + } + + // Workspace gets fetched. + dbM.EXPECT().GetWorkspaceByAgentID(gomock.Any(), agent.ID).Return(database.GetWorkspaceByAgentIDRow{ + Workspace: workspace, + TemplateName: template.Name, + }, nil) + + // We expect an activity bump because ConnectionCount > 0. However, the + // next autostart time will be set on the bump. + dbM.EXPECT().ActivityBumpWorkspace(gomock.Any(), database.ActivityBumpWorkspaceParams{ + WorkspaceID: workspace.ID, + NextAutostart: nextAutostart, + }).Return(nil) + + // Workspace last used at gets bumped. + dbM.EXPECT().UpdateWorkspaceLastUsedAt(gomock.Any(), database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }).Return(nil) + + // User gets fetched to hit the UpdateAgentMetricsFn. + dbM.EXPECT().GetUserByID(gomock.Any(), user.ID).Return(user, nil) + + resp, err := api.UpdateStats(context.Background(), req) + require.NoError(t, err) + require.Equal(t, &agentproto.UpdateStatsResponse{ + ReportInterval: durationpb.New(15 * time.Second), + }, resp) + }) +} + +func templateScheduleStorePtr(store schedule.TemplateScheduleStore) *atomic.Pointer[schedule.TemplateScheduleStore] { + var ptr atomic.Pointer[schedule.TemplateScheduleStore] + ptr.Store(&store) + return &ptr +} diff --git a/coderd/agentapi/tailnet_test.go b/coderd/agentapi/tailnet_test.go new file mode 100644 index 0000000000000..dae2a2c7ebe99 --- /dev/null +++ b/coderd/agentapi/tailnet_test.go @@ -0,0 +1,184 @@ +package agentapi_test + +import ( + "context" + "testing" + "time" + + "golang.org/x/xerrors" + "storj.io/drpc" + "tailscale.com/tailcfg" + + "github.com/stretchr/testify/require" + + agentproto "github.com/coder/coder/v2/agent/proto" + "github.com/coder/coder/v2/coderd/agentapi" + "github.com/coder/coder/v2/tailnet" + tailnetproto "github.com/coder/coder/v2/tailnet/proto" +) + +type fakeDERPMapStream struct { + drpc.Stream // to fake implement unused members + + ctx context.Context + closeFn func() error + sendFn func(*tailnetproto.DERPMap) error +} + +var _ agentproto.DRPCAgent_StreamDERPMapsStream = &fakeDERPMapStream{} + +func (s *fakeDERPMapStream) Context() context.Context { + return s.ctx +} + +func (s *fakeDERPMapStream) Close() error { + return s.closeFn() +} + +func (s *fakeDERPMapStream) Send(m *tailnetproto.DERPMap) error { + return s.sendFn(m) +} + +func TestStreamDERPMaps(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + derpMap := tailcfg.DERPMap{} + api := &agentapi.TailnetAPI{ + Ctx: context.Background(), + DerpMapFn: func() *tailcfg.DERPMap { + derp := (&derpMap).Clone() + return derp + }, + DerpMapUpdateFrequency: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + closed := make(chan struct{}) + maps := make(chan *tailnetproto.DERPMap, 10) + stream := &fakeDERPMapStream{ + ctx: ctx, + closeFn: func() error { + select { + case <-ctx.Done(): + default: + t.Fatal("expected context to be canceled before close") + } + close(closed) + return nil + }, + sendFn: func(m *tailnetproto.DERPMap) error { + if m == nil { + t.Fatal("expected non-nil map") + } + maps <- m + return nil + }, + } + + errCh := make(chan error) + go func() { + // Request isn't used. + errCh <- api.StreamDERPMaps(nil, stream) + }() + + // Initial map. + gotMap := <-maps + require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) + + // Update the map, should get an update. + derpMap.Regions = map[int]*tailcfg.DERPRegion{ + 1: {}, + } + gotMap = <-maps + require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) + + // Update the map again, should get an update. + derpMap.Regions = nil + gotMap = <-maps + require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) + + // Cancel the stream, should return the fn. + cancel() + <-closed + require.NoError(t, <-errCh) + }) + + t.Run("SendFailure", func(t *testing.T) { + t.Parallel() + + api := &agentapi.TailnetAPI{ + Ctx: context.Background(), + DerpMapFn: func() *tailcfg.DERPMap { + return &tailcfg.DERPMap{} + }, + DerpMapUpdateFrequency: time.Millisecond, + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream := &fakeDERPMapStream{ + ctx: ctx, + closeFn: func() error { + return nil + }, + sendFn: func(m *tailnetproto.DERPMap) error { + return xerrors.New("test error") + }, + } + + err := api.StreamDERPMaps(nil, stream) + require.Error(t, err) + require.ErrorContains(t, err, "send derp map") + require.ErrorContains(t, err, "test error") + }) + + t.Run("GlobalContextCanceled", func(t *testing.T) { + t.Parallel() + + globalCtx, globalCtxCancel := context.WithCancel(context.Background()) + api := &agentapi.TailnetAPI{ + Ctx: globalCtx, + DerpMapFn: func() *tailcfg.DERPMap { + return &tailcfg.DERPMap{} + }, + DerpMapUpdateFrequency: time.Hour, // long time to make sure ctx cancels are quick + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + maps := make(chan *tailnetproto.DERPMap, 10) + stream := &fakeDERPMapStream{ + ctx: ctx, + closeFn: func() error { + return nil + }, + sendFn: func(m *tailnetproto.DERPMap) error { + if m == nil { + t.Fatal("expected non-nil map") + } + maps <- m + return nil + }, + } + + errCh := make(chan error) + go func() { + // Request isn't used. + errCh <- api.StreamDERPMaps(nil, stream) + }() + + // Initial map. + <-maps + + // Cancel the global context, should return the fn. + globalCtxCancel() + require.NoError(t, <-errCh) + }) +} diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 3c54d8be4e345..403d23d508213 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -28,7 +28,7 @@ import ( "github.com/coder/coder/v2/coderd/util/slice" ) -var errMatchAny = errors.New("match any error") +var errMatchAny = xerrors.New("match any error") var skipMethods = map[string]string{ "InTx": "Not relevant", diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 04428aed28e17..dea25099e0c47 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -156,18 +156,24 @@ func (api *API) workspaceAgentManifest(rw http.ResponseWriter, r *http.Request) // As this API becomes deprecated, use the new protobuf API and convert the // types back to the SDK types. manifestAPI := &agentapi.ManifestAPI{ - AccessURL: api.AccessURL, - AppHostname: api.AppHostname, - AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, - AgentFallbackTroubleshootingURL: api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), - ExternalAuthConfigs: api.ExternalAuthConfigs, - DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), - DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), - - AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil }, - Database: api.Database, - DerpMapFn: api.DERPMap, - TailnetCoordinator: &api.TailnetCoordinator, + AccessURL: api.AccessURL, + AppHostname: api.AppHostname, + ExternalAuthConfigs: api.ExternalAuthConfigs, + DisableDirectConnections: api.DeploymentValues.DERP.Config.BlockDirect.Value(), + DerpForceWebSockets: api.DeploymentValues.DERP.Config.ForceWebSockets.Value(), + + AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil }, + WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) { + // Sadly this results in a double query, but it's only temporary for + // now. + ws, err := api.Database.GetWorkspaceByAgentID(ctx, wa.ID) + if err != nil { + return uuid.Nil, err + } + return ws.Workspace.ID, nil + }, + Database: api.Database, + DerpMapFn: api.DERPMap, } manifest, err := manifestAPI.GetManifest(ctx, &agentproto.GetManifestRequest{}) if err != nil { From a7784dd7d3a6c33513408363eac5556ab52caef5 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 19 Dec 2023 12:52:41 +0000 Subject: [PATCH 3/7] chore: convert startup API to agentapi --- coderd/agentapi/api.go | 18 ++----- coderd/agentapi/apps.go | 8 +-- coderd/agentapi/apps_test.go | 30 ++++++++--- coderd/agentapi/lifecycle.go | 32 +++++++++--- coderd/agentapi/lifecycle_test.go | 12 ++--- coderd/agentapi/logs.go | 18 +++---- coderd/agentapi/logs_test.go | 42 ++++++++++----- coderd/workspaceagents.go | 85 ++++++++++++++----------------- coderd/workspaceagents_test.go | 6 ++- 9 files changed, 143 insertions(+), 108 deletions(-) diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index cbabd9dd273c4..8b217574af77b 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -114,14 +114,15 @@ func New(opts Options) *API { WorkspaceIDFn: api.workspaceID, Database: opts.Database, Log: opts.Log, - PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, + PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn, } api.AppsAPI = &AppsAPI{ AgentFn: api.agent, + WorkspaceIDFn: api.workspaceID, Database: opts.Database, Log: opts.Log, - PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, + PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn, } api.MetadataAPI = &MetadataAPI{ @@ -133,9 +134,10 @@ func New(opts Options) *API { api.LogsAPI = &LogsAPI{ AgentFn: api.agent, + WorkspaceIDFn: api.workspaceID, Database: opts.Database, Log: opts.Log, - PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, + PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn, PublishWorkspaceAgentLogsUpdateFn: opts.PublishWorkspaceAgentLogsUpdateFn, } @@ -210,13 +212,3 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) ( a.mu.Unlock() return getWorkspaceAgentByIDRow.Workspace.ID, nil } - -func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error { - workspaceID, err := a.workspaceID(ctx, agent) - if err != nil { - return err - } - - a.opts.PublishWorkspaceUpdateFn(ctx, workspaceID) - return nil -} diff --git a/coderd/agentapi/apps.go b/coderd/agentapi/apps.go index 7e8bda1262426..e1117928aec9b 100644 --- a/coderd/agentapi/apps.go +++ b/coderd/agentapi/apps.go @@ -13,9 +13,10 @@ import ( type AppsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) + WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, uuid.UUID) } func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) { @@ -91,10 +92,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat } if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) + workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) + return nil, err } + a.PublishWorkspaceUpdateFn(ctx, workspaceID) } return &agentproto.BatchUpdateAppHealthResponse{}, nil } diff --git a/coderd/agentapi/apps_test.go b/coderd/agentapi/apps_test.go index 8d0b802063bfe..bceac4480a004 100644 --- a/coderd/agentapi/apps_test.go +++ b/coderd/agentapi/apps_test.go @@ -21,7 +21,8 @@ func TestBatchUpdateAppHealths(t *testing.T) { t.Parallel() var ( - agent = database.WorkspaceAgent{ + workspaceID = uuid.New() + agent = database.WorkspaceAgent{ ID: uuid.New(), } app1 = database.WorkspaceApp{ @@ -61,11 +62,13 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -99,11 +102,13 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -138,11 +143,13 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -173,6 +180,9 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: nil, @@ -202,6 +212,9 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: nil, @@ -232,6 +245,9 @@ func TestBatchUpdateAppHealths(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: nil, diff --git a/coderd/agentapi/lifecycle.go b/coderd/agentapi/lifecycle.go index 662d0c0c2e28e..c8a298c84d5b3 100644 --- a/coderd/agentapi/lifecycle.go +++ b/coderd/agentapi/lifecycle.go @@ -3,6 +3,7 @@ package agentapi import ( "context" "database/sql" + "sort" "time" "github.com/google/uuid" @@ -16,12 +17,26 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" ) +type WorkspaceAgentAPIVersionContextKey struct{} + +func WorkspaceAgentAPIVersion(ctx context.Context) string { + v, ok := ctx.Value(WorkspaceAgentAPIVersionContextKey{}).(string) + if !ok { + return AgentAPIVersionDRPC + } + return v +} + +func SetWorkspaceAgentAPIVersion(ctx context.Context, version string) context.Context { + return context.WithValue(ctx, WorkspaceAgentAPIVersionContextKey{}, version) +} + type LifecycleAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, uuid.UUID) TimeNowFn func() time.Time // defaults to dbtime.Now() } @@ -113,10 +128,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda } if a.PublishWorkspaceUpdateFn != nil { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) - } + a.PublishWorkspaceUpdateFn(ctx, workspaceID) } return req.Lifecycle, nil @@ -165,12 +177,20 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update } } + // Sort subsystems. + sort.Slice(dbSubsystems, func(i, j int) bool { + return dbSubsystems[i] < dbSubsystems[j] + }) + + // Get API version from context (or default to DRPC). This is only used when + // shimming an old version to this version. + apiVersion := WorkspaceAgentAPIVersion(ctx) err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{ ID: workspaceAgent.ID, Version: req.Startup.Version, ExpandedDirectory: req.Startup.ExpandedDirectory, Subsystems: dbSubsystems, - APIVersion: AgentAPIVersionDRPC, + APIVersion: apiVersion, }) if err != nil { return nil, xerrors.Errorf("update workspace agent startup in database: %w", err) diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go index 9029a84b955eb..040d722e9e9f6 100644 --- a/coderd/agentapi/lifecycle_test.go +++ b/coderd/agentapi/lifecycle_test.go @@ -74,9 +74,8 @@ func TestUpdateLifecycle(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -161,9 +160,8 @@ func TestUpdateLifecycle(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -244,9 +242,8 @@ func TestUpdateLifecycle(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } @@ -318,9 +315,8 @@ func TestUpdateLifecycle(t *testing.T) { }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishCalled, 1) - return nil }, } diff --git a/coderd/agentapi/logs.go b/coderd/agentapi/logs.go index cb3a920b9a63b..c5fd3682809fd 100644 --- a/coderd/agentapi/logs.go +++ b/coderd/agentapi/logs.go @@ -16,9 +16,10 @@ import ( type LogsAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) + WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) Database database.Store Log slog.Logger - PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error + PublishWorkspaceUpdateFn func(context.Context, uuid.UUID) PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) TimeNowFn func() time.Time // defaults to dbtime.Now() @@ -48,6 +49,11 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea return nil, xerrors.Errorf("parse log source ID %q: %w", req.LogSourceId, err) } + workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent) + if err != nil { + return nil, err + } + // This is to support the legacy API where the log source ID was // not provided in the request body. We default to the external // log source in this case. @@ -123,10 +129,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea } if a.PublishWorkspaceUpdateFn != nil { - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) - } + a.PublishWorkspaceUpdateFn(ctx, workspaceID) } return nil, xerrors.New("workspace agent log limit exceeded") } @@ -143,10 +146,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil { // If these are the first logs being appended, we publish a UI update // to notify the UI that logs are now available. - err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent) - if err != nil { - return nil, xerrors.Errorf("publish workspace update: %w", err) - } + a.PublishWorkspaceUpdateFn(ctx, workspaceID) } return &agentproto.BatchCreateLogsResponse{}, nil diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go index 1d4261a0191ea..4f2ec7f58a6cd 100644 --- a/coderd/agentapi/logs_test.go +++ b/coderd/agentapi/logs_test.go @@ -26,7 +26,8 @@ func TestBatchCreateLogs(t *testing.T) { t.Parallel() var ( - agent = database.WorkspaceAgent{ + workspaceID = uuid.New() + agent = database.WorkspaceAgent{ ID: uuid.New(), } logSource = database.WorkspaceAgentLogSource{ @@ -48,11 +49,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) @@ -152,11 +155,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agentWithLogs, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) @@ -200,11 +205,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return overflowedAgent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) @@ -231,6 +238,9 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), // Test that they are ignored when nil. @@ -293,11 +303,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) @@ -337,11 +349,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) @@ -384,11 +398,13 @@ func TestBatchCreateLogs(t *testing.T) { AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil }, + WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) { + return workspaceID, nil + }, Database: dbM, Log: slogtest.Make(t, nil), - PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { + PublishWorkspaceUpdateFn: func(context.Context, uuid.UUID) { atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) - return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index dea25099e0c47..e1ddf5721dc0e 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -23,7 +23,6 @@ import ( "github.com/sqlc-dev/pqtype" "golang.org/x/exp/maps" "golang.org/x/exp/slices" - "golang.org/x/mod/semver" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" "nhooyr.io/websocket" @@ -252,16 +251,23 @@ const AgentAPIVersionREST = "1.0" func (api *API) postWorkspaceAgentStartup(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) - apiAgent, err := db2sdk.WorkspaceAgent( - api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, - api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), - ) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error reading workspace agent.", - Detail: err.Error(), - }) - return + + // As this API becomes deprecated, use the new protobuf API and convert the + // types back to the SDK types. + lifecycleAPI := &agentapi.LifecycleAPI{ + AgentFn: func(_ context.Context) (database.WorkspaceAgent, error) { return workspaceAgent, nil }, + WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) { + ws, err := api.Database.GetWorkspaceByAgentID(ctx, wa.ID) + if err != nil { + return uuid.Nil, err + } + return ws.Workspace.ID, nil + }, + Database: api.Database, + Log: api.Logger, + PublishWorkspaceUpdateFn: func(ctx context.Context, workspaceID uuid.UUID) { + api.publishWorkspaceUpdate(ctx, workspaceID) + }, } var req agentsdk.PostStartupRequest @@ -269,51 +275,36 @@ func (api *API) postWorkspaceAgentStartup(rw http.ResponseWriter, r *http.Reques return } - api.Logger.Debug( - ctx, - "post workspace agent version", - slog.F("agent_id", apiAgent.ID), - slog.F("agent_version", req.Version), - slog.F("remote_addr", r.RemoteAddr), - ) - - if !semver.IsValid(req.Version) { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid workspace agent version provided.", - Detail: fmt.Sprintf("invalid semver version: %q", req.Version), - }) - return - } - - // Validate subsystems. - seen := make(map[codersdk.AgentSubsystem]bool) - for _, s := range req.Subsystems { - if !s.Valid() { + // Convert subsystems. + protoSubsystems := make([]agentproto.Startup_Subsystem, len(req.Subsystems)) + for i, s := range req.Subsystems { + switch s { + case codersdk.AgentSubsystemEnvbox: + protoSubsystems[i] = agentproto.Startup_ENVBOX + case codersdk.AgentSubsystemEnvbuilder: + protoSubsystems[i] = agentproto.Startup_ENVBUILDER + case codersdk.AgentSubsystemExectrace: + protoSubsystems[i] = agentproto.Startup_EXECTRACE + default: httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid workspace agent subsystem provided.", Detail: fmt.Sprintf("invalid subsystem: %q", s), }) return } - if seen[s] { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid workspace agent subsystem provided.", - Detail: fmt.Sprintf("duplicate subsystem: %q", s), - }) - return - } - seen[s] = true } - if err := api.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{ - ID: apiAgent.ID, - Version: req.Version, - ExpandedDirectory: req.ExpandedDirectory, - Subsystems: convertWorkspaceAgentSubsystems(req.Subsystems), - APIVersion: AgentAPIVersionREST, - }); err != nil { + ctx = agentapi.SetWorkspaceAgentAPIVersion(ctx, AgentAPIVersionREST) + _, err := lifecycleAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{ + Startup: &agentproto.Startup{ + Version: req.Version, + ExpandedDirectory: req.ExpandedDirectory, + Subsystems: protoSubsystems, + }, + }) + if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Error setting agent version", + Message: "Internal error updating workspace agent startup.", Detail: err.Error(), }) return diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 5232b71113ea9..c71ad9e3b84a4 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -1431,12 +1431,14 @@ func TestWorkspaceAgent_Startup(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{ - Version: "1.2.3", + Version: "1.2.3", // missing "v" }) require.Error(t, err) cerr, ok := codersdk.AsError(err) require.True(t, ok) - require.Equal(t, http.StatusBadRequest, cerr.StatusCode()) + // This is supposed to be a 400, but during the deprecation phase it + // will be a 500 due to it calling the proto API. + require.Equal(t, http.StatusInternalServerError, cerr.StatusCode()) }) } From 3b9d3659f7f090fcc4a008ac6fb10af0d2934765 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Tue, 19 Dec 2023 12:55:34 +0000 Subject: [PATCH 4/7] avoid race in test --- coderd/agentapi/tailnet_test.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/coderd/agentapi/tailnet_test.go b/coderd/agentapi/tailnet_test.go index dae2a2c7ebe99..974d6eb2d707e 100644 --- a/coderd/agentapi/tailnet_test.go +++ b/coderd/agentapi/tailnet_test.go @@ -2,6 +2,7 @@ package agentapi_test import ( "context" + "sync" "testing" "time" @@ -45,10 +46,13 @@ func TestStreamDERPMaps(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() + derpMapMu := sync.Mutex{} derpMap := tailcfg.DERPMap{} api := &agentapi.TailnetAPI{ Ctx: context.Background(), DerpMapFn: func() *tailcfg.DERPMap { + derpMapMu.Lock() + defer derpMapMu.Unlock() derp := (&derpMap).Clone() return derp }, @@ -91,14 +95,18 @@ func TestStreamDERPMaps(t *testing.T) { require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) // Update the map, should get an update. + derpMapMu.Lock() derpMap.Regions = map[int]*tailcfg.DERPRegion{ 1: {}, } + derpMapMu.Unlock() gotMap = <-maps require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) // Update the map again, should get an update. + derpMapMu.Lock() derpMap.Regions = nil + derpMapMu.Unlock() gotMap = <-maps require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) From 0b1aa1237b9cfcb7a3002c1fdc979b6c5d913f66 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 25 Jan 2024 12:47:15 +0000 Subject: [PATCH 5/7] PR comments --- coderd/agentapi/apps_test.go | 23 ++- coderd/agentapi/lifecycle_test.go | 29 ++-- coderd/agentapi/logs_test.go | 78 +++++------ coderd/agentapi/manifest_test.go | 8 +- coderd/agentapi/metadata_test.go | 35 ++--- coderd/agentapi/servicebanner_test.go | 2 +- coderd/agentapi/stats_test.go | 14 +- coderd/agentapi/tailnet_test.go | 192 -------------------------- 8 files changed, 99 insertions(+), 282 deletions(-) delete mode 100644 coderd/agentapi/tailnet_test.go diff --git a/coderd/agentapi/apps_test.go b/coderd/agentapi/apps_test.go index 8d0b802063bfe..c774c6777b32a 100644 --- a/coderd/agentapi/apps_test.go +++ b/coderd/agentapi/apps_test.go @@ -2,12 +2,11 @@ package agentapi_test import ( "context" - "sync/atomic" "testing" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "cdr.dev/slog/sloggers/slogtest" @@ -56,7 +55,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { Health: database.WorkspaceAppHealthUnhealthy, }).Return(nil) - var publishCalled int64 + publishCalled := false api := &agentapi.AppsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -64,12 +63,12 @@ func TestBatchUpdateAppHealths(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } - // Set both to healthy, only one should be updated in the DB. + // Set one to healthy, set another to unhealthy. resp, err := api.BatchUpdateAppHealths(context.Background(), &agentproto.BatchUpdateAppHealthRequest{ Updates: []*agentproto.BatchUpdateAppHealthRequest_HealthUpdate{ { @@ -85,7 +84,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) - require.EqualValues(t, 1, atomic.LoadInt64(&publishCalled)) + require.True(t, publishCalled) }) t.Run("Unchanged", func(t *testing.T) { @@ -94,7 +93,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) dbM.EXPECT().GetWorkspaceAppsByAgentID(gomock.Any(), agent.ID).Return([]database.WorkspaceApp{app1, app2}, nil) - var publishCalled int64 + publishCalled := false api := &agentapi.AppsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -102,7 +101,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } @@ -124,7 +123,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) - require.EqualValues(t, 0, atomic.LoadInt64(&publishCalled)) + require.False(t, publishCalled) }) t.Run("Empty", func(t *testing.T) { @@ -133,7 +132,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { // No DB queries are made if there are no updates to process. dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishCalled int64 + publishCalled := false api := &agentapi.AppsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -141,7 +140,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } @@ -153,7 +152,7 @@ func TestBatchUpdateAppHealths(t *testing.T) { require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateAppHealthResponse{}, resp) - require.EqualValues(t, 0, atomic.LoadInt64(&publishCalled)) + require.False(t, publishCalled) }) t.Run("AppNoHealthcheck", func(t *testing.T) { diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go index 9029a84b955eb..855ff9329acc9 100644 --- a/coderd/agentapi/lifecycle_test.go +++ b/coderd/agentapi/lifecycle_test.go @@ -8,9 +8,9 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" "cdr.dev/slog/sloggers/slogtest" @@ -64,7 +64,7 @@ func TestUpdateLifecycle(t *testing.T) { ReadyAt: sql.NullTime{Valid: false}, }).Return(nil) - var publishCalled int64 + publishCalled := false api := &agentapi.LifecycleAPI{ AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil @@ -75,7 +75,7 @@ func TestUpdateLifecycle(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } @@ -85,7 +85,7 @@ func TestUpdateLifecycle(t *testing.T) { }) require.NoError(t, err) require.Equal(t, lifecycle, resp) - require.Equal(t, int64(1), atomic.LoadInt64(&publishCalled)) + require.True(t, publishCalled) }) t.Run("OKReadying", func(t *testing.T) { @@ -151,7 +151,7 @@ func TestUpdateLifecycle(t *testing.T) { }, }).Return(nil) - var publishCalled int64 + publishCalled := false api := &agentapi.LifecycleAPI{ AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil @@ -162,7 +162,7 @@ func TestUpdateLifecycle(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } @@ -172,7 +172,7 @@ func TestUpdateLifecycle(t *testing.T) { }) require.NoError(t, err) require.Equal(t, lifecycle, resp) - require.Equal(t, int64(1), atomic.LoadInt64(&publishCalled)) + require.True(t, publishCalled) }) t.Run("NoTimeSpecified", func(t *testing.T) { @@ -263,19 +263,20 @@ func TestUpdateLifecycle(t *testing.T) { } for i, state := range states { t.Log("state", state) - now := now.Add(time.Hour * time.Duration(i)) + // Use a time after the last state change to ensure ordering. + stateNow := now.Add(time.Hour * time.Duration(i)) lifecycle := &agentproto.Lifecycle{ State: state, - ChangedAt: timestamppb.New(now), + ChangedAt: timestamppb.New(stateNow), } expectedStartedAt := agent.StartedAt expectedReadyAt := agent.ReadyAt if state == agentproto.Lifecycle_STARTING { - expectedStartedAt = sql.NullTime{Valid: true, Time: now} + expectedStartedAt = sql.NullTime{Valid: true, Time: stateNow} } if state == agentproto.Lifecycle_READY || state == agentproto.Lifecycle_START_ERROR { - expectedReadyAt = sql.NullTime{Valid: true, Time: now} + expectedReadyAt = sql.NullTime{Valid: true, Time: stateNow} } dbM.EXPECT().UpdateWorkspaceAgentLifecycleStateByID(gomock.Any(), database.UpdateWorkspaceAgentLifecycleStateByIDParams{ @@ -308,7 +309,7 @@ func TestUpdateLifecycle(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishCalled int64 + publishCalled := false api := &agentapi.LifecycleAPI{ AgentFn: func(ctx context.Context) (database.WorkspaceAgent, error) { return agentCreated, nil @@ -319,7 +320,7 @@ func TestUpdateLifecycle(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error { - atomic.AddInt64(&publishCalled, 1) + publishCalled = true return nil }, } @@ -330,7 +331,7 @@ func TestUpdateLifecycle(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "unknown lifecycle state") require.Nil(t, resp) - require.Equal(t, int64(0), atomic.LoadInt64(&publishCalled)) + require.False(t, publishCalled) }) } diff --git a/coderd/agentapi/logs_test.go b/coderd/agentapi/logs_test.go index 1d4261a0191ea..66fbaa005d625 100644 --- a/coderd/agentapi/logs_test.go +++ b/coderd/agentapi/logs_test.go @@ -3,14 +3,14 @@ package agentapi_test import ( "context" "strings" - "sync/atomic" "testing" "time" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/lib/pq" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" "cdr.dev/slog/sloggers/slogtest" @@ -41,8 +41,8 @@ func TestBatchCreateLogs(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false now := dbtime.Now() api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { @@ -51,15 +51,15 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true // Check the message content, should be for -1 since the lowest // log we inserted was 0. - require.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg) + assert.Equal(t, agentsdk.LogsNotifyMessage{CreatedAfter: -1}, msg) }, TimeNowFn: func() time.Time { return now }, } @@ -134,8 +134,8 @@ func TestBatchCreateLogs(t *testing.T) { resp, err := api.BatchCreateLogs(context.Background(), req) require.NoError(t, err) require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) }) t.Run("NoWorkspacePublishIfNotFirstLogs", func(t *testing.T) { @@ -146,8 +146,8 @@ func TestBatchCreateLogs(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agentWithLogs, nil @@ -155,11 +155,11 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true }, } @@ -182,8 +182,8 @@ func TestBatchCreateLogs(t *testing.T) { }) require.NoError(t, err) require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) - require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.False(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) }) t.Run("AlreadyOverflowed", func(t *testing.T) { @@ -194,8 +194,8 @@ func TestBatchCreateLogs(t *testing.T) { overflowedAgent := agent overflowedAgent.LogsOverflowed = true - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return overflowedAgent, nil @@ -203,11 +203,11 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true }, } @@ -218,8 +218,8 @@ func TestBatchCreateLogs(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "workspace agent logs overflowed") require.Nil(t, resp) - require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.False(t, publishWorkspaceUpdateCalled) + require.False(t, publishWorkspaceAgentLogsUpdateCalled) }) t.Run("InvalidLogSourceID", func(t *testing.T) { @@ -287,8 +287,8 @@ func TestBatchCreateLogs(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -296,11 +296,11 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true }, TimeNowFn: func() time.Time { return now }, } @@ -322,8 +322,8 @@ func TestBatchCreateLogs(t *testing.T) { resp, err := api.BatchCreateLogs(context.Background(), req) require.NoError(t, err) require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) }) t.Run("Exists", func(t *testing.T) { @@ -331,8 +331,8 @@ func TestBatchCreateLogs(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -340,11 +340,11 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true }, TimeNowFn: func() time.Time { return now }, } @@ -368,8 +368,8 @@ func TestBatchCreateLogs(t *testing.T) { resp, err := api.BatchCreateLogs(context.Background(), req) require.NoError(t, err) require.Equal(t, &agentproto.BatchCreateLogsResponse{}, resp) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.True(t, publishWorkspaceUpdateCalled) + require.True(t, publishWorkspaceAgentLogsUpdateCalled) }) }) @@ -378,8 +378,8 @@ func TestBatchCreateLogs(t *testing.T) { dbM := dbmock.NewMockStore(gomock.NewController(t)) - var publishWorkspaceUpdateCalled int64 - var publishWorkspaceAgentLogsUpdateCalled int64 + publishWorkspaceUpdateCalled := false + publishWorkspaceAgentLogsUpdateCalled := false api := &agentapi.LogsAPI{ AgentFn: func(context.Context) (database.WorkspaceAgent, error) { return agent, nil @@ -387,11 +387,11 @@ func TestBatchCreateLogs(t *testing.T) { Database: dbM, Log: slogtest.Make(t, nil), PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error { - atomic.AddInt64(&publishWorkspaceUpdateCalled, 1) + publishWorkspaceUpdateCalled = true return nil }, PublishWorkspaceAgentLogsUpdateFn: func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage) { - atomic.AddInt64(&publishWorkspaceAgentLogsUpdateCalled, 1) + publishWorkspaceAgentLogsUpdateCalled = true }, } @@ -421,7 +421,7 @@ func TestBatchCreateLogs(t *testing.T) { }) require.Error(t, err) require.Nil(t, resp) - require.EqualValues(t, 1, atomic.LoadInt64(&publishWorkspaceUpdateCalled)) - require.EqualValues(t, 0, atomic.LoadInt64(&publishWorkspaceAgentLogsUpdateCalled)) + require.True(t, publishWorkspaceUpdateCalled) + require.False(t, publishWorkspaceAgentLogsUpdateCalled) }) } diff --git a/coderd/agentapi/manifest_test.go b/coderd/agentapi/manifest_test.go index 008618c12c57b..575bc353f7c53 100644 --- a/coderd/agentapi/manifest_test.go +++ b/coderd/agentapi/manifest_test.go @@ -9,10 +9,10 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/durationpb" "tailscale.com/tailcfg" @@ -300,8 +300,10 @@ func TestGetManifest(t *testing.T) { expected := &agentproto.Manifest{ AgentId: agent.ID[:], + AgentName: agent.Name, OwnerUsername: owner.Username, WorkspaceId: workspace.ID[:], + WorkspaceName: workspace.Name, GitAuthConfigs: 2, // two "enhanced" external auth configs EnvironmentVariables: expectedEnvVars, Directory: agent.Directory, @@ -365,12 +367,14 @@ func TestGetManifest(t *testing.T) { expected := &agentproto.Manifest{ AgentId: agent.ID[:], + AgentName: agent.Name, OwnerUsername: owner.Username, WorkspaceId: workspace.ID[:], + WorkspaceName: workspace.Name, GitAuthConfigs: 2, // two "enhanced" external auth configs EnvironmentVariables: expectedEnvVars, Directory: agent.Directory, - VsCodePortProxyUri: "https://example.com", + VsCodePortProxyUri: "", // empty with no AppHost MotdPath: agent.MOTDFile, DisableDirectConnections: true, DerpForceWebsockets: true, diff --git a/coderd/agentapi/metadata_test.go b/coderd/agentapi/metadata_test.go index 71abcc0e9f46a..f116eb82ce541 100644 --- a/coderd/agentapi/metadata_test.go +++ b/coderd/agentapi/metadata_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/timestamppb" "cdr.dev/slog/sloggers/slogtest" @@ -22,6 +22,19 @@ import ( "github.com/coder/coder/v2/coderd/database/pubsub" ) +type fakePublisher struct { + // Nil pointer to pass interface check. + pubsub.Pubsub + publishes [][]byte +} + +var _ pubsub.Pubsub = &fakePublisher{} + +func (f *fakePublisher) Publish(channel string, message []byte) error { + f.publishes = append(f.publishes, message) + return nil +} + func TestBatchUpdateMetadata(t *testing.T) { t.Parallel() @@ -33,7 +46,7 @@ func TestBatchUpdateMetadata(t *testing.T) { t.Parallel() dbM := dbmock.NewMockStore(gomock.NewController(t)) - pub := pubsub.NewInMemory() + pub := &fakePublisher{} now := dbtime.Now() req := &agentproto.BatchUpdateMetadataRequest{ @@ -80,25 +93,13 @@ func TestBatchUpdateMetadata(t *testing.T) { }, } - // Watch the pubsub for events. - var ( - eventCount int64 - gotEvent agentapi.WorkspaceAgentMetadataChannelPayload - ) - cancel, err := pub.Subscribe(agentapi.WatchWorkspaceAgentMetadataChannel(agent.ID), func(ctx context.Context, message []byte) { - if atomic.AddInt64(&eventCount, 1) > 1 { - return - } - require.NoError(t, json.Unmarshal(message, &gotEvent)) - }) - require.NoError(t, err) - defer cancel() - resp, err := api.BatchUpdateMetadata(context.Background(), req) require.NoError(t, err) require.Equal(t, &agentproto.BatchUpdateMetadataResponse{}, resp) - require.Equal(t, int64(1), atomic.LoadInt64(&eventCount)) + require.Equal(t, 1, len(pub.publishes)) + var gotEvent agentapi.WorkspaceAgentMetadataChannelPayload + require.NoError(t, json.Unmarshal(pub.publishes[0], &gotEvent)) require.Equal(t, agentapi.WorkspaceAgentMetadataChannelPayload{ CollectedAt: now, Keys: []string{req.Metadata[0].Key, req.Metadata[1].Key}, diff --git a/coderd/agentapi/servicebanner_test.go b/coderd/agentapi/servicebanner_test.go index f7a860a96b70e..902af7395e54d 100644 --- a/coderd/agentapi/servicebanner_test.go +++ b/coderd/agentapi/servicebanner_test.go @@ -6,8 +6,8 @@ import ( "encoding/json" "testing" - "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" agentproto "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/coderd/agentapi" diff --git a/coderd/agentapi/stats_test.go b/coderd/agentapi/stats_test.go index 409c24df60b43..a26e7fbf6ae7a 100644 --- a/coderd/agentapi/stats_test.go +++ b/coderd/agentapi/stats_test.go @@ -8,10 +8,10 @@ import ( "testing" "time" - "github.com/golang/mock/gomock" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "google.golang.org/protobuf/types/known/durationpb" agentproto "github.com/coder/coder/v2/agent/proto" @@ -89,7 +89,7 @@ func TestUpdateStates(t *testing.T) { }, } batcher = &statsBatcher{} - updateAgentMetricsFnCalled int64 + updateAgentMetricsFnCalled = false req = &agentproto.UpdateStatsRequest{ Stats: &agentproto.Stats{ @@ -129,7 +129,7 @@ func TestUpdateStates(t *testing.T) { TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), AgentStatsRefreshInterval: 10 * time.Second, UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { - atomic.AddInt64(&updateAgentMetricsFnCalled, 1) + updateAgentMetricsFnCalled = true assert.Equal(t, prometheusmetrics.AgentMetricLabels{ Username: user.Username, WorkspaceName: workspace.Name, @@ -179,6 +179,8 @@ func TestUpdateStates(t *testing.T) { require.Equal(t, user.ID, batcher.lastUserID) require.Equal(t, workspace.ID, batcher.lastWorkspaceID) require.Equal(t, req.Stats, batcher.lastStats) + + require.True(t, updateAgentMetricsFnCalled) }) t.Run("ConnectionCountZero", func(t *testing.T) { @@ -303,7 +305,7 @@ func TestUpdateStates(t *testing.T) { }, } batcher = &statsBatcher{} - updateAgentMetricsFnCalled int64 + updateAgentMetricsFnCalled = false req = &agentproto.UpdateStatsRequest{ Stats: &agentproto.Stats{ @@ -324,7 +326,7 @@ func TestUpdateStates(t *testing.T) { TemplateScheduleStore: templateScheduleStorePtr(templateScheduleStore), AgentStatsRefreshInterval: 15 * time.Second, UpdateAgentMetricsFn: func(ctx context.Context, labels prometheusmetrics.AgentMetricLabels, metrics []*agentproto.Stats_Metric) { - atomic.AddInt64(&updateAgentMetricsFnCalled, 1) + updateAgentMetricsFnCalled = true assert.Equal(t, prometheusmetrics.AgentMetricLabels{ Username: user.Username, WorkspaceName: workspace.Name, @@ -365,6 +367,8 @@ func TestUpdateStates(t *testing.T) { require.Equal(t, &agentproto.UpdateStatsResponse{ ReportInterval: durationpb.New(15 * time.Second), }, resp) + + require.True(t, updateAgentMetricsFnCalled) }) } diff --git a/coderd/agentapi/tailnet_test.go b/coderd/agentapi/tailnet_test.go deleted file mode 100644 index 974d6eb2d707e..0000000000000 --- a/coderd/agentapi/tailnet_test.go +++ /dev/null @@ -1,192 +0,0 @@ -package agentapi_test - -import ( - "context" - "sync" - "testing" - "time" - - "golang.org/x/xerrors" - "storj.io/drpc" - "tailscale.com/tailcfg" - - "github.com/stretchr/testify/require" - - agentproto "github.com/coder/coder/v2/agent/proto" - "github.com/coder/coder/v2/coderd/agentapi" - "github.com/coder/coder/v2/tailnet" - tailnetproto "github.com/coder/coder/v2/tailnet/proto" -) - -type fakeDERPMapStream struct { - drpc.Stream // to fake implement unused members - - ctx context.Context - closeFn func() error - sendFn func(*tailnetproto.DERPMap) error -} - -var _ agentproto.DRPCAgent_StreamDERPMapsStream = &fakeDERPMapStream{} - -func (s *fakeDERPMapStream) Context() context.Context { - return s.ctx -} - -func (s *fakeDERPMapStream) Close() error { - return s.closeFn() -} - -func (s *fakeDERPMapStream) Send(m *tailnetproto.DERPMap) error { - return s.sendFn(m) -} - -func TestStreamDERPMaps(t *testing.T) { - t.Parallel() - - t.Run("OK", func(t *testing.T) { - t.Parallel() - - derpMapMu := sync.Mutex{} - derpMap := tailcfg.DERPMap{} - api := &agentapi.TailnetAPI{ - Ctx: context.Background(), - DerpMapFn: func() *tailcfg.DERPMap { - derpMapMu.Lock() - defer derpMapMu.Unlock() - derp := (&derpMap).Clone() - return derp - }, - DerpMapUpdateFrequency: time.Millisecond, - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - closed := make(chan struct{}) - maps := make(chan *tailnetproto.DERPMap, 10) - stream := &fakeDERPMapStream{ - ctx: ctx, - closeFn: func() error { - select { - case <-ctx.Done(): - default: - t.Fatal("expected context to be canceled before close") - } - close(closed) - return nil - }, - sendFn: func(m *tailnetproto.DERPMap) error { - if m == nil { - t.Fatal("expected non-nil map") - } - maps <- m - return nil - }, - } - - errCh := make(chan error) - go func() { - // Request isn't used. - errCh <- api.StreamDERPMaps(nil, stream) - }() - - // Initial map. - gotMap := <-maps - require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) - - // Update the map, should get an update. - derpMapMu.Lock() - derpMap.Regions = map[int]*tailcfg.DERPRegion{ - 1: {}, - } - derpMapMu.Unlock() - gotMap = <-maps - require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) - - // Update the map again, should get an update. - derpMapMu.Lock() - derpMap.Regions = nil - derpMapMu.Unlock() - gotMap = <-maps - require.Equal(t, tailnet.DERPMapToProto(&derpMap), gotMap) - - // Cancel the stream, should return the fn. - cancel() - <-closed - require.NoError(t, <-errCh) - }) - - t.Run("SendFailure", func(t *testing.T) { - t.Parallel() - - api := &agentapi.TailnetAPI{ - Ctx: context.Background(), - DerpMapFn: func() *tailcfg.DERPMap { - return &tailcfg.DERPMap{} - }, - DerpMapUpdateFrequency: time.Millisecond, - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - stream := &fakeDERPMapStream{ - ctx: ctx, - closeFn: func() error { - return nil - }, - sendFn: func(m *tailnetproto.DERPMap) error { - return xerrors.New("test error") - }, - } - - err := api.StreamDERPMaps(nil, stream) - require.Error(t, err) - require.ErrorContains(t, err, "send derp map") - require.ErrorContains(t, err, "test error") - }) - - t.Run("GlobalContextCanceled", func(t *testing.T) { - t.Parallel() - - globalCtx, globalCtxCancel := context.WithCancel(context.Background()) - api := &agentapi.TailnetAPI{ - Ctx: globalCtx, - DerpMapFn: func() *tailcfg.DERPMap { - return &tailcfg.DERPMap{} - }, - DerpMapUpdateFrequency: time.Hour, // long time to make sure ctx cancels are quick - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - maps := make(chan *tailnetproto.DERPMap, 10) - stream := &fakeDERPMapStream{ - ctx: ctx, - closeFn: func() error { - return nil - }, - sendFn: func(m *tailnetproto.DERPMap) error { - if m == nil { - t.Fatal("expected non-nil map") - } - maps <- m - return nil - }, - } - - errCh := make(chan error) - go func() { - // Request isn't used. - errCh <- api.StreamDERPMaps(nil, stream) - }() - - // Initial map. - <-maps - - // Cancel the global context, should return the fn. - globalCtxCancel() - require.NoError(t, <-errCh) - }) -} From 08b67808aa8fd6d1d0ee8ac96362495107110da0 Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Thu, 25 Jan 2024 13:02:42 +0000 Subject: [PATCH 6/7] fixup! PR comments --- coderd/workspaceagentsrpc.go | 1 + 1 file changed, 1 insertion(+) diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index 163d55c9f2553..8e02fc878e243 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -128,6 +128,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Database: api.Database, Pubsub: api.Pubsub, DerpMapFn: api.DERPMap, + TailnetCoordinator: &api.TailnetCoordinator, TemplateScheduleStore: api.TemplateScheduleStore, StatsBatcher: api.statsBatcher, PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate, From 7a165ddf12bc5a6acc3657abfe8e359e3e162f4a Mon Sep 17 00:00:00 2001 From: Dean Sheather Date: Fri, 26 Jan 2024 06:52:24 +0000 Subject: [PATCH 7/7] fixup! PR comments --- coderd/agentapi/metadata_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/coderd/agentapi/metadata_test.go b/coderd/agentapi/metadata_test.go index f116eb82ce541..c3d0ec5528ea8 100644 --- a/coderd/agentapi/metadata_test.go +++ b/coderd/agentapi/metadata_test.go @@ -30,7 +30,7 @@ type fakePublisher struct { var _ pubsub.Pubsub = &fakePublisher{} -func (f *fakePublisher) Publish(channel string, message []byte) error { +func (f *fakePublisher) Publish(_ string, message []byte) error { f.publishes = append(f.publishes, message) return nil }