Skip to content

Commit 934a8d3

Browse files
committed
feat: use agent v2 API to post startup
1 parent 3b2aa65 commit 934a8d3

File tree

11 files changed

+102
-71
lines changed

11 files changed

+102
-71
lines changed

agent/agent.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ type Client interface {
9292
ReportStats(ctx context.Context, log slog.Logger, statsChan <-chan *agentsdk.Stats, setInterval func(time.Duration)) (io.Closer, error)
9393
PostLifecycle(ctx context.Context, state agentsdk.PostLifecycleRequest) error
9494
PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsRequest) error
95-
PostStartup(ctx context.Context, req agentsdk.PostStartupRequest) error
9695
PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequest) error
9796
PatchLogs(ctx context.Context, req agentsdk.PatchLogs) error
9897
RewriteDERPMap(derpMap *tailcfg.DERPMap)
@@ -737,13 +736,18 @@ func (a *agent) run(ctx context.Context) error {
737736
if err != nil {
738737
return xerrors.Errorf("expand directory: %w", err)
739738
}
740-
err = a.client.PostStartup(ctx, agentsdk.PostStartupRequest{
739+
subsys, err := agentsdk.ProtoFromSubsystems(a.subsystems)
740+
if err != nil {
741+
a.logger.Critical(ctx, "failed to convert subsystems", slog.Error(err))
742+
return xerrors.Errorf("failed to convert subsystems: %w", err)
743+
}
744+
_, err = aAPI.UpdateStartup(ctx, &proto.UpdateStartupRequest{Startup: &proto.Startup{
741745
Version: buildinfo.Version(),
742746
ExpandedDirectory: manifest.Directory,
743-
Subsystems: a.subsystems,
744-
})
747+
Subsystems: subsys,
748+
}})
745749
if err != nil {
746-
return xerrors.Errorf("update workspace agent version: %w", err)
750+
return xerrors.Errorf("update workspace agent startup: %w", err)
747751
}
748752

749753
oldManifest := a.manifest.Swap(&manifest)

agent/agent_test.go

+12-16
Original file line numberDiff line numberDiff line change
@@ -1395,56 +1395,52 @@ func TestAgent_Startup(t *testing.T) {
13951395

13961396
t.Run("EmptyDirectory", func(t *testing.T) {
13971397
t.Parallel()
1398+
ctx := testutil.Context(t, testutil.WaitShort)
13981399

13991400
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14001401
Directory: "",
14011402
}, 0)
1402-
assert.Eventually(t, func() bool {
1403-
return client.GetStartup().Version != ""
1404-
}, testutil.WaitShort, testutil.IntervalFast)
1405-
require.Equal(t, "", client.GetStartup().ExpandedDirectory)
1403+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
1404+
require.Equal(t, "", startup.GetExpandedDirectory())
14061405
})
14071406

14081407
t.Run("HomeDirectory", func(t *testing.T) {
14091408
t.Parallel()
1409+
ctx := testutil.Context(t, testutil.WaitShort)
14101410

14111411
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14121412
Directory: "~",
14131413
}, 0)
1414-
assert.Eventually(t, func() bool {
1415-
return client.GetStartup().Version != ""
1416-
}, testutil.WaitShort, testutil.IntervalFast)
1414+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14171415
homeDir, err := os.UserHomeDir()
14181416
require.NoError(t, err)
1419-
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
1417+
require.Equal(t, homeDir, startup.GetExpandedDirectory())
14201418
})
14211419

14221420
t.Run("NotAbsoluteDirectory", func(t *testing.T) {
14231421
t.Parallel()
1422+
ctx := testutil.Context(t, testutil.WaitShort)
14241423

14251424
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14261425
Directory: "coder/coder",
14271426
}, 0)
1428-
assert.Eventually(t, func() bool {
1429-
return client.GetStartup().Version != ""
1430-
}, testutil.WaitShort, testutil.IntervalFast)
1427+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14311428
homeDir, err := os.UserHomeDir()
14321429
require.NoError(t, err)
1433-
require.Equal(t, filepath.Join(homeDir, "coder/coder"), client.GetStartup().ExpandedDirectory)
1430+
require.Equal(t, filepath.Join(homeDir, "coder/coder"), startup.GetExpandedDirectory())
14341431
})
14351432

