Skip to content

Commit 5146aff

Browse files
ThomasK33dannykopping
authored andcommitted
fix: remove unnecessary user lookup in agent API calls (#17934)
This PR optimizes the agent API by using the `workspace.OwnerUsername` field directly instead of making an additional database query to fetch the owner's username. The change removes the need to call `GetUserByID` in the manifest API and workspace agent RPC endpoints. An issue arose when the agent token was scoped without access to user data (`api_key_scope = "no_user_data"`), causing the agent to fail to fetch the manifest due to an RBAC issue. Change-Id: I3b6e7581134e2374b364ee059e3b18ece3d98b41 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 32f093e commit 5146aff

File tree

6 files changed

+182
-105
lines changed

6 files changed

+182
-105
lines changed

coderd/agentapi/manifest.go

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
4747
scripts []database.WorkspaceAgentScript
4848
metadata []database.WorkspaceAgentMetadatum
4949
workspace database.Workspace
50-
owner database.User
5150
devcontainers []database.WorkspaceAgentDevcontainer
5251
)
5352

@@ -76,10 +75,6 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
7675
if err != nil {
7776
return xerrors.Errorf("getting workspace by id: %w", err)
7877
}
79-
owner, err = a.Database.GetUserByID(ctx, workspace.OwnerID)
80-
if err != nil {
81-
return xerrors.Errorf("getting workspace owner by id: %w", err)
82-
}
8378
return err
8479
})
8580
eg.Go(func() (err error) {
@@ -98,7 +93,7 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
9893
AppSlugOrPort: "{{port}}",
9994
AgentName: workspaceAgent.Name,
10095
WorkspaceName: workspace.Name,
101-
Username: owner.Username,
96+
Username: workspace.OwnerUsername,
10297
}
10398

10499
vscodeProxyURI := vscodeProxyURI(appSlug, a.AccessURL, a.AppHostname)
@@ -115,15 +110,15 @@ func (a *ManifestAPI) GetManifest(ctx context.Context, _ *agentproto.GetManifest
115110
}
116111
}
117112

118-
apps, err := dbAppsToProto(dbApps, workspaceAgent, owner.Username, workspace)
113+
apps, err := dbAppsToProto(dbApps, workspaceAgent, workspace.OwnerUsername, workspace)
119114
if err != nil {
120115
return nil, xerrors.Errorf("converting workspace apps: %w", err)
121116
}
122117

123118
return &agentproto.Manifest{
124119
AgentId: workspaceAgent.ID[:],
125120
AgentName: workspaceAgent.Name,
126-
OwnerUsername: owner.Username,
121+
OwnerUsername: workspace.OwnerUsername,
127122
WorkspaceId: workspace.ID[:],
128123
WorkspaceName: workspace.Name,
129124
GitAuthConfigs: gitAuthConfigs,

coderd/agentapi/manifest_test.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ func TestGetManifest(t *testing.T) {
4646
Username: "cool-user",
4747
}
4848
workspace = database.Workspace{
49-
ID: uuid.New(),
50-
OwnerID: owner.ID,
51-
Name: "cool-workspace",
49+
ID: uuid.New(),
50+
OwnerID: owner.ID,
51+
OwnerUsername: owner.Username,
52+
Name: "cool-workspace",
5253
}
5354
agent = database.WorkspaceAgent{
5455
ID: uuid.New(),
@@ -329,7 +330,6 @@ func TestGetManifest(t *testing.T) {
329330
}).Return(metadata, nil)
330331
mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil)
331332
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
332-
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
333333

334334
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
335335
require.NoError(t, err)
@@ -396,7 +396,6 @@ func TestGetManifest(t *testing.T) {
396396
}).Return(metadata, nil)
397397
mDB.EXPECT().GetWorkspaceAgentDevcontainersByAgentID(gomock.Any(), agent.ID).Return(devcontainers, nil)
398398
mDB.EXPECT().GetWorkspaceByID(gomock.Any(), workspace.ID).Return(workspace, nil)
399-
mDB.EXPECT().GetUserByID(gomock.Any(), workspace.OwnerID).Return(owner, nil)
400399

