Skip to content

chore: convert workspaceagent HTTP API to agentapi #11280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
18 changes: 5 additions & 13 deletions coderd/agentapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,15 @@ func New(opts Options) *API {
WorkspaceIDFn: api.workspaceID,
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn,
}

api.AppsAPI = &AppsAPI{
AgentFn: api.agent,
WorkspaceIDFn: api.workspaceID,
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn,
}

api.MetadataAPI = &MetadataAPI{
Expand All @@ -135,9 +136,10 @@ func New(opts Options) *API {

api.LogsAPI = &LogsAPI{
AgentFn: api.agent,
WorkspaceIDFn: api.workspaceID,
Database: opts.Database,
Log: opts.Log,
PublishWorkspaceUpdateFn: api.publishWorkspaceUpdate,
PublishWorkspaceUpdateFn: opts.PublishWorkspaceUpdateFn,
PublishWorkspaceAgentLogsUpdateFn: opts.PublishWorkspaceAgentLogsUpdateFn,
}

Expand Down Expand Up @@ -218,13 +220,3 @@ func (a *API) workspaceID(ctx context.Context, agent *database.WorkspaceAgent) (
a.mu.Unlock()
return getWorkspaceAgentByIDRow.Workspace.ID, nil
}

func (a *API) publishWorkspaceUpdate(ctx context.Context, agent *database.WorkspaceAgent) error {
workspaceID, err := a.workspaceID(ctx, agent)
if err != nil {
return err
}

a.opts.PublishWorkspaceUpdateFn(ctx, workspaceID)
return nil
}
8 changes: 5 additions & 3 deletions coderd/agentapi/apps.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (

type AppsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID)
}

func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.BatchUpdateAppHealthRequest) (*agentproto.BatchUpdateAppHealthResponse, error) {
Expand Down Expand Up @@ -91,10 +92,11 @@ func (a *AppsAPI) BatchUpdateAppHealths(ctx context.Context, req *agentproto.Bat
}

if a.PublishWorkspaceUpdateFn != nil && len(newApps) > 0 {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
return nil, err
}
a.PublishWorkspaceUpdateFn(ctx, workspaceID)
}
return &agentproto.BatchUpdateAppHealthResponse{}, nil
}
30 changes: 24 additions & 6 deletions coderd/agentapi/apps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"

Expand Down Expand Up @@ -56,15 +57,19 @@ func TestBatchUpdateAppHealths(t *testing.T) {
}).Return(nil)

publishCalled := false
workspaceID := uuid.New()
api := &agentapi.AppsAPI{
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) {
return workspaceID, nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, id uuid.UUID) {
publishCalled = true
return nil
assert.Equal(t, workspaceID, id)
},
}

Expand Down Expand Up @@ -98,11 +103,13 @@ func TestBatchUpdateAppHealths(t *testing.T) {
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
return uuid.New(), nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
publishCalled = true
return nil
},
}

Expand Down Expand Up @@ -137,11 +144,13 @@ func TestBatchUpdateAppHealths(t *testing.T) {
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
return uuid.New(), nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, wa *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
publishCalled = true
return nil
},
}

Expand Down Expand Up @@ -172,6 +181,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
return uuid.New(), nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
Expand Down Expand Up @@ -201,6 +213,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
return uuid.New(), nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
Expand Down Expand Up @@ -231,6 +246,9 @@ func TestBatchUpdateAppHealths(t *testing.T) {
AgentFn: func(context.Context) (database.WorkspaceAgent, error) {
return agent, nil
},
WorkspaceIDFn: func(ctx context.Context, wa *database.WorkspaceAgent) (uuid.UUID, error) {
return uuid.New(), nil
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: nil,
Expand Down
32 changes: 26 additions & 6 deletions coderd/agentapi/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agentapi
import (
"context"
"database/sql"
"sort"
"time"

"github.com/google/uuid"
Expand All @@ -16,12 +17,26 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
)

type WorkspaceAgentAPIVersionContextKey struct{}

func WorkspaceAgentAPIVersion(ctx context.Context) string {
v, ok := ctx.Value(WorkspaceAgentAPIVersionContextKey{}).(string)
if !ok {
return AgentAPIVersionDRPC
}
return v
}

func SetWorkspaceAgentAPIVersion(ctx context.Context, version string) context.Context {
return context.WithValue(ctx, WorkspaceAgentAPIVersionContextKey{}, version)
}

type LifecycleAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID)

TimeNowFn func() time.Time // defaults to dbtime.Now()
}
Expand Down Expand Up @@ -113,10 +128,7 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}

