Skip to content

feat: use agent v2 API to post startup #11877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ type Client interface {
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
RewriteDERPMap(derpMap *tailcfg.DERPMap)
Expand Down Expand Up @@ -737,13 +736,18 @@ func (a *agent) run(ctx context.Context) error {
if err != nil {
return xerrors.Errorf("expand directory: %w", err)
}
err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
if err != nil {
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
return xerrors.Errorf("failed to convert subsystems: %w", err)
}
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
Version: buildinfo.Version(),
ExpandedDirectory: manifest.Directory,
Subsystems: a.subsystems,
})
Subsystems: subsys,
}})
if err != nil {
return xerrors.Errorf("update workspace agent version: %w", err)
return xerrors.Errorf("update workspace agent startup: %w", err)
}

oldManifest := a.manifest.Swap(&manifest)
Expand Down
28 changes: 12 additions & 16 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1394,56 +1394,52 @@ func TestAgent_Startup(t *testing.T) {

t.Run("EmptyDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
require.Equal(t, "", startup.GetExpandedDirectory())
})

t.Run("HomeDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "~",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})

t.Run("NotAbsoluteDirectory", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "coder/coder",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
})

t.Run("HomeEnvironmentVariable", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)

_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
Directory: "$HOME",
}, 0)
assert.Eventually(t, func() bool {
return client.GetStartup().Version != ""
}, testutil.WaitShort, testutil.IntervalFast)
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
homeDir, err := os.UserHomeDir()
require.NoError(t, err)
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
require.Equal(t, homeDir, startup.GetExpandedDirectory())
})
}

Expand Down
31 changes: 11 additions & 20 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ type Client struct {

mu sync.Mutex // Protects following.
lifecycleStates []codersdk.WorkspaceAgentLifecycle
startup agentsdk.PostStartupRequest
logs []agentsdk.Log
derpMapUpdates chan *tailcfg.DERPMap
derpMapOnce sync.Once
Expand Down Expand Up @@ -173,10 +172,8 @@ func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsR
return nil
}

func (c *Client) GetStartup() agentsdk.PostStartupRequest {
c.mu.Lock()
defer c.mu.Unlock()
return c.startup
func (c *Client) GetStartup() <-chan *agentproto.Startup {
return c.fakeAgentAPI.startupCh
}

func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
Expand All @@ -198,14 +195,6 @@ func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequ
return nil
}

func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
c.mu.Lock()
defer c.mu.Unlock()
c.startup = startup
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
return nil
}

func (c *Client) GetStartupLogs() []agentsdk.Log {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down Expand Up @@ -250,7 +239,8 @@ type FakeAgentAPI struct {
t testing.TB
logger slog.Logger

manifest *agentproto.Manifest
manifest *agentproto.Manifest
startupCh chan *agentproto.Startup

getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
}
Expand Down Expand Up @@ -294,9 +284,9 @@ func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpd
panic("implement me")
}

func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
// TODO implement me
panic("implement me")
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
f.startupCh <- req.GetStartup()
return req.GetStartup(), nil
}

func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
Expand All @@ -311,8 +301,9 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog

func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
return &FakeAgentAPI{
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
t: t,
logger: logger.Named("FakeAgentAPI"),
manifest: manifest,
startupCh: make(chan *agentproto.Startup, 100),
}
}
2 changes: 0 additions & 2 deletions coderd/agentapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ import (
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
)

const AgentAPIVersionDRPC = "2.0"

// API implements the DRPC agent API interface from agent/proto. This struct is
// instantiated once per agent connection and kept alive for the duration of the
// session.
Expand Down
14 changes: 13 additions & 1 deletion coderd/agentapi/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/google/uuid"
"golang.org/x/exp/slices"
"golang.org/x/mod/semver"
"golang.org/x/xerrors"
"google.golang.org/protobuf/types/known/timestamppb"
Expand All @@ -16,6 +17,12 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
)

type contextKeyAPIVersion struct{}

func WithAPIVersion(ctx context.Context, version string) context.Context {
return context.WithValue(ctx, contextKeyAPIVersion{}, version)
}

type LifecycleAPI struct {
AgentFn func(context.Context) (database.WorkspaceAgent, error)
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
Expand Down Expand Up @@ -123,6 +130,10 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
}

func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
apiVersion, ok := ctx.Value(contextKeyAPIVersion{}).(string)
if !ok {
return nil, xerrors.Errorf("internal error; api version unspecified")
}
workspaceAgent, err := a.AgentFn(ctx)
if err != nil {
return nil, err
Expand Down Expand Up @@ -164,13 +175,14 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update
dbSubsystems = append(dbSubsystems, dbSubsystem)
}
}
slices.Sort(dbSubsystems)

