Skip to content

Commit 2599850

Browse files
authored
feat: use agent v2 API to post startup (#11877)
Uses the v2 Agent API to post startup information.
1 parent da8bb1c commit 2599850

File tree

11 files changed

+101
-71
lines changed

11 files changed

+101
-71
lines changed

agent/agent.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ type Client interface {
9292
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
9393
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
9494
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
95-
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
9695
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
9796
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
9897
RewriteDERPMap(derpMap *tailcfg.DERPMap)
@@ -737,13 +736,18 @@ func (a *agent) run(ctx context.Context) error {
737736
if err != nil {
738737
return xerrors.Errorf("expand directory: %w", err)
739738
}
740-
err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{
739+
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
740+
if err != nil {
741+
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
742+
return xerrors.Errorf("failed to convert subsystems: %w", err)
743+
}
744+
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
741745
Version: buildinfo.Version(),
742746
ExpandedDirectory: manifest.Directory,
743-
Subsystems: a.subsystems,
744-
})
747+
Subsystems: subsys,
748+
}})
745749
if err != nil {
746-
return xerrors.Errorf("update workspace agent version: %w", err)
750+
return xerrors.Errorf("update workspace agent startup: %w", err)
747751
}
748752

749753
oldManifest := a.manifest.Swap(&manifest)

agent/agent_test.go

+12-16
Original file line numberDiff line numberDiff line change
@@ -1394,56 +1394,52 @@ func TestAgent_Startup(t *testing.T) {
13941394

13951395
t.Run("EmptyDirectory", func(t *testing.T) {
13961396
t.Parallel()
1397+
ctx := testutil.Context(t, testutil.WaitShort)
13971398

13981399
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
13991400
Directory: "",
14001401
}, 0)
1401-
assert.Eventually(t, func() bool {
1402-
return client.GetStartup().Version != ""
1403-
}, testutil.WaitShort, testutil.IntervalFast)
1404-
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
1402+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
1403+
require.Equal(t, "", startup.GetExpandedDirectory())
14051404
})
14061405

14071406
t.Run("HomeDirectory", func(t *testing.T) {
14081407
t.Parallel()
1408+
ctx := testutil.Context(t, testutil.WaitShort)
14091409

14101410
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14111411
Directory: "~",
14121412
}, 0)
1413-
assert.Eventually(t, func() bool {
1414-
return client.GetStartup().Version != ""
1415-
}, testutil.WaitShort, testutil.IntervalFast)
1413+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14161414
homeDir, err := os.UserHomeDir()
14171415
require.NoError(t, err)
1418-
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
1416+
require.Equal(t, homeDir, startup.GetExpandedDirectory())
14191417
})
14201418

14211419
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
14221420
t.Parallel()
1421+
ctx := testutil.Context(t, testutil.WaitShort)
14231422

14241423
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14251424
Directory: "coder/coder",
14261425
}, 0)
1427-
assert.Eventually(t, func() bool {
1428-
return client.GetStartup().Version != ""
1429-
}, testutil.WaitShort, testutil.IntervalFast)
1426+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14301427
homeDir, err := os.UserHomeDir()
14311428
require.NoError(t, err)
1432-
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
1429+
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
14331430
})
14341431

14351432
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
14361433
t.Parallel()
1434+
ctx := testutil.Context(t, testutil.WaitShort)
14371435

14381436
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14391437
Directory: "$HOME",
14401438
}, 0)
1441-
assert.Eventually(t, func() bool {
1442-
return client.GetStartup().Version != ""
1443-
}, testutil.WaitShort, testutil.IntervalFast)
1439+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14441440
homeDir, err := os.UserHomeDir()
14451441
require.NoError(t, err)
1446-
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
1442+
require.Equal(t, homeDir, startup.GetExpandedDirectory())
14471443
})
14481444
}
14491445

agent/agenttest/client.go

+11-20
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ type Client struct {
8888

8989
mu sync.Mutex // Protects following.
9090
lifecycleStates []codersdk.WorkspaceAgentLifecycle
91-
startup agentsdk.PostStartupRequest
9291
logs []agentsdk.Log
9392
derpMapUpdates chan *tailcfg.DERPMap
9493
derpMapOnce sync.Once
@@ -173,10 +172,8 @@ func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsR
173172
return nil
174173
}
175174

176-
func (c *Client) GetStartup() agentsdk.PostStartupRequest {
177-
c.mu.Lock()
178-
defer c.mu.Unlock()
179-
return c.startup
175+
func (c *Client) GetStartup() <-chan *agentproto.Startup {
176+
return c.fakeAgentAPI.startupCh
180177
}
181178

182179
func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
@@ -198,14 +195,6 @@ func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequ
198195
return nil
199196
}
200197

