Skip to content

Commit b6703b1

Browse files
authored
feat: Add external provisioner daemons (#4935)
* Start to port over provisioner daemons PR * Move to Enterprise * Begin adding tests for external registration * Move provisioner daemons query to enterprise * Move around provisioner daemons schema * Add tags to provisioner daemons * make gen * Add user local provisioner daemons * Add provisioner daemons * Add feature for external daemons * Add command to start a provisioner daemon * Add provisioner tags to template push and create * Rename migration files * Fix tests * Fix entitlements test * PR comments * Update migration * Fix FE types
1 parent 66d20ca commit b6703b1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1094
-371
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"codersdk",
1818
"cronstrue",
1919
"databasefake",
20+
"dbtype",
2021
"DERP",
2122
"derphttp",
2223
"derpmap",

cli/deployment/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ func newConfig() *codersdk.DeploymentConfig {
143143
Name: "Cache Directory",
144144
Usage: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.",
145145
Flag: "cache-dir",
146-
Default: defaultCacheDir(),
146+
Default: DefaultCacheDir(),
147147
},
148148
InMemoryDatabase: &codersdk.DeploymentConfigField[bool]{
149149
Name: "In Memory Database",
@@ -672,7 +672,7 @@ func formatEnv(key string) string {
672672
return "CODER_" + strings.ToUpper(strings.NewReplacer("-", "_", ".", "_").Replace(key))
673673
}
674674

675-
func defaultCacheDir() string {
675+
func DefaultCacheDir() string {
676676
defaultCacheDir, err := os.UserCacheDir()
677677
if err != nil {
678678
defaultCacheDir = os.TempDir()

cli/gitaskpass.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func gitAskpass() *cobra.Command {
2626
RunE: func(cmd *cobra.Command, args []string) error {
2727
ctx := cmd.Context()
2828

29-
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
29+
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
3030
defer stop()
3131

3232
user, host, err := gitauth.ParseAskpass(args[0])

cli/gitssh.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func gitssh() *cobra.Command {
2929

3030
// Catch interrupt signals to ensure the temporary private
3131
// key file is cleaned up on most cases.
32-
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
32+
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
3333
defer stop()
3434

3535
// Early check so errors are reported immediately.

cli/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
108108
//
109109
// To get out of a graceful shutdown, the user can send
110110
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
111-
notifyCtx, notifyStop := signal.NotifyContext(ctx, interruptSignals...)
111+
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
112112
defer notifyStop()
113113

114114
// Clean up idle connections at the end, e.g.
@@ -946,7 +946,7 @@ func newProvisionerDaemon(
946946
return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
947947
// This debounces calls to listen every second. Read the comment
948948
// in provisionerdserver.go to learn more!
949-
return coderAPI.ListenProvisionerDaemon(ctx, time.Second)
949+
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, time.Second)
950950
}, &provisionerd.Options{
951951
Logger: logger,
952952
PollInterval: 500 * time.Millisecond,

cli/signal_unix.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import (
77
"syscall"
88
)
99

10-
var interruptSignals = []os.Signal{
10+
var InterruptSignals = []os.Signal{
1111
os.Interrupt,
1212
syscall.SIGTERM,
1313
syscall.SIGHUP,

cli/signal_windows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ import (
66
"os"
77
)
88

9-
var interruptSignals = []os.Signal{os.Interrupt}
9+
var InterruptSignals = []os.Signal{os.Interrupt}

cli/templatecreate.go

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ import (
2424

2525
func templateCreate() *cobra.Command {
2626
var (
27-
directory string
28-
provisioner string
29-
parameterFile string
30-
defaultTTL time.Duration
27+
directory string
28+
provisioner string
29+
provisionerTags []string
30+
parameterFile string
31+
defaultTTL time.Duration
3132
)
3233
cmd := &cobra.Command{
3334
Use: "create [name]",
@@ -87,12 +88,18 @@ func templateCreate() *cobra.Command {
8788
}
8889
spin.Stop()
8990

91+
tags, err := ParseProvisionerTags(provisionerTags)
92+
if err != nil {
93+
return err
94+
}
95+
9096
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{
91-
Client: client,
92-
Organization: organization,
93-
Provisioner: database.ProvisionerType(provisioner),
94-
FileID: resp.ID,
95-
ParameterFile: parameterFile,
97+
Client: client,
98+
Organization: organization,
99+
Provisioner: database.ProvisionerType(provisioner),
100+
FileID: resp.ID,
101+
ParameterFile: parameterFile,
102+
ProvisionerTags: tags,
96103
})
97104
if err != nil {
98105
return err
@@ -131,6 +138,7 @@ func templateCreate() *cobra.Command {
131138
cmd.Flags().StringVarP(&directory, "directory", "d", currentDirectory, "Specify the directory to create from")
132139
cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend")
133140
cmd.Flags().StringVarP(&parameterFile, "parameter-file", "", "", "Specify a file path with parameter values.")
141+
cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.")
134142
cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 24*time.Hour, "Specify a default TTL for workspaces created from this template.")
135143
// This is for testing!
136144
err := cmd.Flags().MarkHidden("test.provisioner")
@@ -154,6 +162,7 @@ type createValidTemplateVersionArgs struct {
154162
// before prompting the user. Set to false to always prompt for param
155163
// values.
156164
ReuseParameters bool
165+
ProvisionerTags map[string]string
157166
}
158167

159168
func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) {
@@ -165,6 +174,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
165174
FileID: args.FileID,
166175
Provisioner: codersdk.ProvisionerType(args.Provisioner),
167176
ParameterValues: parameters,
177+
ProvisionerTags: args.ProvisionerTags,
168178
}
169179
if args.Template != nil {
170180
req.TemplateID = args.Template.ID
@@ -334,3 +344,15 @@ func prettyDirectoryPath(dir string) string {
334344
}
335345
return pretty
336346
}
347+
348+
func ParseProvisionerTags(rawTags []string) (map[string]string, error) {
349+
tags := map[string]string{}
350+
for _, rawTag := range rawTags {
351+
parts := strings.SplitN(rawTag, "=", 2)
352+
if len(parts) < 2 {
353+
return nil, xerrors.Errorf("invalid tag format for %q. must be key=value", rawTag)
354+
}
355+
tags[parts[0]] = parts[1]
356+
}
357+
return tags, nil
358+
}

cli/templatepush.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ import (
1818

1919
func templatePush() *cobra.Command {
2020
var (
21-
directory string
22-
versionName string
23-
provisioner string
24-
parameterFile string
25-
alwaysPrompt bool
21+
directory string
22+
versionName string
23+
provisioner string
24+
parameterFile string
25+
alwaysPrompt bool
26+
provisionerTags []string
2627
)
2728

2829
cmd := &cobra.Command{
@@ -75,6 +76,11 @@ func templatePush() *cobra.Command {
7576
}
7677
spin.Stop()
7778

79+
tags, err := ParseProvisionerTags(provisionerTags)
80+
if err != nil {
81+
return err
82+
}
83+
7884
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{
7985
Name: versionName,
8086
Client: client,
@@ -84,6 +90,7 @@ func templatePush() *cobra.Command {
8490
ParameterFile: parameterFile,
8591
Template: &template,
8692
ReuseParameters: !alwaysPrompt,
93+
ProvisionerTags: tags,
8794
})
8895
if err != nil {
8996
return err

coderd/autobuild/executor/lifecycle_executor.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa
278278
Type: database.ProvisionerJobTypeWorkspaceBuild,
279279
StorageMethod: priorJob.StorageMethod,
280280
FileID: priorJob.FileID,
281+
Tags: priorJob.Tags,
281282
Input: input,
282283
})
283284
if err != nil {

coderd/coderd.go

Lines changed: 84 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package coderd
22

33
import (
4+
"context"
45
"crypto/tls"
56
"crypto/x509"
7+
"encoding/json"
68
"fmt"
79
"io"
810
"net/http"
@@ -18,10 +20,13 @@ import (
1820
"github.com/go-chi/chi/v5/middleware"
1921
"github.com/google/uuid"
2022
"github.com/klauspost/compress/zstd"
23+
"github.com/moby/moby/pkg/namesgenerator"
2124
"github.com/prometheus/client_golang/prometheus"
2225
"go.opentelemetry.io/otel/trace"
2326
"golang.org/x/xerrors"
2427
"google.golang.org/api/idtoken"
28+
"storj.io/drpc/drpcmux"
29+
"storj.io/drpc/drpcserver"
2530
"tailscale.com/derp"
2631
"tailscale.com/derp/derphttp"
2732
"tailscale.com/tailcfg"
@@ -32,17 +37,20 @@ import (
3237
"github.com/coder/coder/coderd/audit"
3338
"github.com/coder/coder/coderd/awsidentity"
3439
"github.com/coder/coder/coderd/database"
40+
"github.com/coder/coder/coderd/database/dbtype"
3541
"github.com/coder/coder/coderd/gitauth"
3642
"github.com/coder/coder/coderd/gitsshkey"
3743
"github.com/coder/coder/coderd/httpapi"
3844
"github.com/coder/coder/coderd/httpmw"
3945
"github.com/coder/coder/coderd/metricscache"
46+
"github.com/coder/coder/coderd/provisionerdserver"
4047
"github.com/coder/coder/coderd/rbac"
4148
"github.com/coder/coder/coderd/telemetry"
4249
"github.com/coder/coder/coderd/tracing"
4350
"github.com/coder/coder/coderd/wsconncache"
4451
"github.com/coder/coder/codersdk"
4552
"github.com/coder/coder/provisionerd/proto"
53+
"github.com/coder/coder/provisionersdk"
4654
"github.com/coder/coder/site"
4755
"github.com/coder/coder/tailnet"
4856
)
@@ -323,13 +331,6 @@ func New(options *Options) *API {
323331
r.Get("/{fileID}", api.fileByID)
324332
r.Post("/", api.postFile)
325333
})
326-
327-
r.Route("/provisionerdaemons", func(r chi.Router) {
328-
r.Use(
329-
apiKeyMiddleware,
330-
)
331-
r.Get("/", api.provisionerDaemons)
332-
})
333334
r.Route("/organizations", func(r chi.Router) {
334335
r.Use(
335336
apiKeyMiddleware,
@@ -595,18 +596,20 @@ type API struct {
595596
// RootHandler serves "/"
596597
RootHandler chi.Router
597598

598-
metricsCache *metricscache.Cache
599-
siteHandler http.Handler
600-
websocketWaitMutex sync.Mutex
601-
websocketWaitGroup sync.WaitGroup
599+
metricsCache *metricscache.Cache
600+
siteHandler http.Handler
601+
602+
WebsocketWaitMutex sync.Mutex
603+
WebsocketWaitGroup sync.WaitGroup
604+
602605
workspaceAgentCache *wsconncache.Cache
603606
}
604607

605608
// Close waits for all WebSocket connections to drain before returning.
606609
func (api *API) Close() error {
607-
api.websocketWaitMutex.Lock()
608-
api.websocketWaitGroup.Wait()
609-
api.websocketWaitMutex.Unlock()
610+
api.WebsocketWaitMutex.Lock()
611+
api.WebsocketWaitGroup.Wait()
612+
api.WebsocketWaitMutex.Unlock()
610613

611614
api.metricsCache.Close()
612615
coordinator := api.TailnetCoordinator.Load()
@@ -635,3 +638,70 @@ func compressHandler(h http.Handler) http.Handler {
635638

636639
return cmp.Handler(h)
637640
}
641+
642+
// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
643+
// in the same process.
644+
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
645+
clientSession, serverSession := provisionersdk.TransportPipe()
646+
defer func() {
647+
if err != nil {
648+
_ = clientSession.Close()
649+
_ = serverSession.Close()
650+
}
651+
}()
652+
653+
name := namesgenerator.GetRandomName(1)
654+
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
655+
ID: uuid.New(),
656+
CreatedAt: database.Now(),
657+
Name: name,
658+
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
659+
Tags: dbtype.StringMap{
660+
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
661+
},
662+
})
663+
if err != nil {
664+
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
665+
}
666+
667+
tags, err := json.Marshal(daemon.Tags)
668+
if err != nil {
669+
return nil, xerrors.Errorf("marshal tags: %w", err)
670+
}
671+
672+
mux := drpcmux.New()
673+
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
674+
AccessURL: api.AccessURL,
675+
ID: daemon.ID,
676+
Database: api.Database,
677+
Pubsub: api.Pubsub,
678+
Provisioners: daemon.Provisioners,
679+
Telemetry: api.Telemetry,
680+
Tags: tags,
681+
QuotaCommitter: &api.QuotaCommitter,
682+
AcquireJobDebounce: debounce,
683+
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
684+
})
685+
if err != nil {
686+
return nil, err
687+
}
688+
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
689+
Log: func(err error) {
690+
if xerrors.Is(err, io.EOF) {
691+
return
692+
}
693+
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
694+
},
695+
})
696+
go func() {
697+
err := server.Serve(ctx, serverSession)
698+
if err != nil && !xerrors.Is(err, io.EOF) {
699+
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
700+
}
701+
// close the sessions so we don't leak goroutines serving them.
702+
_ = clientSession.Close()
703+
_ = serverSession.Close()
704+
}()
705+
706+
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
707+
}

coderd/coderdtest/authorize.go

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"github.com/coder/coder/codersdk"
2020
"github.com/coder/coder/provisioner/echo"
2121
"github.com/coder/coder/provisionersdk/proto"
22-
"github.com/coder/coder/testutil"
2322
)
2423

2524
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
@@ -204,11 +203,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
204203
AssertAction: rbac.ActionRead,
205204
AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID),
206205
},
207-
"GET:/api/v2/provisionerdaemons": {
208-
StatusCode: http.StatusOK,
209-
AssertObject: rbac.ResourceProvisionerDaemon,
210-
},
211-
212206
"POST:/api/v2/parameters/{scope}/{id}": {
213207
AssertAction: rbac.ActionUpdate,
214208
AssertObject: rbac.ResourceTemplate,
@@ -303,16 +297,6 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
303297
if !ok {
304298
t.Fail()
305299
}
306-
// The provisioner will call to coderd and register itself. This is async,
307-
// so we wait for it to occur.
308-
require.Eventually(t, func() bool {
309-
provisionerds, err := client.ProvisionerDaemons(ctx)
310-
return assert.NoError(t, err) && len(provisionerds) > 0
311-
}, testutil.WaitLong, testutil.IntervalSlow)
312-
313-
provisionerds, err := client.ProvisionerDaemons(ctx)
314-
require.NoError(t, err, "fetch provisioners")
315-
require.Len(t, provisionerds, 1)
316300

317301
organization, err := client.Organization(ctx, admin.OrganizationID)
318302
require.NoError(t, err, "fetch org")

0 commit comments

Comments
 (0)