diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index efaa07e958620..9123c7fba5135 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1474,13 +1474,12 @@ func (q *querier) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]databas return q.db.GetUsersByIDs(ctx, ids) } -// GetWorkspaceAgentByAuthToken is used in http middleware to get the workspace agent. -// This should only be used by a system user in that middleware. -func (q *querier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { +func (q *querier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { + // This is a system function if err := q.authorizeContext(ctx, rbac.ActionRead, rbac.ResourceSystem); err != nil { - return database.WorkspaceAgent{}, err + return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, err } - return q.db.GetWorkspaceAgentByAuthToken(ctx, authToken) + return q.db.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken) } func (q *querier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 801e299a35be7..3b41e67a0c0df 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1319,10 +1319,6 @@ func (s *MethodTestSuite) TestSystemFunctions() { dbgen.WorkspaceBuild(s.T(), db, database.WorkspaceBuild{}) check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead) })) - s.Run("GetWorkspaceAgentByAuthToken", s.Subtest(func(db database.Store, check *expects) { - agt := dbgen.WorkspaceAgent(s.T(), db, database.WorkspaceAgent{}) - check.Args(agt.AuthToken).Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(agt) - })) s.Run("GetActiveUserCount", s.Subtest(func(db database.Store, check *expects) { check.Args().Asserts(rbac.ResourceSystem, rbac.ActionRead).Returns(int64(0)) })) diff --git a/coderd/database/dbfake/dbfake.go b/coderd/database/dbfake/dbfake.go index 8fc30e1b174b3..162aa9195cb91 100644 --- a/coderd/database/dbfake/dbfake.go +++ b/coderd/database/dbfake/dbfake.go @@ -2791,18 +2791,72 @@ func (q *FakeQuerier) GetUsersByIDs(_ context.Context, ids []uuid.UUID) ([]datab return users, nil } -func (q *FakeQuerier) GetWorkspaceAgentByAuthToken(_ context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { +func (q *FakeQuerier) GetWorkspaceAgentAndOwnerByAuthToken(_ context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { q.mutex.RLock() defer q.mutex.RUnlock() - // The schema sorts this by created at, so we iterate the array backwards. - for i := len(q.workspaceAgents) - 1; i >= 0; i-- { - agent := q.workspaceAgents[i] - if agent.AuthToken == authToken { - return agent, nil + // map of build number -> row + rows := make(map[int32]database.GetWorkspaceAgentAndOwnerByAuthTokenRow) + + // We want to return the latest build number + var latestBuildNumber int32 + + for _, agt := range q.workspaceAgents { + if agt.AuthToken != authToken { + continue + } + // get the related workspace and user + for _, res := range q.workspaceResources { + if agt.ResourceID != res.ID { + continue + } + for _, build := range q.workspaceBuilds { + if build.JobID != res.JobID { + continue + } + for _, ws := range q.workspaces { + if build.WorkspaceID != ws.ID { + continue + } + var row database.GetWorkspaceAgentAndOwnerByAuthTokenRow + row.WorkspaceID = ws.ID + usr, err := q.getUserByIDNoLock(ws.OwnerID) + if err != nil { + return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows + } + row.OwnerID = usr.ID + row.OwnerRoles = append(usr.RBACRoles, "member") + // We also need to get org roles for the user + row.OwnerName = usr.Username + row.WorkspaceAgent = agt + for _, mem := range q.organizationMembers { + if mem.UserID == usr.ID { + row.OwnerRoles = append(row.OwnerRoles, fmt.Sprintf("organization-member:%s", mem.OrganizationID.String())) + } + } + // And group memberships + for _, groupMem := range q.groupMembers { + if groupMem.UserID == usr.ID { + row.OwnerGroups = append(row.OwnerGroups, groupMem.GroupID.String()) + } + } + + // Keep track of the latest build number + rows[build.BuildNumber] = row + if build.BuildNumber > latestBuildNumber { + latestBuildNumber = build.BuildNumber + } + } + } } } - return database.WorkspaceAgent{}, sql.ErrNoRows + + if len(rows) == 0 { + return database.GetWorkspaceAgentAndOwnerByAuthTokenRow{}, sql.ErrNoRows + } + + // Return the row related to the latest build + return rows[latestBuildNumber], nil } func (q *FakeQuerier) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { diff --git a/coderd/database/dbmetrics/dbmetrics.go b/coderd/database/dbmetrics/dbmetrics.go index 9c9601232fb30..7edba848d7588 100644 --- a/coderd/database/dbmetrics/dbmetrics.go +++ b/coderd/database/dbmetrics/dbmetrics.go @@ -781,11 +781,11 @@ func (m metricsStore) GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]dat return users, err } -func (m metricsStore) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (database.WorkspaceAgent, error) { +func (m metricsStore) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { start := time.Now() - agent, err := m.s.GetWorkspaceAgentByAuthToken(ctx, authToken) - m.queryLatencies.WithLabelValues("GetWorkspaceAgentByAuthToken").Observe(time.Since(start).Seconds()) - return agent, err + r0, r1 := m.s.GetWorkspaceAgentAndOwnerByAuthToken(ctx, authToken) + m.queryLatencies.WithLabelValues("GetWorkspaceAgentAndOwnerByAuthToken").Observe(time.Since(start).Seconds()) + return r0, r1 } func (m metricsStore) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (database.WorkspaceAgent, error) { diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 23a7e8b5935c0..9ce7c73b6d85e 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1616,19 +1616,19 @@ func (mr *MockStoreMockRecorder) GetUsersByIDs(arg0, arg1 interface{}) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUsersByIDs", reflect.TypeOf((*MockStore)(nil).GetUsersByIDs), arg0, arg1) } -// GetWorkspaceAgentByAuthToken mocks base method. -func (m *MockStore) GetWorkspaceAgentByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.WorkspaceAgent, error) { +// GetWorkspaceAgentAndOwnerByAuthToken mocks base method. +func (m *MockStore) GetWorkspaceAgentAndOwnerByAuthToken(arg0 context.Context, arg1 uuid.UUID) (database.GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetWorkspaceAgentByAuthToken", arg0, arg1) - ret0, _ := ret[0].(database.WorkspaceAgent) + ret := m.ctrl.Call(m, "GetWorkspaceAgentAndOwnerByAuthToken", arg0, arg1) + ret0, _ := ret[0].(database.GetWorkspaceAgentAndOwnerByAuthTokenRow) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetWorkspaceAgentByAuthToken indicates an expected call of GetWorkspaceAgentByAuthToken. -func (mr *MockStoreMockRecorder) GetWorkspaceAgentByAuthToken(arg0, arg1 interface{}) *gomock.Call { +// GetWorkspaceAgentAndOwnerByAuthToken indicates an expected call of GetWorkspaceAgentAndOwnerByAuthToken. +func (mr *MockStoreMockRecorder) GetWorkspaceAgentAndOwnerByAuthToken(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentByAuthToken), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWorkspaceAgentAndOwnerByAuthToken", reflect.TypeOf((*MockStore)(nil).GetWorkspaceAgentAndOwnerByAuthToken), arg0, arg1) } // GetWorkspaceAgentByID mocks base method. diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 3a8f97307114d..6ddaeffe2d9ff 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -151,7 +151,7 @@ type sqlcQuerier interface { // to look up references to actions. eg. a user could build a workspace // for another user, then be deleted... we still want them to appear! GetUsersByIDs(ctx context.Context, ids []uuid.UUID) ([]User, error) - GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) + GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error) GetWorkspaceAgentByID(ctx context.Context, id uuid.UUID) (WorkspaceAgent, error) GetWorkspaceAgentByInstanceID(ctx context.Context, authInstanceID string) (WorkspaceAgent, error) GetWorkspaceAgentLifecycleStateByID(ctx context.Context, id uuid.UUID) (GetWorkspaceAgentLifecycleStateByIDRow, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1a8c7598f7f59..fe51d8cd6d244 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -6248,54 +6248,113 @@ func (q *sqlQuerier) DeleteOldWorkspaceAgentLogs(ctx context.Context) error { return err } -const getWorkspaceAgentByAuthToken = `-- name: GetWorkspaceAgentByAuthToken :one +const getWorkspaceAgentAndOwnerByAuthToken = `-- name: GetWorkspaceAgentAndOwnerByAuthToken :one SELECT - id, created_at, updated_at, name, first_connected_at, last_connected_at, disconnected_at, resource_id, auth_token, auth_instance_id, architecture, environment_variables, operating_system, startup_script, instance_metadata, resource_metadata, directory, version, last_connected_replica_id, connection_timeout_seconds, troubleshooting_url, motd_file, lifecycle_state, startup_script_timeout_seconds, expanded_directory, shutdown_script, shutdown_script_timeout_seconds, logs_length, logs_overflowed, startup_script_behavior, started_at, ready_at, subsystems -FROM - workspace_agents + workspace_agents.id, workspace_agents.created_at, workspace_agents.updated_at, workspace_agents.name, workspace_agents.first_connected_at, workspace_agents.last_connected_at, workspace_agents.disconnected_at, workspace_agents.resource_id, workspace_agents.auth_token, workspace_agents.auth_instance_id, workspace_agents.architecture, workspace_agents.environment_variables, workspace_agents.operating_system, workspace_agents.startup_script, workspace_agents.instance_metadata, workspace_agents.resource_metadata, workspace_agents.directory, workspace_agents.version, workspace_agents.last_connected_replica_id, workspace_agents.connection_timeout_seconds, workspace_agents.troubleshooting_url, workspace_agents.motd_file, workspace_agents.lifecycle_state, workspace_agents.startup_script_timeout_seconds, workspace_agents.expanded_directory, workspace_agents.shutdown_script, workspace_agents.shutdown_script_timeout_seconds, workspace_agents.logs_length, workspace_agents.logs_overflowed, workspace_agents.startup_script_behavior, workspace_agents.started_at, workspace_agents.ready_at, workspace_agents.subsystems, + workspaces.id AS workspace_id, + users.id AS owner_id, + users.username AS owner_name, + users.status AS owner_status, + array_cat( + array_append(users.rbac_roles, 'member'), + array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text) + )::text[] as owner_roles, + array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups +FROM users + INNER JOIN + workspaces + ON + workspaces.owner_id = users.id + INNER JOIN + workspace_builds + ON + workspace_builds.workspace_id = workspaces.id + INNER JOIN + workspace_resources + ON + workspace_resources.job_id = workspace_builds.job_id + INNER JOIN + workspace_agents + ON + workspace_agents.resource_id = workspace_resources.id + INNER JOIN -- every user is a member of some org + organization_members + ON + organization_members.user_id = users.id + LEFT JOIN -- as they may not be a member of any groups + group_members + ON + group_members.user_id = users.id WHERE - auth_token = $1 + -- TODO: we can add more conditions here, such as: + -- 1) The user must be active + -- 2) The user must not be deleted + -- 3) The workspace must be running + workspace_agents.auth_token = $1 +GROUP BY + workspace_agents.id, + workspaces.id, + users.id, + organization_members.organization_id, + workspace_builds.build_number ORDER BY - created_at DESC + workspace_builds.build_number DESC +LIMIT 1 ` -func (q *sqlQuerier) GetWorkspaceAgentByAuthToken(ctx context.Context, authToken uuid.UUID) (WorkspaceAgent, error) { - row := q.db.QueryRowContext(ctx, getWorkspaceAgentByAuthToken, authToken) - var i WorkspaceAgent +type GetWorkspaceAgentAndOwnerByAuthTokenRow struct { + WorkspaceAgent WorkspaceAgent `db:"workspace_agent" json:"workspace_agent"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + OwnerName string `db:"owner_name" json:"owner_name"` + OwnerStatus UserStatus `db:"owner_status" json:"owner_status"` + OwnerRoles []string `db:"owner_roles" json:"owner_roles"` + OwnerGroups []string `db:"owner_groups" json:"owner_groups"` +} + +func (q *sqlQuerier) GetWorkspaceAgentAndOwnerByAuthToken(ctx context.Context, authToken uuid.UUID) (GetWorkspaceAgentAndOwnerByAuthTokenRow, error) { + row := q.db.QueryRowContext(ctx, getWorkspaceAgentAndOwnerByAuthToken, authToken) + var i GetWorkspaceAgentAndOwnerByAuthTokenRow err := row.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.Name, - &i.FirstConnectedAt, - &i.LastConnectedAt, - &i.DisconnectedAt, - &i.ResourceID, - &i.AuthToken, - &i.AuthInstanceID, - &i.Architecture, - &i.EnvironmentVariables, - &i.OperatingSystem, - &i.StartupScript, - &i.InstanceMetadata, - &i.ResourceMetadata, - &i.Directory, - &i.Version, - &i.LastConnectedReplicaID, - &i.ConnectionTimeoutSeconds, - &i.TroubleshootingURL, - &i.MOTDFile, - &i.LifecycleState, - &i.StartupScriptTimeoutSeconds, - &i.ExpandedDirectory, - &i.ShutdownScript, - &i.ShutdownScriptTimeoutSeconds, - &i.LogsLength, - &i.LogsOverflowed, - &i.StartupScriptBehavior, - &i.StartedAt, - &i.ReadyAt, - pq.Array(&i.Subsystems), + &i.WorkspaceAgent.ID, + &i.WorkspaceAgent.CreatedAt, + &i.WorkspaceAgent.UpdatedAt, + &i.WorkspaceAgent.Name, + &i.WorkspaceAgent.FirstConnectedAt, + &i.WorkspaceAgent.LastConnectedAt, + &i.WorkspaceAgent.DisconnectedAt, + &i.WorkspaceAgent.ResourceID, + &i.WorkspaceAgent.AuthToken, + &i.WorkspaceAgent.AuthInstanceID, + &i.WorkspaceAgent.Architecture, + &i.WorkspaceAgent.EnvironmentVariables, + &i.WorkspaceAgent.OperatingSystem, + &i.WorkspaceAgent.StartupScript, + &i.WorkspaceAgent.InstanceMetadata, + &i.WorkspaceAgent.ResourceMetadata, + &i.WorkspaceAgent.Directory, + &i.WorkspaceAgent.Version, + &i.WorkspaceAgent.LastConnectedReplicaID, + &i.WorkspaceAgent.ConnectionTimeoutSeconds, + &i.WorkspaceAgent.TroubleshootingURL, + &i.WorkspaceAgent.MOTDFile, + &i.WorkspaceAgent.LifecycleState, + &i.WorkspaceAgent.StartupScriptTimeoutSeconds, + &i.WorkspaceAgent.ExpandedDirectory, + &i.WorkspaceAgent.ShutdownScript, + &i.WorkspaceAgent.ShutdownScriptTimeoutSeconds, + &i.WorkspaceAgent.LogsLength, + &i.WorkspaceAgent.LogsOverflowed, + &i.WorkspaceAgent.StartupScriptBehavior, + &i.WorkspaceAgent.StartedAt, + &i.WorkspaceAgent.ReadyAt, + pq.Array(&i.WorkspaceAgent.Subsystems), + &i.WorkspaceID, + &i.OwnerID, + &i.OwnerName, + &i.OwnerStatus, + pq.Array(&i.OwnerRoles), + pq.Array(&i.OwnerGroups), ) return i, err } diff --git a/coderd/database/queries/workspaceagents.sql b/coderd/database/queries/workspaceagents.sql index dcc15081615e2..9906d367e7bcf 100644 --- a/coderd/database/queries/workspaceagents.sql +++ b/coderd/database/queries/workspaceagents.sql @@ -1,13 +1,3 @@ --- name: GetWorkspaceAgentByAuthToken :one -SELECT - * -FROM - workspace_agents -WHERE - auth_token = $1 -ORDER BY - created_at DESC; - -- name: GetWorkspaceAgentByID :one SELECT * @@ -200,3 +190,56 @@ WHERE WHERE wb.workspace_id = @workspace_id :: uuid ); + +-- name: GetWorkspaceAgentAndOwnerByAuthToken :one +SELECT + sqlc.embed(workspace_agents), + workspaces.id AS workspace_id, + users.id AS owner_id, + users.username AS owner_name, + users.status AS owner_status, + array_cat( + array_append(users.rbac_roles, 'member'), + array_append(ARRAY[]::text[], 'organization-member:' || organization_members.organization_id::text) + )::text[] as owner_roles, + array_agg(COALESCE(group_members.group_id::text, ''))::text[] AS owner_groups +FROM users + INNER JOIN + workspaces + ON + workspaces.owner_id = users.id + INNER JOIN + workspace_builds + ON + workspace_builds.workspace_id = workspaces.id + INNER JOIN + workspace_resources + ON + workspace_resources.job_id = workspace_builds.job_id + INNER JOIN + workspace_agents + ON + workspace_agents.resource_id = workspace_resources.id + INNER JOIN -- every user is a member of some org + organization_members + ON + organization_members.user_id = users.id + LEFT JOIN -- as they may not be a member of any groups + group_members + ON + group_members.user_id = users.id +WHERE + -- TODO: we can add more conditions here, such as: + -- 1) The user must be active + -- 2) The user must not be deleted + -- 3) The workspace must be running + workspace_agents.auth_token = @auth_token +GROUP BY + workspace_agents.id, + workspaces.id, + users.id, + organization_members.organization_id, + workspace_builds.build_number +ORDER BY + workspace_builds.build_number DESC +LIMIT 1; diff --git a/coderd/httpmw/workspaceagent.go b/coderd/httpmw/workspaceagent.go index f05598af5276f..883a54e404c4e 100644 --- a/coderd/httpmw/workspaceagent.go +++ b/coderd/httpmw/workspaceagent.go @@ -74,8 +74,9 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler) }) return } + //nolint:gocritic // System needs to be able to get workspace agents. - agent, err := opts.DB.GetWorkspaceAgentByAuthToken(dbauthz.AsSystemRestricted(ctx), token) + row, err := opts.DB.GetWorkspaceAgentAndOwnerByAuthToken(dbauthz.AsSystemRestricted(ctx), token) if err != nil { if errors.Is(err, sql.ErrNoRows) { optionalWrite(http.StatusUnauthorized, codersdk.Response{ @@ -86,56 +87,23 @@ func ExtractWorkspaceAgent(opts ExtractWorkspaceAgentConfig) func(http.Handler) } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching workspace agent.", + Message: "Internal error checking workspace agent authorization.", Detail: err.Error(), }) return } - //nolint:gocritic // System needs to be able to get workspace agents. - subject, err := getAgentSubject(dbauthz.AsSystemRestricted(ctx), opts.DB, agent) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching workspace agent.", - Detail: err.Error(), - }) - return - } + subject := rbac.Subject{ + ID: row.OwnerID.String(), + Roles: rbac.RoleNames(row.OwnerRoles), + Groups: row.OwnerGroups, + Scope: rbac.WorkspaceAgentScope(row.WorkspaceID, row.OwnerID), + }.WithCachedASTValue() - ctx = context.WithValue(ctx, workspaceAgentContextKey{}, agent) + ctx = context.WithValue(ctx, workspaceAgentContextKey{}, row.WorkspaceAgent) // Also set the dbauthz actor for the request. ctx = dbauthz.As(ctx, subject) next.ServeHTTP(rw, r.WithContext(ctx)) }) } } - -func getAgentSubject(ctx context.Context, db database.Store, agent database.WorkspaceAgent) (rbac.Subject, error) { - // TODO: make a different query that gets the workspace owner and roles along with the agent. - workspace, err := db.GetWorkspaceByAgentID(ctx, agent.ID) - if err != nil { - return rbac.Subject{}, err - } - - user, err := db.GetUserByID(ctx, workspace.OwnerID) - if err != nil { - return rbac.Subject{}, err - } - - roles, err := db.GetAuthorizationUserRoles(ctx, user.ID) - if err != nil { - return rbac.Subject{}, err - } - - // A user that creates a workspace can use this agent auth token and - // impersonate the workspace. So to prevent privilege escalation, the - // subject inherits the roles of the user that owns the workspace. - // We then add a workspace-agent scope to limit the permissions - // to only what the workspace agent needs. - return rbac.Subject{ - ID: user.ID.String(), - Roles: rbac.RoleNames(roles.Roles), - Groups: roles.Groups, - Scope: rbac.WorkspaceAgentScope(workspace.ID, user.ID), - }.WithCachedASTValue(), nil -} diff --git a/coderd/httpmw/workspaceagent_test.go b/coderd/httpmw/workspaceagent_test.go index eb8e06fe9993b..62472fe13513d 100644 --- a/coderd/httpmw/workspaceagent_test.go +++ b/coderd/httpmw/workspaceagent_test.go @@ -10,8 +10,8 @@ import ( "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/codersdk" ) @@ -19,26 +19,19 @@ import ( func TestWorkspaceAgent(t *testing.T) { t.Parallel() - setup := func(db database.Store, token uuid.UUID) *http.Request { - r := httptest.NewRequest("GET", "/", nil) - r.Header.Set(codersdk.SessionTokenHeader, token.String()) - return r - } - t.Run("None", func(t *testing.T) { t.Parallel() - db := dbfake.New() - rtr := chi.NewRouter() - rtr.Use( - httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + db, _ := dbtestutil.NewDB(t) + + req, rtr := setup(t, db, uuid.New(), httpmw.ExtractWorkspaceAgent( + httpmw.ExtractWorkspaceAgentConfig{ DB: db, Optional: false, - }), - ) - rtr.Get("/", nil) - r := setup(db, uuid.New()) + })) + rw := httptest.NewRecorder() - rtr.ServeHTTP(rw, r) + req.Header.Set(codersdk.SessionTokenHeader, uuid.New().String()) + rtr.ServeHTTP(rw, req) res := rw.Result() defer res.Body.Close() @@ -47,42 +40,71 @@ func TestWorkspaceAgent(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() - db := dbfake.New() - var ( - user = dbgen.User(t, db, database.User{}) - workspace = dbgen.Workspace(t, db, database.Workspace{ - OwnerID: user.ID, - }) - job = dbgen.ProvisionerJob(t, db, database.ProvisionerJob{}) - resource = dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ - JobID: job.ID, - }) - _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - WorkspaceID: workspace.ID, - JobID: job.ID, - }) - agent = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: resource.ID, - }) - ) - - rtr := chi.NewRouter() - rtr.Use( - httpmw.ExtractWorkspaceAgent(httpmw.ExtractWorkspaceAgentConfig{ + db, _ := dbtestutil.NewDB(t) + authToken := uuid.New() + req, rtr := setup(t, db, authToken, httpmw.ExtractWorkspaceAgent( + httpmw.ExtractWorkspaceAgentConfig{ DB: db, Optional: false, - }), - ) - rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { - _ = httpmw.WorkspaceAgent(r) - rw.WriteHeader(http.StatusOK) - }) - r := setup(db, agent.AuthToken) + })) + rw := httptest.NewRecorder() - rtr.ServeHTTP(rw, r) + req.Header.Set(codersdk.SessionTokenHeader, authToken.String()) + rtr.ServeHTTP(rw, req) res := rw.Result() - defer res.Body.Close() + t.Cleanup(func() { _ = res.Body.Close() }) require.Equal(t, http.StatusOK, res.StatusCode) }) } + +func setup(t testing.TB, db database.Store, authToken uuid.UUID, mw func(http.Handler) http.Handler) (*http.Request, http.Handler) { + t.Helper() + org := dbgen.Organization(t, db, database.Organization{}) + user := dbgen.User(t, db, database.User{ + Status: database.UserStatusActive, + }) + _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + templateVersion := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + template := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + ActiveVersionID: templateVersion.ID, + CreatedBy: user.ID, + }) + workspace := dbgen.Workspace(t, db, database.Workspace{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: template.ID, + }) + job := dbgen.ProvisionerJob(t, db, database.ProvisionerJob{ + OrganizationID: org.ID, + }) + resource := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + JobID: job.ID, + }) + _ = dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspace.ID, + JobID: job.ID, + TemplateVersionID: templateVersion.ID, + }) + _ = dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: resource.ID, + AuthToken: authToken, + }) + + req := httptest.NewRequest("GET", "/", nil) + rtr := chi.NewRouter() + rtr.Use(mw) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.WorkspaceAgent(r) + rw.WriteHeader(http.StatusOK) + }) + + return req, rtr +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 84498a3dd7f84..d1b379d3d74e7 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -406,7 +406,9 @@ func TestWorkspaceAgentListen(t *testing.T) { _, err = agentClient.Listen(ctx) require.Error(t, err) - require.ErrorContains(t, err, "build is outdated") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusForbidden, sdkErr.StatusCode()) }) }