if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
a.PublishWorkspaceUpdateFn(ctx, workspaceID)
}

return req.Lifecycle, nil
Expand Down Expand Up @@ -165,12 +177,20 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update
}
}

// Sort subsystems.
sort.Slice(dbSubsystems, func(i, j int) bool {
return dbSubsystems[i] < dbSubsystems[j]
})

// Get API version from context (or default to DRPC). This is only used when
// shimming an old version to this version.
apiVersion := WorkspaceAgentAPIVersion(ctx)
err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
ID: workspaceAgent.ID,
Version: req.Startup.Version,
ExpandedDirectory: req.Startup.ExpandedDirectory,
Subsystems: dbSubsystems,
APIVersion: AgentAPIVersionDRPC,
APIVersion: apiVersion,
})
if err != nil {
return nil, xerrors.Errorf("update workspace agent startup in database: %w", err)
Expand Down
12 changes: 4 additions & 8 deletions coderd/agentapi/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ func TestUpdateLifecycle(t *testing.T) {
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
publishCalled = true
return nil
},
}

Expand Down Expand Up @@ -161,9 +160,8 @@ func TestUpdateLifecycle(t *testing.T) {
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
publishCalled = true
return nil
},
}

Expand Down Expand Up @@ -244,9 +242,8 @@ func TestUpdateLifecycle(t *testing.T) {
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
atomic.AddInt64(&publishCalled, 1)
return nil
},
}

Expand Down Expand Up @@ -319,9 +316,8 @@ func TestUpdateLifecycle(t *testing.T) {
},
Database: dbM,
Log: slogtest.Make(t, nil),
PublishWorkspaceUpdateFn: func(ctx context.Context, agent *database.WorkspaceAgent) error {
PublishWorkspaceUpdateFn: func(_ context.Context, _ uuid.UUID) {
publishCalled = true
return nil
},
}

Expand Down
18 changes: 9 additions & 9 deletions coderd/agentapi/logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (

type LogsAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
Database database.Store
Log slog.Logger
PublishWorkspaceUpdateFn func(context.Context, *database.WorkspaceAgent) error
PublishWorkspaceUpdateFn func(context.Context, uuid.UUID)
PublishWorkspaceAgentLogsUpdateFn func(ctx context.Context, workspaceAgentID uuid.UUID, msg agentsdk.LogsNotifyMessage)

TimeNowFn func() time.Time // defaults to dbtime.Now()
Expand Down Expand Up @@ -48,6 +49,11 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
return nil, xerrors.Errorf("parse log source ID %q: %w", req.LogSourceId, err)
}

workspaceID, err := a.WorkspaceIDFn(ctx, &workspaceAgent)
if err != nil {
return nil, err
}

// This is to support the legacy API where the log source ID was
// not provided in the request body. We default to the external
// log source in this case.
Expand Down Expand Up @@ -123,10 +129,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
}

if a.PublishWorkspaceUpdateFn != nil {
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
a.PublishWorkspaceUpdateFn(ctx, workspaceID)
}
return nil, xerrors.New("workspace agent log limit exceeded")
}
Expand All @@ -143,10 +146,7 @@ func (a *LogsAPI) BatchCreateLogs(ctx context.Context, req *agentproto.BatchCrea
if workspaceAgent.LogsLength == 0 && a.PublishWorkspaceUpdateFn != nil {
// If these are the first logs being appended, we publish a UI update
// to notify the UI that logs are now available.
err = a.PublishWorkspaceUpdateFn(ctx, &workspaceAgent)
if err != nil {
return nil, xerrors.Errorf("publish workspace update: %w", err)
}
a.PublishWorkspaceUpdateFn(ctx, workspaceID)
}

return &agentproto.BatchCreateLogsResponse{}, nil
Expand Down
Loading