Skip to content

Commit e97287f

Browse files
committed
Add user local provisioner daemons
1 parent b1ce65b commit e97287f

18 files changed

+137
-19
lines changed

coderd/autobuild/executor/lifecycle_executor.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa
277277
Type: database.ProvisionerJobTypeWorkspaceBuild,
278278
StorageMethod: priorJob.StorageMethod,
279279
FileID: priorJob.FileID,
280+
Tags: priorJob.Tags,
280281
Input: input,
281282
})
282283
if err != nil {

coderd/coderd.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/tls"
66
"crypto/x509"
7+
"encoding/json"
78
"fmt"
89
"io"
910
"net/http"
@@ -36,6 +37,7 @@ import (
3637
"github.com/coder/coder/coderd/audit"
3738
"github.com/coder/coder/coderd/awsidentity"
3839
"github.com/coder/coder/coderd/database"
40+
"github.com/coder/coder/coderd/database/dbtype"
3941
"github.com/coder/coder/coderd/gitauth"
4042
"github.com/coder/coder/coderd/gitsshkey"
4143
"github.com/coder/coder/coderd/httpapi"
@@ -659,11 +661,19 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context) (client pro
659661
CreatedAt: database.Now(),
660662
Name: name,
661663
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
664+
Tags: dbtype.Map{
665+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
666+
},
662667
})
663668
if err != nil {
664669
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
665670
}
666671

672+
tags, err := json.Marshal(daemon.Tags)
673+
if err != nil {
674+
return nil, xerrors.Errorf("marshal tags: %w", err)
675+
}
676+
667677
mux := drpcmux.New()
668678
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
669679
AccessURL: api.AccessURL,
@@ -672,6 +682,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context) (client pro
672682
Pubsub: api.Pubsub,
673683
Provisioners: daemon.Provisioners,
674684
Telemetry: api.Telemetry,
685+
Tags: tags,
675686
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
676687
})
677688
if err != nil {

coderd/database/databasefake/databasefake.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package databasefake
33
import (
44
"context"
55
"database/sql"
6-
"reflect"
6+
"encoding/json"
77
"sort"
88
"strings"
99
"sync"
@@ -147,7 +147,27 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
147147
if !found {
148148
continue
149149
}
150-
if !reflect.DeepEqual(arg.Tags, provisionerJob.Tags) {
150+
tags := map[string]string{}
151+
if arg.Tags != nil {
152+
err := json.Unmarshal(arg.Tags, &tags)
153+
if err != nil {
154+
return provisionerJob, xerrors.Errorf("unmarshal: %w", err)
155+
}
156+
}
157+
158+
missing := false
159+
for key, value := range provisionerJob.Tags {
160+
provided, found := tags[key]
161+
if !found {
162+
missing = true
163+
break
164+
}
165+
if provided != value {
166+
missing = true
167+
break
168+
}
169+
}
170+
if missing {
151171
continue
152172
}
153173
provisionerJob.StartedAt = arg.StartedAt
@@ -2291,6 +2311,7 @@ func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser
22912311
FileID: arg.FileID,
22922312
Type: arg.Type,
22932313
Input: arg.Input,
2314+
Tags: arg.Tags,
22942315
}
22952316
q.provisionerJobs = append(q.provisionerJobs, job)
22962317
return job, nil

coderd/database/dump.sql

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
ALTER TABLE provisioner_daemons ADD COLUMN tags jsonb NOT NULL DEFAULT '{}';
2-
ALTER TABLE provisioner_jobs ADD COLUMN tags jsonb NOT NULL DEFAULT '{}';
2+
3+
-- We must add the organization scope by default, otherwise pending jobs
4+
-- could be provisioned on new daemons that don't match the tags.
5+
ALTER TABLE provisioner_jobs ADD COLUMN tags jsonb NOT NULL DEFAULT '{"scope":"organization"}';

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ type Server struct {
3232
ID uuid.UUID
3333
Logger slog.Logger
3434
Provisioners []database.ProvisionerType
35+
Tags json.RawMessage
3536
Database database.Store
3637
Pubsub database.Pubsub
3738
Telemetry telemetry.Reporter
@@ -50,6 +51,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
5051
Valid: true,
5152
},
5253
Types: server.Provisioners,
54+
Tags: server.Tags,
5355
})
5456
if errors.Is(err, sql.ErrNoRows) {
5557
// The provisioner daemon assumes no jobs are available if
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package provisionerdserver
2+
3+
import "github.com/google/uuid"
4+
5+
const (
6+
TagScope = "scope"
7+
TagOwner = "owner"
8+
9+
ScopeUser = "user"
10+
ScopeOrganization = "organization"
11+
)
12+
13+
// MutateTags adjusts the "owner" tag dependent on the "scope".
14+
// If the scope is "user", the "owner" is changed to the user ID.
15+
// This is for user-scoped provisioner daemons, where users should
16+
// own their own operations.
17+
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
18+
if tags == nil {
19+
tags = map[string]string{}
20+
}
21+
_, ok := tags[TagScope]
22+
if !ok {
23+
tags[TagScope] = ScopeOrganization
24+
}
25+
switch tags[TagScope] {
26+
case ScopeUser:
27+
tags[TagOwner] = userID.String()
28+
case ScopeOrganization:
29+
default:
30+
tags[TagScope] = ScopeOrganization
31+
}
32+
return tags
33+
}