401400
got, err := api.GetManifest(context.Background(), &agentproto.GetManifestRequest{})
402401
require.NoError(t, err)

coderd/workspaceagents_test.go

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -437,25 +437,55 @@ func TestWorkspaceAgentConnectRPC(t *testing.T) {
437437
t.Run("Connect", func(t *testing.T) {
438438
t.Parallel()
439439

440-
client, db := coderdtest.NewWithDatabase(t, nil)
441-
user := coderdtest.CreateFirstUser(t, client)
442-
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
443-
OrganizationID: user.OrganizationID,
444-
OwnerID: user.UserID,
445-
}).WithAgent().Do()
446-
_ = agenttest.New(t, client.URL, r.AgentToken)
447-
resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID)
440+
for _, tc := range []struct {
441+
name string
442+
apiKeyScope rbac.ScopeName
443+
}{
444+
{
445+
name: "empty (backwards compat)",
446+
apiKeyScope: "",
447+
},
448+
{
449+
name: "all",
450+
apiKeyScope: rbac.ScopeAll,
451+
},
452+
{
453+
name: "no_user_data",
454+
apiKeyScope: rbac.ScopeNoUserData,
455+
},
456+
{
457+
name: "application_connect",
458+
apiKeyScope: rbac.ScopeApplicationConnect,
459+
},
460+
} {
461+
t.Run(tc.name, func(t *testing.T) {
462+
client, db := coderdtest.NewWithDatabase(t, nil)
463+
user := coderdtest.CreateFirstUser(t, client)
464+
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
465+
OrganizationID: user.OrganizationID,
466+
OwnerID: user.UserID,
467+
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
468+
for _, agent := range agents {
469+
agent.ApiKeyScope = string(tc.apiKeyScope)
470+
}
448471

449-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
450-
defer cancel()
472+
return agents
473+
}).Do()
474+
_ = agenttest.New(t, client.URL, r.AgentToken)
475+
resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).AgentNames([]string{}).Wait()
451476

452-
conn, err := workspacesdk.New(client).
453-
DialAgent(ctx, resources[0].Agents[0].ID, nil)
454-
require.NoError(t, err)
455-
defer func() {
456-
_ = conn.Close()
457-
}()
458-
conn.AwaitReachable(ctx)
477+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
478+
defer cancel()
479+
480+
conn, err := workspacesdk.New(client).
481+
DialAgent(ctx, resources[0].Agents[0].ID, nil)
482+
require.NoError(t, err)
483+
defer func() {
484+
_ = conn.Close()
485+
}()
486+
conn.AwaitReachable(ctx)
487+
})
488+
}
459489
})
460490

461491
t.Run("FailNonLatestBuild", func(t *testing.T) {

coderd/workspaceagentsrpc.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,8 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
7676
return
7777
}
7878

79-
owner, err := api.Database.GetUserByID(ctx, workspace.OwnerID)
80-
if err != nil {
81-
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
82-
Message: "Internal error fetching user.",
83-
Detail: err.Error(),
84-
})
85-
return
86-
}
87-
8879
logger = logger.With(
89-
slog.F("owner", owner.Username),
80+
slog.F("owner", workspace.OwnerUsername),
9081
slog.F("workspace_name", workspace.Name),
9182
slog.F("agent_name", workspaceAgent.Name),
9283
)
@@ -170,7 +161,7 @@ func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
170161
})
171162

172163
streamID := tailnet.StreamID{
173-
Name: fmt.Sprintf("%s-%s-%s", owner.Username, workspace.Name, workspaceAgent.Name),
164+
Name: fmt.Sprintf("%s-%s-%s", workspace.OwnerUsername, workspace.Name, workspaceAgent.Name),
174165
ID: workspaceAgent.ID,
175166
Auth: tailnet.AgentCoordinateeAuth{ID: workspaceAgent.ID},
176167
}

coderd/workspaceagentsrpc_test.go

