Skip to content

Commit cb4989c

Browse files
authored
feat: add PSK for external provisionerd auth (#8877)
Signed-off-by: Spike Curtis <spike@coder.com>
1 parent b77d6b2 commit cb4989c

24 files changed

+429
-61
lines changed

cli/root.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,15 @@ func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
494494
// InitClient sets client to a new client.
495495
// It reads from global configuration files if flags are not set.
496496
func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc {
497+
return r.initClientInternal(client, false)
498+
}
499+
500+
func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) clibase.MiddlewareFunc {
501+
return r.initClientInternal(client, true)
502+
}
503+
504+
// nolint: revive
505+
func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) clibase.MiddlewareFunc {
497506
if client == nil {
498507
panic("client is nil")
499508
}
@@ -508,7 +517,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc {
508517
rawURL, err := conf.URL().Read()
509518
// If the configuration files are absent, the user is logged out
510519
if os.IsNotExist(err) {
511-
return (errUnauthenticated)
520+
return errUnauthenticated
512521
}
513522
if err != nil {
514523
return err
@@ -524,9 +533,10 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc {
524533
r.token, err = conf.Session().Read()
525534
// If the configuration files are absent, the user is logged out
526535
if os.IsNotExist(err) {
527-
return (errUnauthenticated)
528-
}
529-
if err != nil {
536+
if !allowTokenMissing {
537+
return errUnauthenticated
538+
}
539+
} else if err != nil {
530540
return err
531541
}
532542
}

cli/testdata/coder_server_--help.golden

+4
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ updating, and deleting workspace resources.
373373
--provisioner-daemon-poll-jitter duration, $CODER_PROVISIONER_DAEMON_POLL_JITTER (default: 100ms)
374374
Random jitter added to the poll interval.
375375

376+
--provisioner-daemon-psk string, $CODER_PROVISIONER_DAEMON_PSK
377+
Pre-shared key to authenticate external provisioner daemons to Coder
378+
server.
379+
376380
--provisioner-daemons int, $CODER_PROVISIONER_DAEMONS (default: 3)
377381
Number of provisioner daemons to create on start. If builds are stuck
378382
in queued state for a long time, consider increasing this.

cli/testdata/server-config.yaml.golden

+3
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@ provisioning:
327327
# Time to force cancel provisioning tasks that are stuck.
328328
# (default: 10m0s, type: duration)
329329
forceCancelInterval: 10m0s
330+
# Pre-shared key to authenticate external provisioner daemons to Coder server.
331+
# (default: <unset>, type: string)
332+
daemonPSK: ""
330333
# Enable one or more experiments. These are not ready for production. Separate
331334
# multiple experiments with commas, or enter '*' to opt-in to all available
332335
# experiments.

coderd/apidoc/docs.go

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/coderdtest/coderdtest.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,11 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui
497497
}()
498498

499499
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
500-
return client.ServeProvisionerDaemon(ctx, org, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, tags)
500+
return client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
501+
Organization: org,
502+
Provisioners: []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho},
503+
Tags: tags,
504+
})
501505
}, &provisionerd.Options{
502506
Filesystem: fs,
503507
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),

codersdk/client.go

+3
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ const (
7171
// command that was invoked to produce the request. It is for internal use
7272
// only.
7373
CLITelemetryHeader = "Coder-CLI-Telemetry"
74+
75+
// ProvisionerDaemonPSK contains the authentication pre-shared key for an external provisioner daemon
76+
ProvisionerDaemonPSK = "Coder-Provisioner-Daemon-PSK"
7477
)
7578

7679
// loggableMimeTypes is a list of MIME types that are safe to log

codersdk/deployment.go

+10
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ type ProvisionerConfig struct {
328328
DaemonPollInterval clibase.Duration `json:"daemon_poll_interval" typescript:",notnull"`
329329
DaemonPollJitter clibase.Duration `json:"daemon_poll_jitter" typescript:",notnull"`
330330
ForceCancelInterval clibase.Duration `json:"force_cancel_interval" typescript:",notnull"`
331+
DaemonPSK clibase.String `json:"daemon_psk" typescript:",notnull"`
331332
}
332333

333334
type RateLimitConfig struct {
@@ -1230,6 +1231,15 @@ when required by your organization's security policy.`,
12301231
Group: &deploymentGroupProvisioning,
12311232
YAML: "forceCancelInterval",
12321233
},
1234+
{
1235+
Name: "Provisioner Daemon Pre-shared Key (PSK)",
1236+
Description: "Pre-shared key to authenticate external provisioner daemons to Coder server.",
1237+
Flag: "provisioner-daemon-psk",
1238+
Env: "CODER_PROVISIONER_DAEMON_PSK",
1239+
Value: &c.Provisioner.DaemonPSK,
1240+
Group: &deploymentGroupProvisioning,
1241+
YAML: "daemonPSK",
1242+
},
12331243
// RateLimit settings
12341244
{
12351245
Name: "Disable All Rate Limits",

codersdk/organizations.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,11 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization,
149149
return organization, json.NewDecoder(res.Body).Decode(&organization)
150150
}
151151

152-
// ProvisionerDaemonsByOrganization returns provisioner daemons available for an organization.
152+
// ProvisionerDaemons returns provisioner daemons available.
153153
func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) {
154154
res, err := c.Request(ctx, http.MethodGet,
155-
"/api/v2/provisionerdaemons",
155+
// TODO: the organization path parameter is currently ignored.
156+
"/api/v2/organizations/default/provisionerdaemons",
156157
nil,
157158
)
158159
if err != nil {

codersdk/provisionerdaemons.go

+37-14
Original file line numberDiff line numberDiff line change
@@ -164,38 +164,61 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
164164
}), nil
165165
}
166166