coderd/provisionerjobs.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
311311
CreatedAt: provisionerJob.CreatedAt,
312312
Error: provisionerJob.Error.String,
313313
FileID: provisionerJob.FileID,
314+
Tags: provisionerJob.Tags,
314315
}
315316
// Applying values optional to the struct.
316317
if provisionerJob.StartedAt.Valid {

coderd/templateversions.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques
288288
FileID: job.FileID,
289289
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
290290
Input: input,
291+
// Copy tags from the previous run.
292+
Tags: job.Tags,
291293
})
292294
if err != nil {
293295
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -717,6 +719,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
717719
return
718720
}
719721

722+
// Ensures the "owner" is properly applied.
723+
tags := provisionerdserver.MutateTags(apiKey.UserID, req.ProvisionerTags)
724+
720725
file, err := api.Database.GetFileByID(ctx, req.FileID)
721726
if errors.Is(err, sql.ErrNoRows) {
722727
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@@ -815,6 +820,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
815820
FileID: file.ID,
816821
Type: database.ProvisionerJobTypeTemplateVersionImport,
817822
Input: []byte{'{', '}'},
823+
Tags: tags,
818824
})
819825
if err != nil {
820826
return xerrors.Errorf("insert provisioner job: %w", err)

coderd/templateversions_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/coder/coder/coderd/audit"
1414
"github.com/coder/coder/coderd/coderdtest"
1515
"github.com/coder/coder/coderd/database"
16+
"github.com/coder/coder/coderd/provisionerdserver"
1617
"github.com/coder/coder/codersdk"
1718
"github.com/coder/coder/provisioner/echo"
1819
"github.com/coder/coder/provisionersdk/proto"
@@ -122,6 +123,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) {
122123
})
123124
require.NoError(t, err)
124125
require.Equal(t, "bananas", version.Name)
126+
require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope])
125127

126128
require.Len(t, auditor.AuditLogs, 1)
127129
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[0].Action)

coderd/workspacebuilds.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
428428
return
429429
}
430430

431+
tags := provisionerdserver.MutateTags(workspace.OwnerID, templateVersionJob.Tags)
432+
431433
// Store prior build number to compute new build number
432434
var priorBuildNum int32
433435
priorHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
@@ -513,6 +515,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
513515
StorageMethod: templateVersionJob.StorageMethod,
514516
FileID: templateVersionJob.FileID,
515517
Input: input,
518+
Tags: tags,
516519
})
517520
if err != nil {
518521
return xerrors.Errorf("insert provisioner job: %w", err)

coderd/workspaces.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,8 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
428428
return
429429
}
430430

431+
tags := provisionerdserver.MutateTags(user.ID, templateVersionJob.Tags)
432+
431433
var (
432434
provisionerJob database.ProvisionerJob
433435
workspaceBuild database.WorkspaceBuild
@@ -490,6 +492,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
490492
StorageMethod: templateVersionJob.StorageMethod,
491493
FileID: templateVersionJob.FileID,
492494
Input: input,
495+
Tags: tags,
493496
})
494497
if err != nil {
495498
return xerrors.Errorf("insert provisioner job: %w", err)

codersdk/organizations.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ type Organization struct {
3636
type CreateTemplateVersionRequest struct {
3737
Name string `json:"name,omitempty" validate:"omitempty,template_name"`
3838
// TemplateID optionally associates a version with a template.
39-
TemplateID uuid.UUID `json:"template_id,omitempty"`
39+
TemplateID uuid.UUID `json:"template_id,omitempty"`
40+
StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"`
41+
FileID uuid.UUID `json:"file_id" validate:"required"`
42+
Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"`
43+
ProvisionerTags map[string]string `json:"tags"`
4044

41-
StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"`
42-
FileID uuid.UUID `json:"file_id" validate:"required"`
43-
Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"`
4445
// ParameterValues allows for additional parameters to be provided
4546
// during the dry-run provision stage.
4647
ParameterValues []CreateParameterRequest `json:"parameter_values,omitempty"`

codersdk/provisionerdaemons.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ type ProvisionerJob struct {
7676
Status ProvisionerJobStatus `json:"status"`
7777
WorkerID *uuid.UUID `json:"worker_id,omitempty"`
7878
FileID uuid.UUID `json:"file_id"`
79+
Tags map[string]string `json:"tags"`
7980
}
8081

8182
type ProvisionerJobLog struct {
@@ -166,10 +167,6 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
166167
}), nil
167168
}
168169

169-
type CreateProvisionerDaemonRequest struct {
170-
Name string `json:"name" validate:"required"`
171-
}
172-
173170
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
174171
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
175172
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))