14361433
t.Run("HomeEnvironmentVariable", func(t *testing.T) {
14371434
t.Parallel()
1435+
ctx := testutil.Context(t, testutil.WaitShort)
14381436

14391437
_, client, _, _, _ := setupAgent(t, agentsdk.Manifest{
14401438
Directory: "$HOME",
14411439
}, 0)
1442-
assert.Eventually(t, func() bool {
1443-
return client.GetStartup().Version != ""
1444-
}, testutil.WaitShort, testutil.IntervalFast)
1440+
startup := testutil.RequireRecvCtx(ctx, t, client.GetStartup())
14451441
homeDir, err := os.UserHomeDir()
14461442
require.NoError(t, err)
1447-
require.Equal(t, homeDir, client.GetStartup().ExpandedDirectory)
1443+
require.Equal(t, homeDir, startup.GetExpandedDirectory())
14481444
})
14491445
}
14501446

agent/agenttest/client.go

+11-20
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ type Client struct {
8888

8989
mu sync.Mutex // Protects following.
9090
lifecycleStates []codersdk.WorkspaceAgentLifecycle
91-
startup agentsdk.PostStartupRequest
9291
logs []agentsdk.Log
9392
derpMapUpdates chan *tailcfg.DERPMap
9493
derpMapOnce sync.Once
@@ -173,10 +172,8 @@ func (c *Client) PostAppHealth(ctx context.Context, req agentsdk.PostAppHealthsR
173172
return nil
174173
}
175174

176-
func (c *Client) GetStartup() agentsdk.PostStartupRequest {
177-
c.mu.Lock()
178-
defer c.mu.Unlock()
179-
return c.startup
175+
func (c *Client) GetStartup() <-chan *agentproto.Startup {
176+
return c.fakeAgentAPI.startupCh
180177
}
181178

182179
func (c *Client) GetMetadata() map[string]agentsdk.Metadata {
@@ -198,14 +195,6 @@ func (c *Client) PostMetadata(ctx context.Context, req agentsdk.PostMetadataRequ
198195
return nil
199196
}
200197

201-
func (c *Client) PostStartup(ctx context.Context, startup agentsdk.PostStartupRequest) error {
202-
c.mu.Lock()
203-
defer c.mu.Unlock()
204-
c.startup = startup
205-
c.logger.Debug(ctx, "post startup", slog.F("req", startup))
206-
return nil
207-
}
208-
209198
func (c *Client) GetStartupLogs() []agentsdk.Log {
210199
c.mu.Lock()
211200
defer c.mu.Unlock()
@@ -250,7 +239,8 @@ type FakeAgentAPI struct {
250239
t testing.TB
251240
logger slog.Logger
252241

253-
manifest *agentproto.Manifest
242+
manifest *agentproto.Manifest
243+
startupCh chan *agentproto.Startup
254244

255245
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
256246
}
@@ -294,9 +284,9 @@ func (*FakeAgentAPI) BatchUpdateAppHealths(context.Context, *agentproto.BatchUpd
294284
panic("implement me")
295285
}
296286

297-
func (*FakeAgentAPI) UpdateStartup(context.Context, *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
298-
// TODO implement me
299-
panic("implement me")
287+
func (f *FakeAgentAPI) UpdateStartup(_ context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
288+
f.startupCh <- req.GetStartup()
289+
return req.GetStartup(), nil
300290
}
301291

302292
func (*FakeAgentAPI) BatchUpdateMetadata(context.Context, *agentproto.BatchUpdateMetadataRequest) (*agentproto.BatchUpdateMetadataResponse, error) {
@@ -311,8 +301,9 @@ func (*FakeAgentAPI) BatchCreateLogs(context.Context, *agentproto.BatchCreateLog
311301

312302
func NewFakeAgentAPI(t testing.TB, logger slog.Logger, manifest *agentproto.Manifest) *FakeAgentAPI {
313303
return &FakeAgentAPI{
314-
t: t,
315-
logger: logger.Named("FakeAgentAPI"),
316-
manifest: manifest,
304+
t: t,
305+
logger: logger.Named("FakeAgentAPI"),
306+
manifest: manifest,
307+
startupCh: make(chan *agentproto.Startup, 100),
317308
}
318309
}

coderd/agentapi/api.go

-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ import (
2929
tailnetproto "github.com/coder/coder/v2/tailnet/proto"
3030
)
3131

32-
const AgentAPIVersionDRPC = "2.0"
33-
3432
// API implements the DRPC agent API interface from agent/proto. This struct is
3533
// instantiated once per agent connection and kept alive for the duration of the
3634
// session.

coderd/agentapi/lifecycle.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"time"
77

88
"github.com/google/uuid"
9+
"golang.org/x/exp/slices"
910
"golang.org/x/mod/semver"
1011
"golang.org/x/xerrors"
1112
"google.golang.org/protobuf/types/known/timestamppb"
@@ -16,6 +17,12 @@ import (
1617
"github.com/coder/coder/v2/coderd/database/dbtime"
1718
)
1819

20+
type contextKeyAPIVersion struct{}
21+
22+
func WithAPIVersion(ctx context.Context, version string) context.Context {
23+
return context.WithValue(ctx, contextKeyAPIVersion{}, version)
24+
}
25+
1926
type LifecycleAPI struct {
2027
AgentFn func(context.Context) (database.WorkspaceAgent, error)
2128
WorkspaceIDFn func(context.Context, *database.WorkspaceAgent) (uuid.UUID, error)
@@ -123,6 +130,10 @@ func (a *LifecycleAPI) UpdateLifecycle(ctx context.Context, req *agentproto.Upda
123130
}
124131

125132
func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.UpdateStartupRequest) (*agentproto.Startup, error) {
133+
apiVersion, ok := ctx.Value(contextKeyAPIVersion{}).(string)
134+
if !ok {
135+
return nil, xerrors.Errorf("internal error; api version unspecified")
136+
}
126137
workspaceAgent, err := a.AgentFn(ctx)
127138
if err != nil {
128139
return nil, err
@@ -164,13 +175,14 @@ func (a *LifecycleAPI) UpdateStartup(ctx context.Context, req *agentproto.Update
164175
dbSubsystems = append(dbSubsystems, dbSubsystem)
165176
}
166177
}
178+
slices.Sort(dbSubsystems)
167179

168180
err = a.Database.UpdateWorkspaceAgentStartupByID(ctx, database.UpdateWorkspaceAgentStartupByIDParams{
169181
ID: workspaceAgent.ID,
170182
Version: req.Startup.Version,
171183
ExpandedDirectory: req.Startup.ExpandedDirectory,
172184
Subsystems: dbSubsystems,
173-
APIVersion: AgentAPIVersionDRPC,
185+
APIVersion: apiVersion,
174186
})
175187
if err != nil {
176188
return nil, xerrors.Errorf("update workspace agent startup in database: %w", err)

coderd/agentapi/lifecycle_test.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,11 @@ func TestUpdateStartup(t *testing.T) {
382382
database.WorkspaceAgentSubsystemEnvbuilder,
383383
database.WorkspaceAgentSubsystemExectrace,
384384
},
385-
APIVersion: agentapi.AgentAPIVersionDRPC,
385+
APIVersion: "2.0",
386386
}).Return(nil)
387387

388-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
388+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
389+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
389390
Startup: startup,
390391
})
391392
require.NoError(t, err)
@@ -416,7 +417,8 @@ func TestUpdateStartup(t *testing.T) {
416417
Subsystems: []agentproto.Startup_Subsystem{},
417418
}
418419

419-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
420+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
421+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
420422
Startup: startup,
421423
})
422424
require.Error(t, err)
@@ -451,7 +453,8 @@ func TestUpdateStartup(t *testing.T) {
451453
},
452454
}
453455

454-
resp, err := api.UpdateStartup(context.Background(), &agentproto.UpdateStartupRequest{
456+
ctx := agentapi.WithAPIVersion(context.Background(), "2.0")
457+
resp, err := api.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{
455458
Startup: startup,
456459
})
457460
require.Error(t, err)

coderd/workspaceagents_test.go

+19-11
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import (
2424
"github.com/coder/coder/v2/agent"
2525
"github.com/coder/coder/v2/agent/agenttest"
2626
agentproto "github.com/coder/coder/v2/agent/proto"
27-
"github.com/coder/coder/v2/coderd"
2827
"github.com/coder/coder/v2/coderd/coderdtest"
2928
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
3029
"github.com/coder/coder/v2/coderd/database"
@@ -1389,13 +1388,13 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
13891388
}
13901389
)
13911390

1392-
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
1391+
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
13931392
Version: expectedVersion,
13941393
ExpandedDirectory: expectedDir,
1395-
Subsystems: []codersdk.AgentSubsystem{
1394+
Subsystems: []agentproto.Startup_Subsystem{
13961395
// Not sorted.
1397-
expectedSubsystems[1],
1398-
expectedSubsystems[0],
1396+
agentproto.Startup_EXECTRACE,
1397+
agentproto.Startup_ENVBOX,
13991398
},
14001399
})
14011400
require.NoError(t, err)
@@ -1409,7 +1408,7 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
14091408
require.Equal(t, expectedDir, wsagent.ExpandedDirectory)
14101409
// Sorted
14111410
require.Equal(t, expectedSubsystems, wsagent.Subsystems)
1412-
require.Equal(t, coderd.AgentAPIVersionREST, wsagent.APIVersion)
1411+
require.Equal(t, agentproto.CurrentVersion.String(), wsagent.APIVersion)
14131412
})
14141413

