diff --git a/agent/agent.go b/agent/agent.go index 528506e020b5e..fe0f6a7b0d36e 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -32,6 +32,7 @@ import ( "golang.org/x/xerrors" "tailscale.com/net/speedtest" "tailscale.com/tailcfg" + "tailscale.com/types/netlogtype" "cdr.dev/slog" "github.com/coder/coder/agent/usershell" @@ -98,7 +99,6 @@ func New(options Options) io.Closer { exchangeToken: options.ExchangeToken, filesystem: options.Filesystem, tempDir: options.TempDir, - stats: &Stats{}, } server.init(ctx) return server @@ -126,7 +126,6 @@ type agent struct { sshServer *ssh.Server network *tailnet.Conn - stats *Stats } // runLoop attempts to start the agent in a retry loop. @@ -238,22 +237,16 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t return nil, xerrors.New("closed") } network, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)}, - DERPMap: derpMap, - Logger: a.logger.Named("tailnet"), + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.TailnetIP, 128)}, + DERPMap: derpMap, + Logger: a.logger.Named("tailnet"), + EnableTrafficStats: true, }) if err != nil { a.closeMutex.Unlock() return nil, xerrors.Errorf("create tailnet: %w", err) } a.network = network - network.SetForwardTCPCallback(func(conn net.Conn, listenerExists bool) net.Conn { - if listenerExists { - // If a listener already exists, we would double-wrap the conn. - return conn - } - return a.stats.wrapConn(conn) - }) a.connCloseWait.Add(4) a.closeMutex.Unlock() @@ -268,7 +261,7 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t if err != nil { return } - go a.sshServer.HandleConn(a.stats.wrapConn(conn)) + go a.sshServer.HandleConn(conn) } }() @@ -284,7 +277,6 @@ func (a *agent) createTailnet(ctx context.Context, derpMap *tailcfg.DERPMap) (*t a.logger.Debug(ctx, "accept pty failed", slog.Error(err)) return } - conn = a.stats.wrapConn(conn) // This cannot use a JSON decoder, since that can // buffer additional data that is required for the PTY. rawLen := make([]byte, 2) @@ -523,7 +515,13 @@ func (a *agent) init(ctx context.Context) { go a.runLoop(ctx) cl, err := a.client.AgentReportStats(ctx, a.logger, func() *codersdk.AgentStats { - return a.stats.Copy() + stats := map[netlogtype.Connection]netlogtype.Counts{} + a.closeMutex.Lock() + if a.network != nil { + stats = a.network.ExtractTrafficStats() + } + a.closeMutex.Unlock() + return convertAgentStats(stats) }) if err != nil { a.logger.Error(ctx, "report stats", slog.Error(err)) @@ -537,6 +535,23 @@ func (a *agent) init(ctx context.Context) { }() } +func convertAgentStats(counts map[netlogtype.Connection]netlogtype.Counts) *codersdk.AgentStats { + stats := &codersdk.AgentStats{ + ConnsByProto: map[string]int64{}, + NumConns: int64(len(counts)), + } + + for conn, count := range counts { + stats.ConnsByProto[conn.Proto.String()]++ + stats.RxPackets += int64(count.RxPackets) + stats.RxBytes += int64(count.RxBytes) + stats.TxPackets += int64(count.TxPackets) + stats.TxBytes += int64(count.TxBytes) + } + + return stats +} + // createCommand processes raw command input with OpenSSH-like behavior. // If the rawCommand provided is empty, it will default to the users shell. // This injects environment variables specified by the user at launch too. diff --git a/agent/agent_test.go b/agent/agent_test.go index bc79c97fcd2a4..92a5fe6d0f149 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -69,10 +69,16 @@ func TestAgent(t *testing.T) { session, err := sshClient.NewSession() require.NoError(t, err) defer session.Close() + require.NoError(t, session.Run("echo test")) - assert.EqualValues(t, 1, (<-stats).NumConns) - assert.Greater(t, (<-stats).RxBytes, int64(0)) - assert.Greater(t, (<-stats).TxBytes, int64(0)) + var s *codersdk.AgentStats + require.Eventuallyf(t, func() bool { + var ok bool + s, ok = <-stats + return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 + }, testutil.WaitLong, testutil.IntervalFast, + "never saw stats: %+v", s, + ) }) t.Run("ReconnectingPTY", func(t *testing.T) { @@ -97,7 +103,7 @@ func TestAgent(t *testing.T) { var s *codersdk.AgentStats require.Eventuallyf(t, func() bool { var ok bool - s, ok = (<-stats) + s, ok = <-stats return ok && s.NumConns > 0 && s.RxBytes > 0 && s.TxBytes > 0 }, testutil.WaitLong, testutil.IntervalFast, "never saw stats: %+v", s, @@ -675,7 +681,7 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo } coordinator := tailnet.NewCoordinator() agentID := uuid.New() - statsCh := make(chan *codersdk.AgentStats) + statsCh := make(chan *codersdk.AgentStats, 50) fs := afero.NewMemMapFs() closer := agent.New(agent.Options{ Client: &client{ @@ -693,9 +699,10 @@ func setupAgent(t *testing.T, metadata codersdk.WorkspaceAgentMetadata, ptyTimeo _ = closer.Close() }) conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, - DERPMap: metadata.DERPMap, - Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + Addresses: []netip.Prefix{netip.PrefixFrom(tailnet.IP(), 128)}, + DERPMap: metadata.DERPMap, + Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug), + EnableTrafficStats: true, }) require.NoError(t, err) clientConn, serverConn := net.Pipe() @@ -781,7 +788,7 @@ func (c *client) AgentReportStats(ctx context.Context, _ slog.Logger, stats func go func() { defer close(doneCh) - t := time.NewTicker(time.Millisecond * 100) + t := time.NewTicker(500 * time.Millisecond) defer t.Stop() for { select { diff --git a/agent/stats.go b/agent/stats.go deleted file mode 100644 index e7b45e576b6c2..0000000000000 --- a/agent/stats.go +++ /dev/null @@ -1,58 +0,0 @@ -package agent - -import ( - "net" - "sync/atomic" - - "github.com/coder/coder/codersdk" -) - -// statsConn wraps a net.Conn with statistics. -type statsConn struct { - *Stats - net.Conn `json:"-"` -} - -var _ net.Conn = new(statsConn) - -func (c *statsConn) Read(b []byte) (n int, err error) { - n, err = c.Conn.Read(b) - atomic.AddInt64(&c.RxBytes, int64(n)) - return n, err -} - -func (c *statsConn) Write(b []byte) (n int, err error) { - n, err = c.Conn.Write(b) - atomic.AddInt64(&c.TxBytes, int64(n)) - return n, err -} - -var _ net.Conn = new(statsConn) - -// Stats records the Agent's network connection statistics for use in -// user-facing metrics and debugging. -// Each member value must be written and read with atomic. -type Stats struct { - NumConns int64 `json:"num_comms"` - RxBytes int64 `json:"rx_bytes"` - TxBytes int64 `json:"tx_bytes"` -} - -func (s *Stats) Copy() *codersdk.AgentStats { - return &codersdk.AgentStats{ - NumConns: atomic.LoadInt64(&s.NumConns), - RxBytes: atomic.LoadInt64(&s.RxBytes), - TxBytes: atomic.LoadInt64(&s.TxBytes), - } -} - -// wrapConn returns a new connection that records statistics. -func (s *Stats) wrapConn(conn net.Conn) net.Conn { - atomic.AddInt64(&s.NumConns, 1) - cs := &statsConn{ - Stats: s, - Conn: conn, - } - - return cs -} diff --git a/coderd/activitybump.go b/coderd/activitybump.go index 5b4bd16f22817..1324e9c821896 100644 --- a/coderd/activitybump.go +++ b/coderd/activitybump.go @@ -6,6 +6,7 @@ import ( "errors" "time" + "github.com/google/uuid" "golang.org/x/xerrors" "cdr.dev/slog" @@ -14,14 +15,14 @@ import ( // activityBumpWorkspace automatically bumps the workspace's auto-off timer // if it is set to expire soon. -func activityBumpWorkspace(log slog.Logger, db database.Store, workspace database.Workspace) { +func activityBumpWorkspace(log slog.Logger, db database.Store, workspaceID uuid.UUID) { // We set a short timeout so if the app is under load, these // low priority operations fail first. ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() err := db.InTx(func(s database.Store) error { - build, err := s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID) + build, err := s.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspaceID) if errors.Is(err, sql.ErrNoRows) { return nil } else if err != nil { @@ -65,15 +66,13 @@ func activityBumpWorkspace(log slog.Logger, db database.Store, workspace databas return nil }, nil) if err != nil { - log.Error( - ctx, "bump failed", - slog.Error(err), - slog.F("workspace_id", workspace.ID), - ) - } else { - log.Debug( - ctx, "bumped deadline from activity", - slog.F("workspace_id", workspace.ID), + log.Error(ctx, "bump failed", slog.Error(err), + slog.F("workspace_id", workspaceID), ) + return } + + log.Debug(ctx, "bumped deadline from activity", + slog.F("workspace_id", workspaceID), + ) } diff --git a/coderd/activitybump_test.go b/coderd/activitybump_test.go index f9c0736e0c7b7..ffa5e434cf6ca 100644 --- a/coderd/activitybump_test.go +++ b/coderd/activitybump_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/require" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database" "github.com/coder/coder/codersdk" diff --git a/coderd/coderd.go b/coderd/coderd.go index 3f7d3d7211321..468aa919fe4cd 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -132,7 +132,7 @@ func New(options *Options) *API { options.APIRateLimit = 512 } if options.AgentStatsRefreshInterval == 0 { - options.AgentStatsRefreshInterval = 10 * time.Minute + options.AgentStatsRefreshInterval = 5 * time.Minute } if options.MetricsCacheRefreshInterval == 0 { options.MetricsCacheRefreshInterval = time.Hour @@ -493,7 +493,10 @@ func New(options *Options) *API { r.Get("/gitauth", api.workspaceAgentsGitAuth) r.Get("/gitsshkey", api.agentGitSSHKey) r.Get("/coordinate", api.workspaceAgentCoordinate) - r.Get("/report-stats", api.workspaceAgentReportStats) + r.Post("/report-stats", api.workspaceAgentReportStats) + // DEPRECATED in favor of the POST endpoint above. + // TODO: remove in January 2023 + r.Get("/report-stats", api.workspaceAgentReportStatsWebsocket) }) r.Route("/{workspaceagent}", func(r chi.Router) { r.Use( diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index ef308fba97dff..127c5037bda67 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -64,6 +64,7 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { "POST:/api/v2/workspaceagents/me/version": {NoAuthorize: true}, "POST:/api/v2/workspaceagents/me/app-health": {NoAuthorize: true}, "GET:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, + "POST:/api/v2/workspaceagents/me/report-stats": {NoAuthorize: true}, // These endpoints have more assertions. This is good, add more endpoints to assert if you can! "GET:/api/v2/organizations/{organization}": {AssertObject: rbac.ResourceOrganization.InOrg(a.Admin.OrganizationID)}, diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 953e8db9f8ee9..23572e1ad18a2 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -30,31 +30,31 @@ func New() database.Store { return &fakeQuerier{ mutex: &sync.RWMutex{}, data: &data{ - apiKeys: make([]database.APIKey, 0), - agentStats: make([]database.AgentStat, 0), - organizationMembers: make([]database.OrganizationMember, 0), - organizations: make([]database.Organization, 0), - users: make([]database.User, 0), - gitAuthLinks: make([]database.GitAuthLink, 0), - groups: make([]database.Group, 0), - groupMembers: make([]database.GroupMember, 0), - auditLogs: make([]database.AuditLog, 0), - files: make([]database.File, 0), - gitSSHKey: make([]database.GitSSHKey, 0), - parameterSchemas: make([]database.ParameterSchema, 0), - parameterValues: make([]database.ParameterValue, 0), - provisionerDaemons: make([]database.ProvisionerDaemon, 0), - provisionerJobAgents: make([]database.WorkspaceAgent, 0), - provisionerJobLogs: make([]database.ProvisionerJobLog, 0), - provisionerJobResources: make([]database.WorkspaceResource, 0), - provisionerJobResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0), - provisionerJobs: make([]database.ProvisionerJob, 0), - templateVersions: make([]database.TemplateVersion, 0), - templates: make([]database.Template, 0), - workspaceBuilds: make([]database.WorkspaceBuild, 0), - workspaceApps: make([]database.WorkspaceApp, 0), - workspaces: make([]database.Workspace, 0), - licenses: make([]database.License, 0), + apiKeys: make([]database.APIKey, 0), + agentStats: make([]database.AgentStat, 0), + organizationMembers: make([]database.OrganizationMember, 0), + organizations: make([]database.Organization, 0), + users: make([]database.User, 0), + gitAuthLinks: make([]database.GitAuthLink, 0), + groups: make([]database.Group, 0), + groupMembers: make([]database.GroupMember, 0), + auditLogs: make([]database.AuditLog, 0), + files: make([]database.File, 0), + gitSSHKey: make([]database.GitSSHKey, 0), + parameterSchemas: make([]database.ParameterSchema, 0), + parameterValues: make([]database.ParameterValue, 0), + provisionerDaemons: make([]database.ProvisionerDaemon, 0), + workspaceAgents: make([]database.WorkspaceAgent, 0), + provisionerJobLogs: make([]database.ProvisionerJobLog, 0), + workspaceResources: make([]database.WorkspaceResource, 0), + workspaceResourceMetadata: make([]database.WorkspaceResourceMetadatum, 0), + provisionerJobs: make([]database.ProvisionerJob, 0), + templateVersions: make([]database.TemplateVersion, 0), + templates: make([]database.Template, 0), + workspaceBuilds: make([]database.WorkspaceBuild, 0), + workspaceApps: make([]database.WorkspaceApp, 0), + workspaces: make([]database.Workspace, 0), + licenses: make([]database.License, 0), }, } } @@ -89,28 +89,28 @@ type data struct { userLinks []database.UserLink // New tables - agentStats []database.AgentStat - auditLogs []database.AuditLog - files []database.File - gitAuthLinks []database.GitAuthLink - gitSSHKey []database.GitSSHKey - groups []database.Group - groupMembers []database.GroupMember - parameterSchemas []database.ParameterSchema - parameterValues []database.ParameterValue - provisionerDaemons []database.ProvisionerDaemon - provisionerJobAgents []database.WorkspaceAgent - provisionerJobLogs []database.ProvisionerJobLog - provisionerJobResources []database.WorkspaceResource - provisionerJobResourceMetadata []database.WorkspaceResourceMetadatum - provisionerJobs []database.ProvisionerJob - templateVersions []database.TemplateVersion - templates []database.Template - workspaceBuilds []database.WorkspaceBuild - workspaceApps []database.WorkspaceApp - workspaces []database.Workspace - licenses []database.License - replicas []database.Replica + agentStats []database.AgentStat + auditLogs []database.AuditLog + files []database.File + gitAuthLinks []database.GitAuthLink + gitSSHKey []database.GitSSHKey + groupMembers []database.GroupMember + groups []database.Group + licenses []database.License + parameterSchemas []database.ParameterSchema + parameterValues []database.ParameterValue + provisionerDaemons []database.ProvisionerDaemon + provisionerJobLogs []database.ProvisionerJobLog + provisionerJobs []database.ProvisionerJob + replicas []database.Replica + templateVersions []database.TemplateVersion + templates []database.Template + workspaceAgents []database.WorkspaceAgent + workspaceApps []database.WorkspaceApp + workspaceBuilds []database.WorkspaceBuild + workspaceResourceMetadata []database.WorkspaceResourceMetadatum + workspaceResources []database.WorkspaceResource + workspaces []database.Workspace deploymentID string derpMeshKey string @@ -942,6 +942,52 @@ func (q *fakeQuerier) GetWorkspaceByID(_ context.Context, id uuid.UUID) (databas return database.Workspace{}, sql.ErrNoRows } +func (q *fakeQuerier) GetWorkspaceByAgentID(_ context.Context, agentID uuid.UUID) (database.Workspace, error) { + q.mutex.RLock() + defer q.mutex.RUnlock() + + var agent database.WorkspaceAgent + for _, _agent := range q.workspaceAgents { + if _agent.ID == agentID { + agent = _agent + break + } + } + if agent.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } + + var resource database.WorkspaceResource + for _, _resource := range q.workspaceResources { + if _resource.ID == agent.ResourceID { + resource = _resource + break + } + } + if resource.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } + + var build database.WorkspaceBuild + for _, _build := range q.workspaceBuilds { + if _build.JobID == resource.JobID { + build = _build + break + } + } + if build.ID == uuid.Nil { + return database.Workspace{}, sql.ErrNoRows + } + + for _, workspace := range q.workspaces { + if workspace.ID == build.WorkspaceID { + return workspace, nil + } + } + + return database.Workspace{}, sql.ErrNoRows +} + func (q *fakeQuerier) GetWorkspaceByOwnerIDAndName(_ context.Context, arg database.GetWorkspaceByOwnerIDAndNameParams) (database.Workspace, error) { q.mutex.RLock() defer q.mutex.RUnlock() @@ -1801,8 +1847,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken defer q.mutex.RUnlock() // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { - agent := q.provisionerJobAgents[i] + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] if agent.AuthToken == authToken { return agent, nil } @@ -1815,8 +1861,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByID(_ context.Context, id uuid.UUID) (da defer q.mutex.RUnlock() // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { - agent := q.provisionerJobAgents[i] + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] if agent.ID == id { return agent, nil } @@ -1829,8 +1875,8 @@ func (q *fakeQuerier) GetWorkspaceAgentByInstanceID(_ context.Context, instanceI defer q.mutex.RUnlock() // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.provisionerJobAgents) - 1; i >= 0; i-- { - agent := q.provisionerJobAgents[i] + for i := len(q.workspaceAgents) - 1; i >= 0; i-- { + agent := q.workspaceAgents[i] if agent.AuthInstanceID.Valid && agent.AuthInstanceID.String == instanceID { return agent, nil } @@ -1843,7 +1889,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc defer q.mutex.RUnlock() workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.provisionerJobAgents { + for _, agent := range q.workspaceAgents { for _, resourceID := range resourceIDs { if agent.ResourceID != resourceID { continue @@ -1859,7 +1905,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsCreatedAfter(_ context.Context, after ti defer q.mutex.RUnlock() workspaceAgents := make([]database.WorkspaceAgent, 0) - for _, agent := range q.provisionerJobAgents { + for _, agent := range q.workspaceAgents { if agent.CreatedAt.After(after) { workspaceAgents = append(workspaceAgents, agent) } @@ -1913,7 +1959,7 @@ func (q *fakeQuerier) GetWorkspaceResourceByID(_ context.Context, id uuid.UUID) q.mutex.RLock() defer q.mutex.RUnlock() - for _, resource := range q.provisionerJobResources { + for _, resource := range q.workspaceResources { if resource.ID == id { return resource, nil } @@ -1926,7 +1972,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobID(_ context.Context, jobID uuid defer q.mutex.RUnlock() resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.provisionerJobResources { + for _, resource := range q.workspaceResources { if resource.JobID != jobID { continue } @@ -1940,7 +1986,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesByJobIDs(_ context.Context, jobIDs [] defer q.mutex.RUnlock() resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.provisionerJobResources { + for _, resource := range q.workspaceResources { for _, jobID := range jobIDs { if resource.JobID != jobID { continue @@ -1956,7 +2002,7 @@ func (q *fakeQuerier) GetWorkspaceResourcesCreatedAfter(_ context.Context, after defer q.mutex.RUnlock() resources := make([]database.WorkspaceResource, 0) - for _, resource := range q.provisionerJobResources { + for _, resource := range q.workspaceResources { if resource.CreatedAt.After(after) { resources = append(resources, resource) } @@ -1978,7 +2024,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataCreatedAfter(ctx context.Conte defer q.mutex.RUnlock() metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, m := range q.provisionerJobResourceMetadata { + for _, m := range q.workspaceResourceMetadata { _, ok := resourceIDs[m.WorkspaceResourceID] if !ok { continue @@ -1993,7 +2039,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceID(_ context.Context defer q.mutex.RUnlock() metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.provisionerJobResourceMetadata { + for _, metadatum := range q.workspaceResourceMetadata { if metadatum.WorkspaceResourceID == id { metadata = append(metadata, metadatum) } @@ -2006,7 +2052,7 @@ func (q *fakeQuerier) GetWorkspaceResourceMetadataByResourceIDs(_ context.Contex defer q.mutex.RUnlock() metadata := make([]database.WorkspaceResourceMetadatum, 0) - for _, metadatum := range q.provisionerJobResourceMetadata { + for _, metadatum := range q.workspaceResourceMetadata { for _, id := range ids { if metadatum.WorkspaceResourceID == id { metadata = append(metadata, metadatum) @@ -2319,7 +2365,7 @@ func (q *fakeQuerier) InsertWorkspaceAgent(_ context.Context, arg database.Inser TroubleshootingURL: arg.TroubleshootingURL, } - q.provisionerJobAgents = append(q.provisionerJobAgents, agent) + q.workspaceAgents = append(q.workspaceAgents, agent) return agent, nil } @@ -2339,7 +2385,7 @@ func (q *fakeQuerier) InsertWorkspaceResource(_ context.Context, arg database.In Icon: arg.Icon, DailyCost: arg.DailyCost, } - q.provisionerJobResources = append(q.provisionerJobResources, resource) + q.workspaceResources = append(q.workspaceResources, resource) return resource, nil } @@ -2354,7 +2400,7 @@ func (q *fakeQuerier) InsertWorkspaceResourceMetadata(_ context.Context, arg dat Value: arg.Value, Sensitive: arg.Sensitive, } - q.provisionerJobResourceMetadata = append(q.provisionerJobResourceMetadata, metadatum) + q.workspaceResourceMetadata = append(q.workspaceResourceMetadata, metadatum) return metadatum, nil } @@ -2681,7 +2727,7 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg q.mutex.Lock() defer q.mutex.Unlock() - for index, agent := range q.provisionerJobAgents { + for index, agent := range q.workspaceAgents { if agent.ID != arg.ID { continue } @@ -2689,7 +2735,7 @@ func (q *fakeQuerier) UpdateWorkspaceAgentConnectionByID(_ context.Context, arg agent.LastConnectedAt = arg.LastConnectedAt agent.DisconnectedAt = arg.DisconnectedAt agent.UpdatedAt = arg.UpdatedAt - q.provisionerJobAgents[index] = agent + q.workspaceAgents[index] = agent return nil } return sql.ErrNoRows @@ -2699,13 +2745,13 @@ func (q *fakeQuerier) UpdateWorkspaceAgentVersionByID(_ context.Context, arg dat q.mutex.Lock() defer q.mutex.Unlock() - for index, agent := range q.provisionerJobAgents { + for index, agent := range q.workspaceAgents { if agent.ID != arg.ID { continue } agent.Version = arg.Version - q.provisionerJobAgents[index] = agent + q.workspaceAgents[index] = agent return nil } return sql.ErrNoRows diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 151cb0cde9f29..35457c5dc00a6 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -113,6 +113,7 @@ type sqlcQuerier interface { GetWorkspaceBuildByWorkspaceIDAndBuildNumber(ctx context.Context, arg GetWorkspaceBuildByWorkspaceIDAndBuildNumberParams) (WorkspaceBuild, error) GetWorkspaceBuildsByWorkspaceID(ctx context.Context, arg GetWorkspaceBuildsByWorkspaceIDParams) ([]WorkspaceBuild, error) GetWorkspaceBuildsCreatedAfter(ctx context.Context, createdAt time.Time) ([]WorkspaceBuild, error) + GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error) GetWorkspaceByID(ctx context.Context, id uuid.UUID) (Workspace, error) GetWorkspaceByOwnerIDAndName(ctx context.Context, arg GetWorkspaceByOwnerIDAndNameParams) (Workspace, error) GetWorkspaceCountByUserID(ctx context.Context, ownerID uuid.UUID) (int64, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 60932097bb51b..0643eaeff6d93 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -17,7 +17,7 @@ import ( ) const deleteOldAgentStats = `-- name: DeleteOldAgentStats :exec -DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days' +DELETE FROM agent_stats WHERE created_at < NOW() - INTERVAL '30 days' ` func (q *sqlQuerier) DeleteOldAgentStats(ctx context.Context) error { @@ -45,16 +45,17 @@ func (q *sqlQuerier) GetLatestAgentStat(ctx context.Context, agentID uuid.UUID) } const getTemplateDAUs = `-- name: GetTemplateDAUs :many -select +SELECT (created_at at TIME ZONE 'UTC')::date as date, user_id -from +FROM agent_stats -where template_id = $1 -group by +WHERE + template_id = $1 +GROUP BY date, user_id -order by - date asc +ORDER BY + date ASC ` type GetTemplateDAUsRow struct { @@ -6145,6 +6146,55 @@ func (q *sqlQuerier) InsertWorkspaceResourceMetadata(ctx context.Context, arg In return i, err } +const getWorkspaceByAgentID = `-- name: GetWorkspaceByAgentID :one +SELECT + id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at +FROM + workspaces +WHERE + workspaces.id = ( + SELECT + workspace_id + FROM + workspace_builds + WHERE + workspace_builds.job_id = ( + SELECT + job_id + FROM + workspace_resources + WHERE + workspace_resources.id = ( + SELECT + resource_id + FROM + workspace_agents + WHERE + workspace_agents.id = $1 + ) + ) + ) +` + +func (q *sqlQuerier) GetWorkspaceByAgentID(ctx context.Context, agentID uuid.UUID) (Workspace, error) { + row := q.db.QueryRowContext(ctx, getWorkspaceByAgentID, agentID) + var i Workspace + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.OwnerID, + &i.OrganizationID, + &i.TemplateID, + &i.Deleted, + &i.Name, + &i.AutostartSchedule, + &i.Ttl, + &i.LastUsedAt, + ) + return i, err +} + const getWorkspaceByID = `-- name: GetWorkspaceByID :one SELECT id, created_at, updated_at, owner_id, organization_id, template_id, deleted, name, autostart_schedule, ttl, last_used_at diff --git a/coderd/database/queries/agentstats.sql b/coderd/database/queries/agentstats.sql index ddb7d04aa0a69..4d94cd98b9f25 100644 --- a/coderd/database/queries/agentstats.sql +++ b/coderd/database/queries/agentstats.sql @@ -16,16 +16,17 @@ VALUES SELECT * FROM agent_stats WHERE agent_id = $1 ORDER BY created_at DESC LIMIT 1; -- name: GetTemplateDAUs :many -select +SELECT (created_at at TIME ZONE 'UTC')::date as date, user_id -from +FROM agent_stats -where template_id = $1 -group by +WHERE + template_id = $1 +GROUP BY date, user_id -order by - date asc; +ORDER BY + date ASC; -- name: DeleteOldAgentStats :exec -DELETE FROM AGENT_STATS WHERE created_at < now() - interval '30 days'; +DELETE FROM agent_stats WHERE created_at < NOW() - INTERVAL '30 days'; diff --git a/coderd/database/queries/workspaces.sql b/coderd/database/queries/workspaces.sql index 65815c9af9ddd..071a970a66975 100644 --- a/coderd/database/queries/workspaces.sql +++ b/coderd/database/queries/workspaces.sql @@ -8,6 +8,35 @@ WHERE LIMIT 1; +-- name: GetWorkspaceByAgentID :one +SELECT + * +FROM + workspaces +WHERE + workspaces.id = ( + SELECT + workspace_id + FROM + workspace_builds + WHERE + workspace_builds.job_id = ( + SELECT + job_id + FROM + workspace_resources + WHERE + workspace_resources.id = ( + SELECT + resource_id + FROM + workspace_agents + WHERE + workspace_agents.id = @agent_id + ) + ) + ); + -- name: GetWorkspaces :many SELECT workspaces.*, COUNT(*) OVER () as count diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 278cf7c0d8144..21b6d83b599c3 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -75,41 +75,41 @@ func (api *API) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) }) return } - dbApps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID) + dbApps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID) if err != nil && !xerrors.Is(err, sql.ErrNoRows) { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace agent applications.", Detail: err.Error(), }) return } - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace resource.", Detail: err.Error(), }) return } - build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace build.", Detail: err.Error(), }) return } - workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID) + workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace.", Detail: err.Error(), }) return } - owner, err := api.Database.GetUserByID(r.Context(), workspace.OwnerID) + owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace owner.", Detail: err.Error(), }) @@ -755,31 +755,69 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordin func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.WebsocketWaitMutex.Lock() - api.WebsocketWaitGroup.Add(1) - api.WebsocketWaitMutex.Unlock() - defer api.WebsocketWaitGroup.Done() - workspaceAgent := httpmw.WorkspaceAgent(r) - resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) + workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to get workspace resource.", + Message: "Failed to get workspace.", Detail: err.Error(), }) return } - build, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Failed to get build.", - Detail: err.Error(), + var req codersdk.AgentStats + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + if req.RxBytes == 0 && req.TxBytes == 0 { + httpapi.Write(ctx, rw, http.StatusOK, codersdk.AgentStatsResponse{ + ReportInterval: api.AgentStatsRefreshInterval, }) return } - workspace, err := api.Database.GetWorkspaceByID(ctx, build.WorkspaceID) + activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID) + + now := database.Now() + _, err = api.Database.InsertAgentStat(ctx, database.InsertAgentStatParams{ + ID: uuid.New(), + CreatedAt: now, + AgentID: workspaceAgent.ID, + WorkspaceID: workspace.ID, + UserID: workspace.OwnerID, + TemplateID: workspace.TemplateID, + Payload: json.RawMessage("{}"), + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + err = api.Database.UpdateWorkspaceLastUsedAt(ctx, database.UpdateWorkspaceLastUsedAtParams{ + ID: workspace.ID, + LastUsedAt: now, + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.AgentStatsResponse{ + ReportInterval: api.AgentStatsRefreshInterval, + }) +} + +func (api *API) workspaceAgentReportStatsWebsocket(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() + + workspaceAgent := httpmw.WorkspaceAgent(r) + workspace, err := api.Database.GetWorkspaceByAgentID(ctx, workspaceAgent.ID) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Failed to get workspace.", @@ -861,14 +899,13 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques api.Logger.Debug(ctx, "read stats report", slog.F("interval", api.AgentStatsRefreshInterval), slog.F("agent", workspaceAgent.ID), - slog.F("resource", resource.ID), slog.F("workspace", workspace.ID), slog.F("update_db", updateDB), slog.F("payload", rep), ) if updateDB { - go activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace) + go activityBumpWorkspace(api.Logger.Named("activity_bump"), api.Database, workspace.ID) lastReport = rep @@ -876,7 +913,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques ID: uuid.New(), CreatedAt: database.Now(), AgentID: workspaceAgent.ID, - WorkspaceID: build.WorkspaceID, + WorkspaceID: workspace.ID, UserID: workspace.OwnerID, TemplateID: workspace.TemplateID, Payload: json.RawMessage(repJSON), @@ -888,7 +925,7 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques } err = api.Database.UpdateWorkspaceLastUsedAt(ctx, database.UpdateWorkspaceLastUsedAtParams{ - ID: build.WorkspaceID, + ID: workspace.ID, LastUsedAt: database.Now(), }) if err != nil { @@ -901,22 +938,23 @@ func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Reques } func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() workspaceAgent := httpmw.WorkspaceAgent(r) var req codersdk.PostWorkspaceAppHealthsRequest - if !httpapi.Read(r.Context(), rw, r, &req) { + if !httpapi.Read(ctx, rw, r, &req) { return } if req.Healths == nil || len(req.Healths) == 0 { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Health field is empty", }) return } - apps, err := api.Database.GetWorkspaceAppsByAgentID(r.Context(), workspaceAgent.ID) + apps, err := api.Database.GetWorkspaceAppsByAgentID(ctx, workspaceAgent.ID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Error getting agent apps", Detail: err.Error(), }) @@ -935,7 +973,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) return nil }() if old == nil { - httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: "Error setting workspace app health", Detail: xerrors.Errorf("workspace app name %s not found", id).Error(), }) @@ -943,7 +981,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) } if old.HealthcheckUrl == "" { - httpapi.Write(r.Context(), rw, http.StatusNotFound, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ Message: "Error setting workspace app health", Detail: xerrors.Errorf("health checking is disabled for workspace app %s", id).Error(), }) @@ -955,7 +993,7 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) case codersdk.WorkspaceAppHealthHealthy: case codersdk.WorkspaceAppHealthUnhealthy: default: - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Error setting workspace app health", Detail: xerrors.Errorf("workspace app health %s is not a valid value", newHealth).Error(), }) @@ -972,12 +1010,12 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) } for _, app := range newApps { - err = api.Database.UpdateWorkspaceAppHealthByID(r.Context(), database.UpdateWorkspaceAppHealthByIDParams{ + err = api.Database.UpdateWorkspaceAppHealthByID(ctx, database.UpdateWorkspaceAppHealthByIDParams{ ID: app.ID, Health: app.Health, }) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Error setting workspace app health", Detail: err.Error(), }) @@ -985,33 +1023,33 @@ func (api *API) postWorkspaceAppHealth(rw http.ResponseWriter, r *http.Request) } } - resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID) + resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace resource.", Detail: err.Error(), }) return } - job, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID) + job, err := api.Database.GetWorkspaceBuildByJobID(ctx, resource.JobID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace build.", Detail: err.Error(), }) return } - workspace, err := api.Database.GetWorkspaceByID(r.Context(), job.WorkspaceID) + workspace, err := api.Database.GetWorkspaceByID(ctx, job.WorkspaceID) if err != nil { - httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching workspace.", Detail: err.Error(), }) return } - api.publishWorkspaceUpdate(r.Context(), workspace.ID) + api.publishWorkspaceUpdate(ctx, workspace.ID) - httpapi.Write(r.Context(), rw, http.StatusOK, nil) + httpapi.Write(ctx, rw, http.StatusOK, nil) } // postWorkspaceAgentsGitAuth returns a username and password for use @@ -1101,7 +1139,7 @@ func (api *API) workspaceAgentsGitAuth(rw http.ResponseWriter, r *http.Request) defer ticker.Stop() for { select { - case <-r.Context().Done(): + case <-ctx.Done(): return case <-ticker.C: case <-authChan: diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 316ba4a63657d..13bf080a5f307 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -1065,6 +1065,65 @@ func TestWorkspaceAgentsGitAuth(t *testing.T) { }) } +func TestWorkspaceAgentReportStats(t *testing.T) { + t.Parallel() + + t.Run("OK", func(t *testing.T) { + t.Parallel() + + client := coderdtest.New(t, &coderdtest.Options{ + IncludeProvisionerDaemon: true, + }) + user := coderdtest.CreateFirstUser(t, client) + authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: echo.ProvisionComplete, + ProvisionApply: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + agentClient := codersdk.New(client.URL) + agentClient.SetSessionToken(authToken) + + _, err := agentClient.PostAgentStats(context.Background(), &codersdk.AgentStats{ + ConnsByProto: map[string]int64{"TCP": 1}, + NumConns: 1, + RxPackets: 1, + RxBytes: 1, + TxPackets: 1, + TxBytes: 1, + }) + require.NoError(t, err) + + newWorkspace, err := client.Workspace(context.Background(), workspace.ID) + require.NoError(t, err) + + assert.True(t, + newWorkspace.LastUsedAt.After(workspace.LastUsedAt), + "%s is not after %s", newWorkspace.LastUsedAt, workspace.LastUsedAt, + ) + }) +} + func gitAuthCallback(t *testing.T, id string, client *codersdk.Client) *http.Response { client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse diff --git a/codersdk/templates.go b/codersdk/templates.go index 1d5a642b6ff08..6993f17c0597f 100644 --- a/codersdk/templates.go +++ b/codersdk/templates.go @@ -232,6 +232,6 @@ type AgentStatsReportResponse struct { NumConns int64 `json:"num_comms"` // RxBytes is the number of received bytes. RxBytes int64 `json:"rx_bytes"` - // TxBytes is the number of received bytes. + // TxBytes is the number of transmitted bytes. TxBytes int64 `json:"tx_bytes"` } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index f64d0cfe76261..f90e615de8917 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -18,7 +18,6 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" - "nhooyr.io/websocket/wsjson" "tailscale.com/tailcfg" "cdr.dev/slog" @@ -553,97 +552,88 @@ func (c *Client) WorkspaceAgentListeningPorts(ctx context.Context, agentID uuid. // Stats records the Agent's network connection statistics for use in // user-facing metrics and debugging. -// Each member value must be written and read with atomic. // @typescript-ignore AgentStats type AgentStats struct { + // ConnsByProto is a count of connections by protocol. + ConnsByProto map[string]int64 `json:"conns_by_proto"` + // NumConns is the number of connections received by an agent. NumConns int64 `json:"num_comms"` - RxBytes int64 `json:"rx_bytes"` - TxBytes int64 `json:"tx_bytes"` + // RxPackets is the number of received packets. + RxPackets int64 `json:"rx_packets"` + // RxBytes is the number of received bytes. + RxBytes int64 `json:"rx_bytes"` + // TxPackets is the number of transmitted bytes. + TxPackets int64 `json:"tx_packets"` + // TxBytes is the number of transmitted bytes. + TxBytes int64 `json:"tx_bytes"` } -// AgentReportStats begins a stat streaming connection with the Coder server. -// It is resilient to network failures and intermittent coderd issues. -func (c *Client) AgentReportStats( - ctx context.Context, - log slog.Logger, - stats func() *AgentStats, -) (io.Closer, error) { - serverURL, err := c.URL.Parse("/api/v2/workspaceagents/me/report-stats") +// @typescript-ignore AgentStatsResponse +type AgentStatsResponse struct { + // ReportInterval is the duration after which the agent should send stats + // again. + ReportInterval time.Duration `json:"report_interval"` +} + +func (c *Client) PostAgentStats(ctx context.Context, stats *AgentStats) (AgentStatsResponse, error) { + res, err := c.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/report-stats", stats) if err != nil { - return nil, xerrors.Errorf("parse url: %w", err) + return AgentStatsResponse{}, xerrors.Errorf("send request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AgentStatsResponse{}, readBodyAsError(res) } - jar, err := cookiejar.New(nil) + var interval AgentStatsResponse + err = json.NewDecoder(res.Body).Decode(&interval) if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) + return AgentStatsResponse{}, xerrors.Errorf("decode stats response: %w", err) } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenKey, - Value: c.SessionToken(), - }}) - - httpClient := &http.Client{ - Jar: jar, - Transport: c.HTTPClient.Transport, - } + return interval, nil +} - doneCh := make(chan struct{}) +// AgentReportStats begins a stat streaming connection with the Coder server. +// It is resilient to network failures and intermittent coderd issues. +func (c *Client) AgentReportStats( + ctx context.Context, + log slog.Logger, + getStats func() *AgentStats, +) (io.Closer, error) { ctx, cancel := context.WithCancel(ctx) go func() { - defer close(doneCh) - - // If the agent connection succeeds for a while, then fails, then succeeds - // for a while (etc.) the retry may hit the maximum. This is a normal - // case for long-running agents that experience coderd upgrades, so - // we use a short maximum retry limit. - for r := retry.New(time.Second, time.Minute); r.Wait(ctx); { - err = func() error { - conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ - HTTPClient: httpClient, - // Need to disable compression to avoid a data-race. - CompressionMode: websocket.CompressionDisabled, - }) - if err != nil { - if res == nil { - return err - } - return readBodyAsError(res) - } - - for { - var req AgentStatsReportRequest - err := wsjson.Read(ctx, conn, &req) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, "") - return err - } - - s := stats() + // Immediately trigger a stats push to get the correct interval. + timer := time.NewTimer(time.Nanosecond) + defer timer.Stop() - resp := AgentStatsReportResponse{ - NumConns: s.NumConns, - RxBytes: s.RxBytes, - TxBytes: s.TxBytes, - } + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + } - err = wsjson.Write(ctx, conn, resp) - if err != nil { - _ = conn.Close(websocket.StatusGoingAway, "") - return err + var nextInterval time.Duration + for r := retry.New(100*time.Millisecond, time.Minute); r.Wait(ctx); { + resp, err := c.PostAgentStats(ctx, getStats()) + if err != nil { + if !xerrors.Is(err, context.Canceled) { + log.Error(ctx, "report stats", slog.Error(err)) } + continue } - }() - if err != nil && ctx.Err() == nil { - log.Error(ctx, "report stats", slog.Error(err)) + + nextInterval = resp.ReportInterval + break } + timer.Reset(nextInterval) } }() return closeFunc(func() error { cancel() - <-doneCh return nil }), nil } diff --git a/codersdk/workspaceagents_test.go b/codersdk/workspaceagents_test.go index 8027a1ce86c72..e94181f12c730 100644 --- a/codersdk/workspaceagents_test.go +++ b/codersdk/workspaceagents_test.go @@ -6,13 +6,17 @@ import ( "net/http/httptest" "net/url" "strconv" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" + "cdr.dev/slog/sloggers/slogtest" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" + "github.com/coder/coder/testutil" ) func TestWorkspaceAgentMetadata(t *testing.T) { @@ -47,3 +51,30 @@ func TestWorkspaceAgentMetadata(t *testing.T) { require.Equal(t, parsed.Hostname(), node.HostName) require.Equal(t, parsed.Port(), strconv.Itoa(node.DERPPort)) } + +func TestAgentReportStats(t *testing.T) { + t.Parallel() + + var numReports atomic.Int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + numReports.Add(1) + httpapi.Write(context.Background(), w, http.StatusOK, codersdk.AgentStatsResponse{ + ReportInterval: 5 * time.Millisecond, + }) + })) + parsed, err := url.Parse(srv.URL) + require.NoError(t, err) + client := codersdk.New(parsed) + + ctx := context.Background() + closeStream, err := client.AgentReportStats(ctx, slogtest.Make(t, nil), func() *codersdk.AgentStats { + return &codersdk.AgentStats{} + }) + require.NoError(t, err) + defer closeStream.Close() + + require.Eventually(t, + func() bool { return numReports.Load() >= 3 }, + testutil.WaitMedium, testutil.IntervalFast, + ) +} diff --git a/go.mod b/go.mod index f38504b38d9dc..0d7bf9b8c17f3 100644 --- a/go.mod +++ b/go.mod @@ -40,7 +40,7 @@ replace github.com/tcnksm/go-httpstat => github.com/kylecarbs/go-httpstat v0.0.0 // There are a few minor changes we make to Tailscale that we're slowly upstreaming. Compare here: // https://github.com/tailscale/tailscale/compare/main...coder:tailscale:main -replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc +replace tailscale.com => github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3 // Switch to our fork that imports fixes from http://github.com/tailscale/ssh. // See: https://github.com/coder/coder/issues/3371 diff --git a/go.sum b/go.sum index 29b60a8021290..a33a51154b21d 100644 --- a/go.sum +++ b/go.sum @@ -355,8 +355,8 @@ github.com/coder/retry v1.3.0 h1:5lAAwt/2Cm6lVmnfBY7sOMXcBOwcwJhmV5QGSELIVWY= github.com/coder/retry v1.3.0/go.mod h1:tXuRgZgWjUnU5LZPT4lJh4ew2elUhexhlnXzrJWdyFY= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338 h1:tN5GKFT68YLVzJoA8AHuiMNJ0qlhoD3pGN3JY9gxSko= github.com/coder/ssh v0.0.0-20220811105153-fcea99919338/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914= -github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc h1:qozpteSLz0ifMasetJ+/Qac5Ud/NRNIlgTubGf6TAaQ= -github.com/coder/tailscale v1.1.1-0.20221113171243-7d90f070c5dc/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc= +github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3 h1:lq8GmpE5bn8A36uxq1h+TWnaQKPugtRkxKrYZA78O9c= +github.com/coder/tailscale v1.1.1-0.20221117204504-2d6503f027c3/go.mod h1:lkCb74eSJwxeNq8YwyILoHD5vtHktiZnTOxBxo3tbNc= github.com/containerd/aufs v0.0.0-20200908144142-dab0cbea06f4/go.mod h1:nukgQABAEopAHvB6j7cnP5zJ+/3aVcE7hCYqvIwAHyE= github.com/containerd/aufs v0.0.0-20201003224125-76a6863f2989/go.mod h1:AkGGQs9NM2vtYHaUen+NljV0/baGCAPELGm2q9ZXpWU= github.com/containerd/aufs v0.0.0-20210316121734-20793ff83c97/go.mod h1:kL5kd6KM5TzQjR79jljyi4olc1Vrx6XBlcyj3gNv2PU= diff --git a/tailnet/conn.go b/tailnet/conn.go index 325783c48bda8..10dec35287a1d 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -26,6 +26,7 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" tslogger "tailscale.com/types/logger" + "tailscale.com/types/netlogtype" "tailscale.com/types/netmap" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -35,15 +36,14 @@ import ( "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg/nmcfg" + "cdr.dev/slog" "github.com/coder/coder/coderd/database" "github.com/coder/coder/cryptorand" - - "cdr.dev/slog" ) func init() { - // Globally disable network namespacing. - // All networking happens in userspace. + // Globally disable network namespacing. All networking happens in + // userspace. netns.SetEnabled(false) } @@ -55,6 +55,11 @@ type Options struct { // If so, only DERPs can establish connections. BlockEndpoints bool Logger slog.Logger + + // EnableTrafficStats enables per-connection traffic statistics. + // ExtractTrafficStats must be called to reset the counters and be + // periodically called while enabled to avoid unbounded memory use. + EnableTrafficStats bool } // NewConn constructs a new Wireguard server that will accept connections from the addresses provided. @@ -143,8 +148,9 @@ func NewConn(options *Options) (*Conn, error) { } tunDevice, magicConn, dnsManager, ok := wireguardInternals.GetInternals() if !ok { - return nil, xerrors.New("failed to get wireguard internals") + return nil, xerrors.New("get wireguard internals") } + tunDevice.SetStatisticsEnabled(options.EnableTrafficStats) // Update the keys for the magic connection! err = magicConn.SetPrivateKey(nodePrivateKey) @@ -649,6 +655,13 @@ func (c *Conn) forwardTCPToLocal(conn net.Conn, port uint16) { c.logger.Debug(c.dialContext, "forwarded connection closed", slog.F("local_addr", dialAddrStr)) } +// ExtractTrafficStats extracts and resets the counters for all active +// connections. It must be called periodically otherwise the memory used is +// unbounded. EnableTrafficStats must be true when calling NewConn. +func (c *Conn) ExtractTrafficStats() map[netlogtype.Connection]netlogtype.Counts { + return c.tunDevice.ExtractStatistics() +} + type listenKey struct { network string host string