enterprise/coderd/coderd.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ func New(ctx context.Context, options *Options) (*API, error) {
9191
})
9292
})
9393
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
94-
r.Use(apiKeyMiddleware)
94+
r.Use(
95+
apiKeyMiddleware,
96+
httpmw.ExtractOrganizationParam(api.Database),
97+
)
9598
r.Get("/", api.provisionerDaemons)
9699
r.Get("/serve", api.provisionerDaemonServe)
97100
})

enterprise/coderd/coderdenttest/coderdenttest_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
4747
a.URLParams["{groupName}"] = group.Name
4848

4949
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
50+
skipRoutes["GET:/api/v2/organizations/{organization}/provisionerdaemons/serve"] = "This route checks for RBAC dependent on input parameters!"
51+
5052
assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{
5153
NoAuthorize: true,
5254
}
@@ -84,6 +86,14 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
8486
AssertAction: rbac.ActionRead,
8587
AssertObject: groupObj,
8688
}
89+
assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{
90+
AssertAction: rbac.ActionRead,
91+
AssertObject: rbac.ResourceProvisionerDaemon,
92+
}
93+
assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{
94+
AssertAction: rbac.ActionRead,
95+
AssertObject: rbac.ResourceProvisionerDaemon,
96+
}
8797
assertRoute["GET:/api/v2/groups/{group}"] = coderdtest.RouteCheck{
8898
AssertAction: rbac.ActionRead,
8999
AssertObject: groupObj,

enterprise/coderd/provisionerdaemons.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package coderd
22

33
import (
44
"database/sql"
5+
"encoding/json"
56
"errors"
67
"fmt"
78
"io"
@@ -30,6 +31,11 @@ import (
3031

3132
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
3233
ctx := r.Context()
34+
org := httpmw.OrganizationParam(r)
35+
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceProvisionerDaemon.InOrg(org.ID)) {
36+
httpapi.Forbidden(rw)
37+
return
38+
}
3339
daemons, err := api.Database.GetProvisionerDaemons(ctx)
3440
if errors.Is(err, sql.ErrNoRows) {
3541
err = nil
@@ -97,8 +103,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
97103
// for jobs that they own, but only authorized users can create
98104
// globally scoped provisioners that attach to all jobs.
99105
apiKey := httpmw.APIKey(r)
100-
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
101-
tags["owner"] = apiKey.UserID.String()
106+
tags = provisionerdserver.MutateTags(apiKey.UserID, tags)
107+
108+
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
109+
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
110+
httpapi.Write(r.Context(), rw, http.StatusUnauthorized, codersdk.Response{
111+
Message: "You aren't allowed to create provisioner daemons for the organization.",
112+
})
113+
return
114+
}
102115
}
103116

104117
name := namesgenerator.GetRandomName(1)
@@ -117,6 +130,15 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
117130
return
118131
}
119132

133+
rawTags, err := json.Marshal(daemon.Tags)
134+
if err != nil {
135+
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
136+
Message: "Internal error marshaling daemon tags.",
137+
Detail: err.Error(),
138+
})
139+
return
140+
}
141+
120142
api.AGPL.WebsocketWaitMutex.Lock()
121143
api.AGPL.WebsocketWaitGroup.Add(1)
122144
api.AGPL.WebsocketWaitMutex.Unlock()
@@ -155,6 +177,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
155177
Provisioners: daemon.Provisioners,
156178
Telemetry: api.Telemetry,
157179
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
180+
Tags: rawTags,
158181
})
159182
if err != nil {
160183
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))

0 commit comments

Comments
 (0)