167-
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon
167+
// ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with
168+
// @typescript-ignore ServeProvisionerDaemonRequest
169+
type ServeProvisionerDaemonRequest struct {
170+
// Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations
171+
// and so the organization ID is optional.
172+
Organization uuid.UUID `json:"organization" format:"uuid"`
173+
// Provisioners is a list of provisioner types hosted by the provisioner daemon
174+
Provisioners []ProvisionerType `json:"provisioners"`
175+
// Tags is a map of key-value pairs that tag the jobs this provisioner daemon can handle
176+
Tags map[string]string `json:"tags"`
177+
// PreSharedKey is an authentication key to use on the API instead of the normal session token from the client.
178+
PreSharedKey string `json:"pre_shared_key"`
179+
}
180+
181+
// ServeProvisionerDaemon returns the gRPC service for a provisioner daemon
168182
// implementation. The context is during dial, not during the lifetime of the
169183
// client. Client should be closed after use.
170-
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
171-
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
184+
func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) {
185+
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization))
172186
if err != nil {
173187
return nil, xerrors.Errorf("parse url: %w", err)
174188
}
175189
query := serverURL.Query()
176-
for _, provisioner := range provisioners {
190+
for _, provisioner := range req.Provisioners {
177191
query.Add("provisioner", string(provisioner))
178192
}
179-
for key, value := range tags {
193+
for key, value := range req.Tags {
180194
query.Add("tag", fmt.Sprintf("%s=%s", key, value))
181195
}
182196
serverURL.RawQuery = query.Encode()
183-
jar, err := cookiejar.New(nil)
184-
if err != nil {
185-
return nil, xerrors.Errorf("create cookie jar: %w", err)
186-
}
187-
jar.SetCookies(serverURL, []*http.Cookie{{
188-
Name: SessionTokenCookie,
189-
Value: c.SessionToken(),
190-
}})
191197
httpClient := &http.Client{
192-
Jar: jar,
193198
Transport: c.HTTPClient.Transport,
194199
}
200+
headers := http.Header{}
201+
202+
if req.PreSharedKey == "" {
203+
// use session token if we don't have a PSK.
204+
jar, err := cookiejar.New(nil)
205+
if err != nil {
206+
return nil, xerrors.Errorf("create cookie jar: %w", err)
207+
}
208+
jar.SetCookies(serverURL, []*http.Cookie{{
209+
Name: SessionTokenCookie,
210+
Value: c.SessionToken(),
211+
}})
212+
httpClient.Jar = jar
213+
} else {
214+
headers.Set(ProvisionerDaemonPSK, req.PreSharedKey)
215+
}
216+
195217
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
196218
HTTPClient: httpClient,
197219
// Need to disable compression to avoid a data-race.
198220
CompressionMode: websocket.CompressionDisabled,
221+
HTTPHeader: headers,
199222
})
200223
if err != nil {
201224
if res == nil {

docs/api/general.md

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/api/schemas.md

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/cli/provisionerd_start.md

+9
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/cli/server.md

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/cli/provisionerdaemons.go

+15-9
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
4444
rawTags []string
4545
pollInterval time.Duration
4646
pollJitter time.Duration
47+
preSharedKey string
4748
)
4849
client := new(codersdk.Client)
4950
cmd := &clibase.Cmd{
5051
Use: "start",
5152
Short: "Run a provisioner daemon",
5253
Middleware: clibase.Chain(
53-
r.InitClient(client),
54+
r.InitClientMissingTokenOK(client),
5455
),
5556
Handler: func(inv *clibase.Invocation) error {
5657
ctx, cancel := context.WithCancel(inv.Context())
@@ -59,11 +60,6 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
5960
notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...)
6061
defer notifyStop()
6162

62-
org, err := agpl.CurrentOrganization(inv, client)
63-
if err != nil {
64-
return xerrors.Errorf("get current organization: %w", err)
65-
}
66-
6763
tags, err := agpl.ParseProvisionerTags(rawTags)
6864
if err != nil {
6965
return err
@@ -112,9 +108,13 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
112108
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient),
113109
}
114110
srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
115-
return client.ServeProvisionerDaemon(ctx, org.ID, []codersdk.ProvisionerType{
116-
codersdk.ProvisionerTypeTerraform,
117-
}, tags)
111+
return client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
112+
Provisioners: []codersdk.ProvisionerType{
113+
codersdk.ProvisionerTypeTerraform,
114+
},
115+
Tags: tags,
116+
PreSharedKey: preSharedKey,
117+
})
118118
}, &provisionerd.Options{
119119
Logger: logger,
120120
JobPollInterval: pollInterval,
@@ -182,6 +182,12 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd {
182182
Default: (100 * time.Millisecond).String(),
183183
Value: clibase.DurationOf(&pollJitter),
184184
},
185+
{
186+
Flag: "psk",
187+
Env: "CODER_PROVISIONER_DAEMON_PSK",
188+
Description: "Pre-shared key to authenticate with Coder server.",
189+
Value: clibase.StringOf(&preSharedKey),
190+
},
185191
}
186192

187193
return cmd

0 commit comments

Comments
 (0)