err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
ID: workspaceAgent.ID,
Version: req.Startup.Version,
ExpandedDirectory: req.Startup.ExpandedDirectory,
Subsystems: dbSubsystems,
APIVersion: AgentAPIVersionDRPC,
APIVersion: apiVersion,
})
if err != nil {
return nil, xerrors.Errorf("update workspace agent startup in database: %w", err)
Expand Down
11 changes: 7 additions & 4 deletions coderd/agentapi/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,11 @@ func TestUpdateStartup(t *testing.T) {
database.WorkspaceAgentSubsystemEnvbuilder,
database.WorkspaceAgentSubsystemExectrace,
},
APIVersion: agentapi.AgentAPIVersionDRPC,
APIVersion: "2.0",
}).Return(nil)

resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.NoError(t, err)
Expand Down Expand Up @@ -416,7 +417,8 @@ func TestUpdateStartup(t *testing.T) {
Subsystems: []agentproto.Startup_Subsystem{},
}

resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)
Expand Down Expand Up @@ -451,7 +453,8 @@ func TestUpdateStartup(t *testing.T) {
},
}

resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
Startup: startup,
})
require.Error(t, err)
Expand Down
30 changes: 19 additions & 11 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"github.com/coder/coder/v2/agent"
"github.com/coder/coder/v2/agent/agenttest"
agentproto "github.com/coder/coder/v2/agent/proto"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
"github.com/coder/coder/v2/coderd/database"
Expand Down Expand Up @@ -1389,13 +1388,13 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
}
)

err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: expectedVersion,
ExpandedDirectory: expectedDir,
Subsystems: []codersdk.AgentSubsystem{
Subsystems: []agentproto.Startup_Subsystem{
// Not sorted.
expectedSubsystems[1],
expectedSubsystems[0],
agentproto.Startup_EXECTRACE,
agentproto.Startup_ENVBOX,
},
})
require.NoError(t, err)
Expand All @@ -1409,7 +1408,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
require.Equal(t, expectedDir, wsagent.ExpandedDirectory)
// Sorted
require.Equal(t, expectedSubsystems, wsagent.Subsystems)
require.Equal(t, coderd.AgentAPIVersionREST, wsagent.APIVersion)
require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion)
})

t.Run("InvalidSemver", func(t *testing.T) {
Expand All @@ -1427,13 +1426,10 @@ func TestWorkspaceAgent_Startup(t *testing.T) {

ctx := testutil.Context(t, testutil.WaitMedium)

err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
Version: "1.2.3",
})
require.Error(t, err)
cerr, ok := codersdk.AsError(err)
require.True(t, ok)
require.Equal(t, http.StatusBadRequest, cerr.StatusCode())
require.ErrorContains(t, err, "invalid agent semver version")
})
}

Expand Down Expand Up @@ -1640,3 +1636,15 @@ func requireGetManifest(ctx context.Context, t testing.TB, client agent.Client)
require.NoError(t, err)
return manifest
}

func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error {
conn, err := client.Listen(ctx)
require.NoError(t, err)
defer func() {
cErr := conn.Close()
require.NoError(t, cErr)
}()
aAPI := agentproto.NewDRPCAgentClient(conn)
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
return err
}
1 change: 1 addition & 0 deletions coderd/workspaceagentsrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
Auth: tailnet.AgentTunnelAuth{},
}
ctx = tailnet.WithStreamID(ctx, streamID)
ctx = agentapi.WithAPIVersion(ctx, version)
err = agentAPI.Serve(ctx, mux)
if err != nil {
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))
Expand Down
12 changes: 0 additions & 12 deletions codersdk/agentsdk/agentsdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -556,18 +556,6 @@ type PostStartupRequest struct {
Subsystems []codersdk.AgentSubsystem `json:"subsystems"`
}

func (c *Client) PostStartup(ctx context.Context, req PostStartupRequest) error {
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/startup", req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return codersdk.ReadBodyAsError(res)
}
return nil
}

type Log struct {
CreatedAt time.Time `json:"created_at"`
Output string `json:"output"`
Expand Down
12 changes: 12 additions & 0 deletions codersdk/agentsdk/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,15 @@ func ProtoFromServiceBanner(sb codersdk.ServiceBannerConfig) *proto.ServiceBanne
BackgroundColor: sb.BackgroundColor,
}
}

func ProtoFromSubsystems(ss []codersdk.AgentSubsystem) ([]proto.Startup_Subsystem, error) {
ret := make([]proto.Startup_Subsystem, len(ss))
for i, s := range ss {
pi, ok := proto.Startup_Subsystem_value[strings.ToUpper(string(s))]
if !ok {
return nil, xerrors.Errorf("unknown subsystem: %s", s)
}
ret[i] = proto.Startup_Subsystem(pi)
}
return ret, nil
}
Loading