Lines changed: 125 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/coder/coder/v2/coderd/database"
1414
"github.com/coder/coder/v2/coderd/database/dbfake"
1515
"github.com/coder/coder/v2/coderd/database/dbtime"
16+
"github.com/coder/coder/v2/coderd/rbac"
1617
"github.com/coder/coder/v2/codersdk/agentsdk"
1718
"github.com/coder/coder/v2/provisionersdk/proto"
1819
"github.com/coder/coder/v2/testutil"
@@ -22,6 +23,30 @@ import (
2223
func TestWorkspaceAgentReportStats(t *testing.T) {
2324
t.Parallel()
2425

26+
for _, tc := range []struct {
27+
name string
28+
apiKeyScope rbac.ScopeName
29+
}{
30+
{
31+
name: "empty (backwards compat)",
32+
apiKeyScope: "",
33+
},
34+
{
35+
name: "all",
36+
apiKeyScope: rbac.ScopeAll,
37+
},
38+
{
39+
name: "no_user_data",
40+
apiKeyScope: rbac.ScopeNoUserData,
41+
},
42+
{
43+
name: "application_connect",
44+
apiKeyScope: rbac.ScopeApplicationConnect,
45+
},
46+
} {
47+
t.Run(tc.name, func(t *testing.T) {
48+
t.Parallel()
49+
2550
tickCh := make(chan time.Time)
2651
flushCh := make(chan int, 1)
2752
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
@@ -32,78 +57,114 @@ func TestWorkspaceAgentReportStats(t *testing.T) {
3257
r := dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{
3358
OrganizationID: user.OrganizationID,
3459
OwnerID: user.UserID,
35-
}).WithAgent().Do()
60+
}).WithAgent(func(agent []*proto.Agent) []*proto.Agent {
61+
for _, a := range agent {
62+
a.ApiKeyScope = string(tc.apiKeyScope)
63+
}
3664

37-
ac := agentsdk.New(client.URL)
38-
ac.SetSessionToken(r.AgentToken)
39-
conn, err := ac.ConnectRPC(context.Background())
40-
require.NoError(t, err)
41-
defer func() {
42-
_ = conn.Close()
43-
}()
44-
agentAPI := agentproto.NewDRPCAgentClient(conn)
65+
return agent
66+
},
67+
).Do()
4568

46-
_, err = agentAPI.UpdateStats(context.Background(), &agentproto.UpdateStatsRequest{
47-
Stats: &agentproto.Stats{
48-
ConnectionsByProto: map[string]int64{"TCP": 1},
49-
ConnectionCount: 1,
50-
RxPackets: 1,
51-
RxBytes: 1,
52-
TxPackets: 1,
53-
TxBytes: 1,
54-
SessionCountVscode: 1,
55-
SessionCountJetbrains: 0,
56-
SessionCountReconnectingPty: 0,
57-
SessionCountSsh: 0,
58-
ConnectionMedianLatencyMs: 10,
59-
},
60-
})
61-
require.NoError(t, err)
69+
ac := agentsdk.New(client.URL)
70+
ac.SetSessionToken(r.AgentToken)
71+
conn, err := ac.ConnectRPC(context.Background())
72+
require.NoError(t, err)
73+
defer func() {
74+
_ = conn.Close()
75+
}()
76+
agentAPI := agentproto.NewDRPCAgentClient(conn)
6277

63-
tickCh <- dbtime.Now()
64-
count := <-flushCh
65-
require.Equal(t, 1, count, "expected one flush with one id")
78+
_, err = agentAPI.UpdateStats(context.Background(), &agentproto.UpdateStatsRequest{
79+
Stats: &agentproto.Stats{
80+
ConnectionsByProto: map[string]int64{"TCP": 1},
81+
ConnectionCount: 1,
82+
RxPackets: 1,
83+
RxBytes: 1,
84+
TxPackets: 1,
85+
TxBytes: 1,
86+
SessionCountVscode: 1,
87+
SessionCountJetbrains: 0,
88+
SessionCountReconnectingPty: 0,
89+
SessionCountSsh: 0,
90+
ConnectionMedianLatencyMs: 10,
91+
},
92+
})
93+
require.NoError(t, err)
6694

67-
newWorkspace, err := client.Workspace(context.Background(), r.Workspace.ID)
68-
require.NoError(t, err)
95+
tickCh <- dbtime.Now()
96+
count := <-flushCh
97+
require.Equal(t, 1, count, "expected one flush with one id")
6998

70-
assert.True(t,
71-
newWorkspace.LastUsedAt.After(r.Workspace.LastUsedAt),
72-
"%s is not after %s", newWorkspace.LastUsedAt, r.Workspace.LastUsedAt,
73-
)
99+
newWorkspace, err := client.Workspace(context.Background(), r.Workspace.ID)
100+
require.NoError(t, err)
101+
102+
assert.True(t,
103+
newWorkspace.LastUsedAt.After(r.Workspace.LastUsedAt),
104+
"%s is not after %s", newWorkspace.LastUsedAt, r.Workspace.LastUsedAt,
105+
)
106+
})
107+
}
74108
}
75109