201-
func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
202-
c.mu.Lock()
203-
defer c.mu.Unlock()
204-
c.startup = startup
205-
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
206-
return nil
207-
}
208-
209198
func (c *Client) GetStartupLogs() []agentsdk.Log {
210199
c.mu.Lock()
211200
defer c.mu.Unlock()
@@ -250,7 +239,8 @@ type FakeAgentAPI struct {
250239
t testing.TB
251240
logger slog.Logger
252241

253-
manifest *agentproto.Manifest
242+
manifest *agentproto.Manifest
243+
startupCh chan *agentproto.Startup
254244

255245
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
256246
}
@@ -294,9 +284,9 @@ func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpd
294284
panic("implement me")
295285
}
296286

297-
func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
298-
// TODO implement me
299-
panic("implement me")
287+
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
288+
f.startupCh <- req.GetStartup()
289+
return req.GetStartup(), nil
300290
}
301291

302292
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
@@ -311,8 +301,9 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog
311301

312302
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
313303
return &FakeAgentAPI{
314-
t: t,
315-
logger: logger.Named("FakeAgentAPI"),
316-
manifest: manifest,
304+
t: t,
305+
logger: logger.Named("FakeAgentAPI"),
306+
manifest: manifest,
307+
startupCh: make(chan *agentproto.Startup, 100),
317308
}
318309
}

coderd/agentapi/api.go

-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ import (
2929
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
3030
)
3131

32-
const AgentAPIVersionDRPC = "2.0"
33-
3432
// API implements the DRPC agent API interface from agent/proto. This struct is
3533
// instantiated once per agent connection and kept alive for the duration of the
3634
// session.

coderd/agentapi/lifecycle.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"time"
77

88
"github.com/google/uuid"
9+
"golang.org/x/exp/slices"
910
"golang.org/x/mod/semver"
1011
"golang.org/x/xerrors"
1112
"google.golang.org/protobuf/types/known/timestamppb"
@@ -16,6 +17,12 @@ import (
1617
"github.com/coder/coder/v2/coderd/database/dbtime"
1718
)
1819

20+
type contextKeyAPIVersion struct{}
21+
22+
func WithAPIVersion(ctx context.Context, version string) context.Context {
23+
return context.WithValue(ctx, contextKeyAPIVersion{}, version)
24+
}
25+
1926
type LifecycleAPI struct {
2027
AgentFn func(context.Context) (database.WorkspaceAgent, error)
2128
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
@@ -123,6 +130,10 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
123130
}
124131

125132
func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
133+
apiVersion, ok := ctx.Value(contextKeyAPIVersion{}).(string)
134+
if !ok {
135+
return nil, xerrors.Errorf("internal error; api version unspecified")
136+
}
126137
workspaceAgent, err := a.AgentFn(ctx)
127138
if err != nil {
128139
return nil, err
@@ -164,13 +175,14 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update
164175
dbSubsystems = append(dbSubsystems, dbSubsystem)
165176
}
166177
}
178+
slices.Sort(dbSubsystems)
167179

168180
err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
169181
ID: workspaceAgent.ID,
170182
Version: req.Startup.Version,
171183
ExpandedDirectory: req.Startup.ExpandedDirectory,
172184
Subsystems: dbSubsystems,
173-
APIVersion: AgentAPIVersionDRPC,
185+
APIVersion: apiVersion,
174186
})
175187
if err != nil {
176188
return nil, xerrors.Errorf("update workspace agent startup in database: %w", err)

coderd/agentapi/lifecycle_test.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,11 @@ func TestUpdateStartup(t *testing.T) {
382382
database.WorkspaceAgentSubsystemEnvbuilder,
383383
database.WorkspaceAgentSubsystemExectrace,
384384
},
385-
APIVersion: agentapi.AgentAPIVersionDRPC,
385+
APIVersion: "2.0",
386386
}).Return(nil)
387387

388-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
388+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
389+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
389390
Startup: startup,
390391
})
391392
require.NoError(t, err)
@@ -416,7 +417,8 @@ func TestUpdateStartup(t *testing.T) {
416417
Subsystems: []agentproto.Startup_Subsystem{},
417418
}
418419

419-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
420+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
421+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
420422
Startup: startup,
421423
})
422424
require.Error(t, err)
@@ -451,7 +453,8 @@ func TestUpdateStartup(t *testing.T) {
451453
},
452454
}
453455

454-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
456+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
457+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
455458
Startup: startup,
456459
})
457460
require.Error(t, err)