14151414
t.Run("InvalidSemver", func(t *testing.T) {
@@ -1427,13 +1426,10 @@ func TestWorkspaceAgent_Startup(t *testing.T) {
14271426

14281427
ctx := testutil.Context(t, testutil.WaitMedium)
14291428

1430-
err := agentClient.PostStartup(ctx, agentsdk.PostStartupRequest{
1429+
err := postStartup(ctx, t, agentClient, &agentproto.Startup{
14311430
Version: "1.2.3",
14321431
})
1433-
require.Error(t, err)
1434-
cerr, ok := codersdk.AsError(err)
1435-
require.True(t, ok)
1436-
require.Equal(t, http.StatusBadRequest, cerr.StatusCode())
1432+
require.ErrorContains(t, err, "invalid agent semver version")
14371433
})
14381434
}
14391435

@@ -1639,3 +1635,15 @@ func requireGetManifest(ctx context.Context, t testing.TB, client agent.Client)
16391635
require.NoError(t, err)
16401636
return manifest
16411637
}
1638+
1639+
func postStartup(ctx context.Context, t testing.TB, client agent.Client, startup *agentproto.Startup) error {
1640+
conn, err := client.Listen(ctx)
1641+
require.NoError(t, err)
1642+
defer func() {
1643+
cErr := conn.Close()
1644+
require.NoError(t, cErr)
1645+
}()
1646+
aAPI := agentproto.NewDRPCAgentClient(conn)
1647+
_, err = aAPI.UpdateStartup(ctx, &agentproto.UpdateStartupRequest{Startup: startup})
1648+
return err
1649+
}

coderd/workspaceagentsrpc.go

+1
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
154154
Auth: tailnet.AgentTunnelAuth{},
155155
}
156156
ctx = tailnet.WithStreamID(ctx, streamID)
157+
ctx = agentapi.WithAPIVersion(ctx, version)
157158
err = agentAPI.Serve(ctx, mux)
158159
if err != nil {
159160
api.Logger.Warn(ctx, "workspace agent RPC listen error", slog.Error(err))

codersdk/agentsdk/agentsdk.go

-12
Original file line numberDiff line numberDiff line change
@@ -556,18 +556,6 @@ type PostStartupRequest struct {
556556
Subsystems []codersdk.AgentSubsystem `json:"subsystems"`
557557
}
558558

559-
func (c *Client) PostStartup(ctx context.Context, req PostStartupRequest) error {
560-
res, err := c.SDK.Request(ctx, http.MethodPost, "/api/v2/workspaceagents/me/startup", req)
561-
if err != nil {
562-
return err
563-
}
564-
defer res.Body.Close()
565-
if res.StatusCode != http.StatusOK {
566-
return codersdk.ReadBodyAsError(res)
567-
}
568-
return nil
569-
}
570-
571559
type Log struct {
572560
CreatedAt time.Time `json:"created_at"`
573561
Output string `json:"output"`

codersdk/agentsdk/convert.go

+12
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,15 @@ func ProtoFromServiceBanner(sb codersdk.ServiceBannerConfig) *proto.ServiceBanne
266266
BackgroundColor: sb.BackgroundColor,
267267
}
268268
}
269+
270+
func ProtoFromSubsystems(ss []codersdk.AgentSubsystem) ([]proto.Startup_Subsystem, error) {
271+
ret := make([]proto.Startup_Subsystem, len(ss))
272+
for i, s := range ss {
273+
pi, ok := proto.Startup_Subsystem_value[strings.ToUpper(string(s))]
274+
if !ok {
275+
return nil, xerrors.Errorf("unknown subsystem: %s", s)
276+
}
277+
ret[i] = proto.Startup_Subsystem(pi)
278+
}
279+
return ret, nil
280+
}

0 commit comments

Comments
 (0)