76110
func TestAgentAPI_LargeManifest(t *testing.T) {
77111
t.Parallel()
78-
ctx := testutil.Context(t, testutil.WaitLong)
79-
client, store := coderdtest.NewWithDatabase(t, nil)
80-
adminUser := coderdtest.CreateFirstUser(t, client)
81-
n := 512000
82-
longScript := make([]byte, n)
83-
for i := range longScript {
84-
longScript[i] = 'q'
112+
113+
for _, tc := range []struct {
114+
name string
115+
apiKeyScope rbac.ScopeName
116+
}{
117+
{
118+
name: "empty (backwards compat)",
119+
apiKeyScope: "",
120+
},
121+
{
122+
name: "all",
123+
apiKeyScope: rbac.ScopeAll,
124+
},
125+
{
126+
name: "no_user_data",
127+
apiKeyScope: rbac.ScopeNoUserData,
128+
},
129+
{
130+
name: "application_connect",
131+
apiKeyScope: rbac.ScopeApplicationConnect,
132+
},
133+
} {
134+
t.Run(tc.name, func(t *testing.T) {
135+
t.Parallel()
136+
ctx := testutil.Context(t, testutil.WaitLong)
137+
client, store := coderdtest.NewWithDatabase(t, nil)
138+
adminUser := coderdtest.CreateFirstUser(t, client)
139+
n := 512000
140+
longScript := make([]byte, n)
141+
for i := range longScript {
142+
longScript[i] = 'q'
143+
}
144+
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
145+
OrganizationID: adminUser.OrganizationID,
146+
OwnerID: adminUser.UserID,
147+
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
148+
agents[0].Scripts = []*proto.Script{
149+
{
150+
Script: string(longScript),
151+
},
152+
}
153+
agents[0].ApiKeyScope = string(tc.apiKeyScope)
154+
return agents
155+
}).Do()
156+
ac := agentsdk.New(client.URL)
157+
ac.SetSessionToken(r.AgentToken)
158+
conn, err := ac.ConnectRPC(ctx)
159+
defer func() {
160+
_ = conn.Close()
161+
}()
162+
require.NoError(t, err)
163+
agentAPI := agentproto.NewDRPCAgentClient(conn)
164+
manifest, err := agentAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
165+
require.NoError(t, err)
166+
require.Len(t, manifest.Scripts, 1)
167+
require.Len(t, manifest.Scripts[0].Script, n)
168+
})
85169
}
86-
r := dbfake.WorkspaceBuild(t, store, database.WorkspaceTable{
87-
OrganizationID: adminUser.OrganizationID,
88-
OwnerID: adminUser.UserID,
89-
}).WithAgent(func(agents []*proto.Agent) []*proto.Agent {
90-
agents[0].Scripts = []*proto.Script{
91-
{
92-
Script: string(longScript),
93-
},
94-
}
95-
return agents
96-
}).Do()
97-
ac := agentsdk.New(client.URL)
98-
ac.SetSessionToken(r.AgentToken)
99-
conn, err := ac.ConnectRPC(ctx)
100-
defer func() {
101-
_ = conn.Close()
102-
}()
103-
require.NoError(t, err)
104-
agentAPI := agentproto.NewDRPCAgentClient(conn)
105-
manifest, err := agentAPI.GetManifest(ctx, &agentproto.GetManifestRequest{})
106-
require.NoError(t, err)
107-
require.Len(t, manifest.Scripts, 1)
108-
require.Len(t, manifest.Scripts[0].Script, n)
109170
}

0 commit comments

Comments
 (0)