coderd/workspaceagents_test.go

+19-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"github.com/coder/coder/v2/agent"
2525
"github.com/coder/coder/v2/agent/agenttest"
2626
agentproto "github.com/coder/coder/v2/agent/proto"
27-
"github.com/coder/coder/v2/coderd"
2827
"github.com/coder/coder/v2/coderd/coderdtest"
2928
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
3029
"github.com/coder/coder/v2/coderd/database"
@@ -1389,13 +1388,13 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
13891388
}
13901389
)
13911390

1392-
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
1391+
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
13931392
Version: expectedVersion,
13941393
ExpandedDirectory: expectedDir,
1395-
Subsystems: []codersdk.AgentSubsystem{
1394+
Subsystems: []agentproto.Startup_Subsystem{
13961395
// Not sorted.
1397-
expectedSubsystems[1],
1398-
expectedSubsystems[0],
1396+
agentproto.Startup_EXECTRACE,
1397+
agentproto.Startup_ENVBOX,
13991398
},
14001399
})
14011400
require.NoError(t, err)
@@ -1409,7 +1408,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
14091408
require.Equal(t, expectedDir, wsagent.ExpandedDirectory)
14101409
// Sorted
14111410
require.Equal(t, expectedSubsystems, wsagent.Subsystems)
1412-
require.Equal(t, coderd.AgentAPIVersionREST, wsagent.APIVersion)
1411+
require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion)
14131412
})
14141413

14151414
t.Run("InvalidSemver", func(t *testing.T) {
@@ -1427,13 +1426,10 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
14271426

14281427
ctx := testutil.Context(t, testutil.WaitMedium)
14291428

1430-
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
1429+
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
14311430
Version: "1.2.3",
14321431
})
1433-
require.Error(t, err)
1434-
cerr, ok := codersdk.AsError(err)
1435-
require.True(t, ok)
1436-
require.Equal(t, http.StatusBadRequest, cerr.StatusCode())
1432+
require.ErrorContains(t, err, "invalid agent semver version")
14371433
})
14381434
}
14391435

@@ -1640,3 +1636,15 @@ func requireGetManifest(ctx context.Context, t testing.TB, client agent.Client)
16401636
require.NoError(t, err)
16411637
return manifest
16421638
}
1639+
1640+
func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error {
1641+
conn, err := client.Listen(ctx)
1642+
require.NoError(t, err)
1643+
defer func() {
1644+
cErr := conn.Close()
1645+
require.NoError(t, cErr)
1646+
}()
1647+
aAPI := agentproto.NewDRPCAgentClient(conn)
1648+
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
1649+
return err
1650+
}

coderd/workspaceagentsrpc.go

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
154154
Auth: tailnet.AgentTunnelAuth{},
155155
}
156156
ctx = tailnet.WithStreamID(ctx, streamID)
157+
ctx = agentapi.WithAPIVersion(ctx, version)
157158
err = agentAPI.Serve(ctx, mux)
158159
if err != nil {
159160
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))

codersdk/agentsdk/agentsdk.go

-12
Original file line numberDiff line numberDiff line change
@@ -556,18 +556,6 @@ type PostStartupRequest struct {
556556
Subsystems []codersdk.AgentSubsystem `json:"subsystems"`
557557
}
558558

559-
func (c *Client) PostStartup(ctx context.Context, req PostStartupRequest) error {
560-
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/startup", req)
561-
if err != nil {
562-
return err
563-
}
564-
defer res.Body.Close()
565-
if res.StatusCode != http.StatusOK {
566-
return codersdk.ReadBodyAsError(res)
567-
}
568-
return nil
569-
}
570-
571559
type Log struct {
572560
CreatedAt time.Time `json:"created_at"`
573561
Output string `json:"output"`

codersdk/agentsdk/convert.go

+12
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,15 @@ func ProtoFromServiceBanner(sb codersdk.ServiceBannerConfig) *proto.ServiceBanne
266266
BackgroundColor: sb.BackgroundColor,
267267
}
268268
}
269+
270+
func ProtoFromSubsystems(ss []codersdk.AgentSubsystem) ([]proto.Startup_Subsystem, error) {
271+
ret := make([]proto.Startup_Subsystem, len(ss))
272+
for i, s := range ss {
273+
pi, ok := proto.Startup_Subsystem_value[strings.ToUpper(string(s))]
274+
if !ok {
275+
return nil, xerrors.Errorf("unknown subsystem: %s", s)
276+
}
277+
ret[i] = proto.Startup_Subsystem(pi)
278+
}
279+
return ret, nil
280+
}

0 commit comments

Comments
 (0)