diff --git a/agent/agent.go b/agent/agent.go index 132c4b200181f..bcf4e50cc8bc9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -92,7 +92,6 @@ type Client interface { ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error) PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error - PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error RewriteDERPMap(derpMap *tailcfg.DERPMap) @@ -737,13 +736,18 @@ func (a *agent) run(ctx context.Context) error { if err != nil { return xerrors.Errorf("expand directory: %w", err) } - err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{ + subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems) + if err != nil { + a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err)) + return xerrors.Errorf("failed to convert subsystems: %w", err) + } + _, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{ Version: buildinfo.Version(), ExpandedDirectory: manifest.Directory, - Subsystems: a.subsystems, - }) + Subsystems: subsys, + }}) if err != nil { - return xerrors.Errorf("update workspace agent version: %w", err) + return xerrors.Errorf("update workspace agent startup: %w", err) } oldManifest := a.manifest.Swap(&manifest) diff --git a/agent/agent_test.go b/agent/agent_test.go index 9f9d7efbaa39f..f7cbe41e96ec0 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -1394,56 +1394,52 @@ func TestAgent_Startup(t *testing.T) { t.Run("EmptyDirectory", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ Directory: "", }, 0) - assert.Eventually(t, func() bool { - return client.GetStartup().Version != "" - }, testutil.WaitShort, testutil.IntervalFast) - require.Equal(t, "", client.GetStartup().ExpandedDirectory) + startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup()) + require.Equal(t, "", startup.GetExpandedDirectory()) }) t.Run("HomeDirectory", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ Directory: "~", }, 0) - assert.Eventually(t, func() bool { - return client.GetStartup().Version != "" - }, testutil.WaitShort, testutil.IntervalFast) + startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup()) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) + require.Equal(t, homeDir, startup.GetExpandedDirectory()) }) t.Run("NotAbsoluteDirectory", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ Directory: "coder/coder", }, 0) - assert.Eventually(t, func() bool { - return client.GetStartup().Version != "" - }, testutil.WaitShort, testutil.IntervalFast) + startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup()) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory) + require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory()) }) t.Run("HomeEnvironmentVariable", func(t *testing.T) { t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) _, client, _, _, _ := setupAgent(t, agentsdk.Manifest{ Directory: "$HOME", }, 0) - assert.Eventually(t, func() bool { - return client.GetStartup().Version != "" - }, testutil.WaitShort, testutil.IntervalFast) + startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup()) homeDir, err := os.UserHomeDir() require.NoError(t, err) - require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory) + require.Equal(t, homeDir, startup.GetExpandedDirectory()) }) } diff --git a/agent/agenttest/client.go b/agent/agenttest/client.go index a1fcc8c44a52e..3f721a3135a55 100644 --- a/agent/agenttest/client.go +++ b/agent/agenttest/client.go @@ -88,7 +88,6 @@ type Client struct { mu sync.Mutex // Protects following. lifecycleStates []codersdk.WorkspaceAgentLifecycle - startup agentsdk.PostStartupRequest logs []agentsdk.Log derpMapUpdates chan *tailcfg.DERPMap derpMapOnce sync.Once @@ -173,10 +172,8 @@ func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsR return nil } -func (c *Client) GetStartup() agentsdk.PostStartupRequest { - c.mu.Lock() - defer c.mu.Unlock() - return c.startup +func (c *Client) GetStartup() <-chan *agentproto.Startup { + return c.fakeAgentAPI.startupCh } func (c *Client) GetMetadata() map[string]agentsdk.Metadata { @@ -198,14 +195,6 @@ func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequ return nil } -func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error { - c.mu.Lock() - defer c.mu.Unlock() - c.startup = startup - c.logger.Debug(ctx, "post startup", slog.F("req", startup)) - return nil -} - func (c *Client) GetStartupLogs() []agentsdk.Log { c.mu.Lock() defer c.mu.Unlock() @@ -250,7 +239,8 @@ type FakeAgentAPI struct { t testing.TB logger slog.Logger - manifest *agentproto.Manifest + manifest *agentproto.Manifest + startupCh chan *agentproto.Startup getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error) } @@ -294,9 +284,9 @@ func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpd panic("implement me") } -func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) { - // TODO implement me - panic("implement me") +func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) { + f.startupCh <- req.GetStartup() + return req.GetStartup(), nil } func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) { @@ -311,8 +301,9 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI { return &FakeAgentAPI{ - t: t, - logger: logger.Named("FakeAgentAPI"), - manifest: manifest, + t: t, + logger: logger.Named("FakeAgentAPI"), + manifest: manifest, + startupCh: make(chan *agentproto.Startup, 100), } } diff --git a/coderd/agentapi/api.go b/coderd/agentapi/api.go index e6db4736af5f3..73b50d9c0c446 100644 --- a/coderd/agentapi/api.go +++ b/coderd/agentapi/api.go @@ -29,8 +29,6 @@ import ( tailnetproto "github.com/coder/coder/v2/tailnet/proto" ) -const AgentAPIVersionDRPC = "2.0" - // API implements the DRPC agent API interface from agent/proto. This struct is // instantiated once per agent connection and kept alive for the duration of the // session. diff --git a/coderd/agentapi/lifecycle.go b/coderd/agentapi/lifecycle.go index 662d0c0c2e28e..9c34b2c5485ad 100644 --- a/coderd/agentapi/lifecycle.go +++ b/coderd/agentapi/lifecycle.go @@ -6,6 +6,7 @@ import ( "time" "github.com/google/uuid" + "golang.org/x/exp/slices" "golang.org/x/mod/semver" "golang.org/x/xerrors" "google.golang.org/protobuf/types/known/timestamppb" @@ -16,6 +17,12 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" ) +type contextKeyAPIVersion struct{} + +func WithAPIVersion(ctx context.Context, version string) context.Context { + return context.WithValue(ctx, contextKeyAPIVersion{}, version) +} + type LifecycleAPI struct { AgentFn func(context.Context) (database.WorkspaceAgent, error) WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error) @@ -123,6 +130,10 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda } func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) { + apiVersion, ok := ctx.Value(contextKeyAPIVersion{}).(string) + if !ok { + return nil, xerrors.Errorf("internal error; api version unspecified") + } workspaceAgent, err := a.AgentFn(ctx) if err != nil { return nil, err @@ -164,13 +175,14 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update dbSubsystems = append(dbSubsystems, dbSubsystem) } } + slices.Sort(dbSubsystems) err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{ ID: workspaceAgent.ID, Version: req.Startup.Version, ExpandedDirectory: req.Startup.ExpandedDirectory, Subsystems: dbSubsystems, - APIVersion: AgentAPIVersionDRPC, + APIVersion: apiVersion, }) if err != nil { return nil, xerrors.Errorf("update workspace agent startup in database: %w", err) diff --git a/coderd/agentapi/lifecycle_test.go b/coderd/agentapi/lifecycle_test.go index 855ff9329acc9..3a88ee5cb3726 100644 --- a/coderd/agentapi/lifecycle_test.go +++ b/coderd/agentapi/lifecycle_test.go @@ -382,10 +382,11 @@ func TestUpdateStartup(t *testing.T) { database.WorkspaceAgentSubsystemEnvbuilder, database.WorkspaceAgentSubsystemExectrace, }, - APIVersion: agentapi.AgentAPIVersionDRPC, + APIVersion: "2.0", }).Return(nil) - resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + ctx := agentapi.WithAPIVersion(context.Background(), "2.0") + resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{ Startup: startup, }) require.NoError(t, err) @@ -416,7 +417,8 @@ func TestUpdateStartup(t *testing.T) { Subsystems: []agentproto.Startup_Subsystem{}, } - resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + ctx := agentapi.WithAPIVersion(context.Background(), "2.0") + resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{ Startup: startup, }) require.Error(t, err) @@ -451,7 +453,8 @@ func TestUpdateStartup(t *testing.T) { }, } - resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{ + ctx := agentapi.WithAPIVersion(context.Background(), "2.0") + resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{ Startup: startup, }) require.Error(t, err) diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 53bece8addd4b..bf303dd0bf703 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -24,7 +24,6 @@ import ( "github.com/coder/coder/v2/agent" "github.com/coder/coder/v2/agent/agenttest" agentproto "github.com/coder/coder/v2/agent/proto" - "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" "github.com/coder/coder/v2/coderd/database" @@ -1389,13 +1388,13 @@ func TestWorkspaceAgent_Startup(t *testing.T) { } ) - err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{ + err := postStartup(ctx, t, agentClient, &agentproto.Startup{ Version: expectedVersion, ExpandedDirectory: expectedDir, - Subsystems: []codersdk.AgentSubsystem{ + Subsystems: []agentproto.Startup_Subsystem{ // Not sorted. - expectedSubsystems[1], - expectedSubsystems[0], + agentproto.Startup_EXECTRACE, + agentproto.Startup_ENVBOX, }, }) require.NoError(t, err) @@ -1409,7 +1408,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) { require.Equal(t, expectedDir, wsagent.ExpandedDirectory) // Sorted require.Equal(t, expectedSubsystems, wsagent.Subsystems) - require.Equal(t, coderd.AgentAPIVersionREST, wsagent.APIVersion) + require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion) }) t.Run("InvalidSemver", func(t *testing.T) { @@ -1427,13 +1426,10 @@ func TestWorkspaceAgent_Startup(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) - err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{ + err := postStartup(ctx, t, agentClient, &agentproto.Startup{ Version: "1.2.3", }) - require.Error(t, err) - cerr, ok := codersdk.AsError(err) - require.True(t, ok) - require.Equal(t, http.StatusBadRequest, cerr.StatusCode()) + require.ErrorContains(t, err, "invalid agent semver version") }) } @@ -1640,3 +1636,15 @@ func requireGetManifest(ctx context.Context, t testing.TB, client agent.Client) require.NoError(t, err) return manifest } + +func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error { + conn, err := client.Listen(ctx) + require.NoError(t, err) + defer func() { + cErr := conn.Close() + require.NoError(t, cErr) + }() + aAPI := agentproto.NewDRPCAgentClient(conn) + _, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup}) + return err +} diff --git a/coderd/workspaceagentsrpc.go b/coderd/workspaceagentsrpc.go index e5398b271570e..c59e50387f784 100644 --- a/coderd/workspaceagentsrpc.go +++ b/coderd/workspaceagentsrpc.go @@ -154,6 +154,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) { Auth: tailnet.AgentTunnelAuth{}, } ctx = tailnet.WithStreamID(ctx, streamID) + ctx = agentapi.WithAPIVersion(ctx, version) err = agentAPI.Serve(ctx, mux) if err != nil { api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err)) diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index ad1e6bbda880c..cc19dd88e3ca9 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -556,18 +556,6 @@ type PostStartupRequest struct { Subsystems []codersdk.AgentSubsystem `json:"subsystems"` } -func (c *Client) PostStartup(ctx context.Context, req PostStartupRequest) error { - res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/startup", req) - if err != nil { - return err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return codersdk.ReadBodyAsError(res) - } - return nil -} - type Log struct { CreatedAt time.Time `json:"created_at"` Output string `json:"output"` diff --git a/codersdk/agentsdk/convert.go b/codersdk/agentsdk/convert.go index 2a20cf5d0c98c..1838e52b013d9 100644 --- a/codersdk/agentsdk/convert.go +++ b/codersdk/agentsdk/convert.go @@ -266,3 +266,15 @@ func ProtoFromServiceBanner(sb codersdk.ServiceBannerConfig) *proto.ServiceBanne BackgroundColor: sb.BackgroundColor, } } + +func ProtoFromSubsystems(ss []codersdk.AgentSubsystem) ([]proto.Startup_Subsystem, error) { + ret := make([]proto.Startup_Subsystem, len(ss)) + for i, s := range ss { + pi, ok := proto.Startup_Subsystem_value[strings.ToUpper(string(s))] + if !ok { + return nil, xerrors.Errorf("unknown subsystem: %s", s) + } + ret[i] = proto.Startup_Subsystem(pi) + } + return ret, nil +} diff --git a/codersdk/agentsdk/convert_test.go b/codersdk/agentsdk/convert_test.go index 9205777e13a28..3417416b43fe1 100644 --- a/codersdk/agentsdk/convert_test.go +++ b/codersdk/agentsdk/convert_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/require" "tailscale.com/tailcfg" + "github.com/coder/coder/v2/agent/proto" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/agentsdk" "github.com/coder/coder/v2/tailnet" @@ -144,3 +145,19 @@ func TestManifest(t *testing.T) { require.Equal(t, manifest.Metadata, back.Metadata) require.Equal(t, manifest.Scripts, back.Scripts) } + +func TestSubsystems(t *testing.T) { + t.Parallel() + ss := []codersdk.AgentSubsystem{ + codersdk.AgentSubsystemEnvbox, + codersdk.AgentSubsystemEnvbuilder, + codersdk.AgentSubsystemExectrace, + } + ps, err := agentsdk.ProtoFromSubsystems(ss) + require.NoError(t, err) + require.Equal(t, ps, []proto.Startup_Subsystem{ + proto.Startup_ENVBOX, + proto.Startup_ENVBUILDER, + proto.Startup_EXECTRACE, + }) +}