From 214c8c8c062462952ae7b3524360e7221131481e Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 3 Aug 2023 10:39:21 +0000 Subject: [PATCH] feat: add PSK for external provisionerd auth Signed-off-by: Spike Curtis --- cli/root.go | 18 +- cli/testdata/coder_server_--help.golden | 4 + cli/testdata/server-config.yaml.golden | 3 + coderd/apidoc/docs.go | 3 + coderd/apidoc/swagger.json | 3 + coderd/coderdtest/coderdtest.go | 6 +- codersdk/client.go | 3 + codersdk/deployment.go | 10 + codersdk/organizations.go | 5 +- codersdk/provisionerdaemons.go | 51 +++-- docs/api/general.md | 1 + docs/api/schemas.md | 4 + docs/cli/provisionerd_start.md | 9 + docs/cli/server.md | 10 + enterprise/cli/provisionerdaemons.go | 24 ++- enterprise/cli/provisionerdaemons_test.go | 56 +++++ enterprise/cli/server.go | 1 + .../coder_provisionerd_start_--help.golden | 3 + .../cli/testdata/coder_server_--help.golden | 4 + enterprise/coderd/coderd.go | 24 ++- .../coderd/coderdenttest/coderdenttest.go | 2 + enterprise/coderd/provisionerdaemons.go | 53 +++-- enterprise/coderd/provisionerdaemons_test.go | 192 ++++++++++++++++-- site/src/api/typesGenerated.ts | 1 + 24 files changed, 429 insertions(+), 61 deletions(-) create mode 100644 enterprise/cli/provisionerdaemons_test.go diff --git a/cli/root.go b/cli/root.go index 4c268235a0f96..036be18a01300 100644 --- a/cli/root.go +++ b/cli/root.go @@ -494,6 +494,15 @@ func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) { // InitClient sets client to a new client. // It reads from global configuration files if flags are not set. func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { + return r.initClientInternal(client, false) +} + +func (r *RootCmd) InitClientMissingTokenOK(client *codersdk.Client) clibase.MiddlewareFunc { + return r.initClientInternal(client, true) +} + +// nolint: revive +func (r *RootCmd) initClientInternal(client *codersdk.Client, allowTokenMissing bool) clibase.MiddlewareFunc { if client == nil { panic("client is nil") } @@ -508,7 +517,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { rawURL, err := conf.URL().Read() // If the configuration files are absent, the user is logged out if os.IsNotExist(err) { - return (errUnauthenticated) + return errUnauthenticated } if err != nil { return err @@ -524,9 +533,10 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { r.token, err = conf.Session().Read() // If the configuration files are absent, the user is logged out if os.IsNotExist(err) { - return (errUnauthenticated) - } - if err != nil { + if !allowTokenMissing { + return errUnauthenticated + } + } else if err != nil { return err } } diff --git a/cli/testdata/coder_server_--help.golden b/cli/testdata/coder_server_--help.golden index cb7ca61b4913a..121ce98a98bd7 100644 --- a/cli/testdata/coder_server_--help.golden +++ b/cli/testdata/coder_server_--help.golden @@ -373,6 +373,10 @@ updating, and deleting workspace resources. --provisioner-daemon-poll-jitter duration, $CODER_PROVISIONER_DAEMON_POLL_JITTER (default: 100ms) Random jitter added to the poll interval. + --provisioner-daemon-psk string, $CODER_PROVISIONER_DAEMON_PSK + Pre-shared key to authenticate external provisioner daemons to Coder + server. + --provisioner-daemons int, $CODER_PROVISIONER_DAEMONS (default: 3) Number of provisioner daemons to create on start. If builds are stuck in queued state for a long time, consider increasing this. diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index c7a8df03414e6..7eab5aba07ecc 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -327,6 +327,9 @@ provisioning: # Time to force cancel provisioning tasks that are stuck. # (default: 10m0s, type: duration) forceCancelInterval: 10m0s + # Pre-shared key to authenticate external provisioner daemons to Coder server. + # (default: , type: string) + daemonPSK: "" # Enable one or more experiments. These are not ready for production. Separate # multiple experiments with commas, or enter '*' to opt-in to all available # experiments. diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index d5ccfb06dfc47..f8b8fd7d4c575 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -8747,6 +8747,9 @@ const docTemplate = `{ "daemon_poll_jitter": { "type": "integer" }, + "daemon_psk": { + "type": "string" + }, "daemons": { "type": "integer" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 69b3e1f6a5453..c6a1cc39f59a1 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -7857,6 +7857,9 @@ "daemon_poll_jitter": { "type": "integer" }, + "daemon_psk": { + "type": "string" + }, "daemons": { "type": "integer" }, diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index ada8782514381..04e2decd9b74b 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -497,7 +497,11 @@ func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uui }() closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { - return client.ServeProvisionerDaemon(ctx, org, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, tags) + return client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: org, + Provisioners: []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, + Tags: tags, + }) }, &provisionerd.Options{ Filesystem: fs, Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), diff --git a/codersdk/client.go b/codersdk/client.go index ad9e46ccf7f4a..98a109aa725e1 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -71,6 +71,9 @@ const ( // command that was invoked to produce the request. It is for internal use // only. CLITelemetryHeader = "Coder-CLI-Telemetry" + + // ProvisionerDaemonPSK contains the authentication pre-shared key for an external provisioner daemon + ProvisionerDaemonPSK = "Coder-Provisioner-Daemon-PSK" ) // loggableMimeTypes is a list of MIME types that are safe to log diff --git a/codersdk/deployment.go b/codersdk/deployment.go index e27731b6aa6bf..890273bafcd3a 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -328,6 +328,7 @@ type ProvisionerConfig struct { DaemonPollInterval clibase.Duration `json:"daemon_poll_interval" typescript:",notnull"` DaemonPollJitter clibase.Duration `json:"daemon_poll_jitter" typescript:",notnull"` ForceCancelInterval clibase.Duration `json:"force_cancel_interval" typescript:",notnull"` + DaemonPSK clibase.String `json:"daemon_psk" typescript:",notnull"` } type RateLimitConfig struct { @@ -1230,6 +1231,15 @@ when required by your organization's security policy.`, Group: &deploymentGroupProvisioning, YAML: "forceCancelInterval", }, + { + Name: "Provisioner Daemon Pre-shared Key (PSK)", + Description: "Pre-shared key to authenticate external provisioner daemons to Coder server.", + Flag: "provisioner-daemon-psk", + Env: "CODER_PROVISIONER_DAEMON_PSK", + Value: &c.Provisioner.DaemonPSK, + Group: &deploymentGroupProvisioning, + YAML: "daemonPSK", + }, // RateLimit settings { Name: "Disable All Rate Limits", diff --git a/codersdk/organizations.go b/codersdk/organizations.go index 26290fd4f4761..96b026a3197a5 100644 --- a/codersdk/organizations.go +++ b/codersdk/organizations.go @@ -149,10 +149,11 @@ func (c *Client) Organization(ctx context.Context, id uuid.UUID) (Organization, return organization, json.NewDecoder(res.Body).Decode(&organization) } -// ProvisionerDaemonsByOrganization returns provisioner daemons available for an organization. +// ProvisionerDaemons returns provisioner daemons available. func (c *Client) ProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) { res, err := c.Request(ctx, http.MethodGet, - "/api/v2/provisionerdaemons", + // TODO: the organization path parameter is currently ignored. + "/api/v2/organizations/default/provisionerdaemons", nil, ) if err != nil { diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 1c9378f718b3a..674523055e06f 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -164,38 +164,61 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after }), nil } -// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon +// ServeProvisionerDaemonRequest are the parameters to call ServeProvisionerDaemon with +// @typescript-ignore ServeProvisionerDaemonRequest +type ServeProvisionerDaemonRequest struct { + // Organization is the organization for the URL. At present provisioner daemons ARE NOT scoped to organizations + // and so the organization ID is optional. + Organization uuid.UUID `json:"organization" format:"uuid"` + // Provisioners is a list of provisioner types hosted by the provisioner daemon + Provisioners []ProvisionerType `json:"provisioners"` + // Tags is a map of key-value pairs that tag the jobs this provisioner daemon can handle + Tags map[string]string `json:"tags"` + // PreSharedKey is an authentication key to use on the API instead of the normal session token from the client. + PreSharedKey string `json:"pre_shared_key"` +} + +// ServeProvisionerDaemon returns the gRPC service for a provisioner daemon // implementation. The context is during dial, not during the lifetime of the // client. Client should be closed after use. -func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) { - serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization)) +func (c *Client) ServeProvisionerDaemon(ctx context.Context, req ServeProvisionerDaemonRequest) (proto.DRPCProvisionerDaemonClient, error) { + serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", req.Organization)) if err != nil { return nil, xerrors.Errorf("parse url: %w", err) } query := serverURL.Query() - for _, provisioner := range provisioners { + for _, provisioner := range req.Provisioners { query.Add("provisioner", string(provisioner)) } - for key, value := range tags { + for key, value := range req.Tags { query.Add("tag", fmt.Sprintf("%s=%s", key, value)) } serverURL.RawQuery = query.Encode() - jar, err := cookiejar.New(nil) - if err != nil { - return nil, xerrors.Errorf("create cookie jar: %w", err) - } - jar.SetCookies(serverURL, []*http.Cookie{{ - Name: SessionTokenCookie, - Value: c.SessionToken(), - }}) httpClient := &http.Client{ - Jar: jar, Transport: c.HTTPClient.Transport, } + headers := http.Header{} + + if req.PreSharedKey == "" { + // use session token if we don't have a PSK. + jar, err := cookiejar.New(nil) + if err != nil { + return nil, xerrors.Errorf("create cookie jar: %w", err) + } + jar.SetCookies(serverURL, []*http.Cookie{{ + Name: SessionTokenCookie, + Value: c.SessionToken(), + }}) + httpClient.Jar = jar + } else { + headers.Set(ProvisionerDaemonPSK, req.PreSharedKey) + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ HTTPClient: httpClient, // Need to disable compression to avoid a data-race. CompressionMode: websocket.CompressionDisabled, + HTTPHeader: headers, }) if err != nil { if res == nil { diff --git a/docs/api/general.md b/docs/api/general.md index bb64823bd86c9..3f1f90a02d851 100644 --- a/docs/api/general.md +++ b/docs/api/general.md @@ -305,6 +305,7 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "provisioner": { "daemon_poll_interval": 0, "daemon_poll_jitter": 0, + "daemon_psk": "string", "daemons": 0, "daemons_echo": true, "force_cancel_interval": 0 diff --git a/docs/api/schemas.md b/docs/api/schemas.md index b9d5e3d1b78a1..ef9187a7c0cd9 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -2096,6 +2096,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "provisioner": { "daemon_poll_interval": 0, "daemon_poll_jitter": 0, + "daemon_psk": "string", "daemons": 0, "daemons_echo": true, "force_cancel_interval": 0 @@ -2453,6 +2454,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in "provisioner": { "daemon_poll_interval": 0, "daemon_poll_jitter": 0, + "daemon_psk": "string", "daemons": 0, "daemons_echo": true, "force_cancel_interval": 0 @@ -3480,6 +3482,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in { "daemon_poll_interval": 0, "daemon_poll_jitter": 0, + "daemon_psk": "string", "daemons": 0, "daemons_echo": true, "force_cancel_interval": 0 @@ -3492,6 +3495,7 @@ AuthorizationObject can represent a "set" of objects, such as: all workspaces in | ----------------------- | ------- | -------- | ------------ | ----------- | | `daemon_poll_interval` | integer | false | | | | `daemon_poll_jitter` | integer | false | | | +| `daemon_psk` | string | false | | | | `daemons` | integer | false | | | | `daemons_echo` | boolean | false | | | | `force_cancel_interval` | integer | false | | | diff --git a/docs/cli/provisionerd_start.md b/docs/cli/provisionerd_start.md index 583e520389150..b129605933db3 100644 --- a/docs/cli/provisionerd_start.md +++ b/docs/cli/provisionerd_start.md @@ -42,6 +42,15 @@ How often to poll for provisioner jobs. How much to jitter the poll interval by. +### --psk + +| | | +| ----------- | ------------------------------------------ | +| Type | string | +| Environment | $CODER_PROVISIONER_DAEMON_PSK | + +Pre-shared key to authenticate with Coder server. + ### -t, --tag | | | diff --git a/docs/cli/server.md b/docs/cli/server.md index 90c60d7392f00..9591dc8041f9f 100644 --- a/docs/cli/server.md +++ b/docs/cli/server.md @@ -668,6 +668,16 @@ Collect database metrics (may increase charges for metrics storage). Serve prometheus metrics on the address defined by prometheus address. +### --provisioner-daemon-psk + +| | | +| ----------- | ------------------------------------------ | +| Type | string | +| Environment | $CODER_PROVISIONER_DAEMON_PSK | +| YAML | provisioning.daemonPSK | + +Pre-shared key to authenticate external provisioner daemons to Coder server. + ### --provisioner-daemons | | | diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go index f3dfc2ba367d7..baca9fa4b3296 100644 --- a/enterprise/cli/provisionerdaemons.go +++ b/enterprise/cli/provisionerdaemons.go @@ -44,13 +44,14 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { rawTags []string pollInterval time.Duration pollJitter time.Duration + preSharedKey string ) client := new(codersdk.Client) cmd := &clibase.Cmd{ Use: "start", Short: "Run a provisioner daemon", Middleware: clibase.Chain( - r.InitClient(client), + r.InitClientMissingTokenOK(client), ), Handler: func(inv *clibase.Invocation) error { ctx, cancel := context.WithCancel(inv.Context()) @@ -59,11 +60,6 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...) defer notifyStop() - org, err := agpl.CurrentOrganization(inv, client) - if err != nil { - return xerrors.Errorf("get current organization: %w", err) - } - tags, err := agpl.ParseProvisionerTags(rawTags) if err != nil { return err @@ -112,9 +108,13 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(terraformClient), } srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { - return client.ServeProvisionerDaemon(ctx, org.ID, []codersdk.ProvisionerType{ - codersdk.ProvisionerTypeTerraform, - }, tags) + return client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeTerraform, + }, + Tags: tags, + PreSharedKey: preSharedKey, + }) }, &provisionerd.Options{ Logger: logger, JobPollInterval: pollInterval, @@ -182,6 +182,12 @@ func (r *RootCmd) provisionerDaemonStart() *clibase.Cmd { Default: (100 * time.Millisecond).String(), Value: clibase.DurationOf(&pollJitter), }, + { + Flag: "psk", + Env: "CODER_PROVISIONER_DAEMON_PSK", + Description: "Pre-shared key to authenticate with Coder server.", + Value: clibase.StringOf(&preSharedKey), + }, } return cmd diff --git a/enterprise/cli/provisionerdaemons_test.go b/enterprise/cli/provisionerdaemons_test.go new file mode 100644 index 0000000000000..69b23d870757c --- /dev/null +++ b/enterprise/cli/provisionerdaemons_test.go @@ -0,0 +1,56 @@ +package cli_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/pty/ptytest" + "github.com/coder/coder/testutil" +) + +func TestProvisionerDaemon_PSK(t *testing.T) { + t.Parallel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + ProvisionerDaemonPSK: "provisionersftw", + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + inv, conf := newCLI(t, "provisionerd", "start", "--psk=provisionersftw") + err := conf.URL().Write(client.URL.String()) + require.NoError(t, err) + pty := ptytest.New(t).Attach(inv) + ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) + defer cancel() + clitest.Start(t, inv) + pty.ExpectMatchContext(ctx, "starting provisioner daemon") +} + +func TestProvisionerDaemon_SessionToken(t *testing.T) { + t.Parallel() + + client, _ := coderdenttest.New(t, &coderdenttest.Options{ + ProvisionerDaemonPSK: "provisionersftw", + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + inv, conf := newCLI(t, "provisionerd", "start") + clitest.SetupConfig(t, client, conf) + pty := ptytest.New(t).Attach(inv) + ctx, cancel := context.WithTimeout(inv.Context(), testutil.WaitLong) + defer cancel() + clitest.Start(t, inv) + pty.ExpectMatchContext(ctx, "starting provisioner daemon") +} diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index b0561a0de1850..70a06ff0548e4 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -66,6 +66,7 @@ func (r *RootCmd) server() *clibase.Cmd { DERPServerRegionID: int(options.DeploymentValues.DERP.Server.RegionID.Value()), ProxyHealthInterval: options.DeploymentValues.ProxyHealthStatusInterval.Value(), DefaultQuietHoursSchedule: options.DeploymentValues.UserQuietHoursSchedule.DefaultSchedule.Value(), + ProvisionerDaemonPSK: options.DeploymentValues.Provisioner.DaemonPSK.Value(), } api, err := coderd.New(ctx, o) diff --git a/enterprise/cli/testdata/coder_provisionerd_start_--help.golden b/enterprise/cli/testdata/coder_provisionerd_start_--help.golden index 1236cfb5ae7e1..5258c33125173 100644 --- a/enterprise/cli/testdata/coder_provisionerd_start_--help.golden +++ b/enterprise/cli/testdata/coder_provisionerd_start_--help.golden @@ -12,6 +12,9 @@ Run a provisioner daemon --poll-jitter duration, $CODER_PROVISIONERD_POLL_JITTER (default: 100ms) How much to jitter the poll interval by. + --psk string, $CODER_PROVISIONER_DAEMON_PSK + Pre-shared key to authenticate with Coder server. + -t, --tag string-array, $CODER_PROVISIONERD_TAGS Tags to filter provisioner jobs by. diff --git a/enterprise/cli/testdata/coder_server_--help.golden b/enterprise/cli/testdata/coder_server_--help.golden index cb7ca61b4913a..121ce98a98bd7 100644 --- a/enterprise/cli/testdata/coder_server_--help.golden +++ b/enterprise/cli/testdata/coder_server_--help.golden @@ -373,6 +373,10 @@ updating, and deleting workspace resources. --provisioner-daemon-poll-jitter duration, $CODER_PROVISIONER_DAEMON_POLL_JITTER (default: 100ms) Random jitter added to the poll interval. + --provisioner-daemon-psk string, $CODER_PROVISIONER_DAEMON_PSK + Pre-shared key to authenticate external provisioner daemons to Coder + server. + --provisioner-daemons int, $CODER_PROVISIONER_DAEMONS (default: 3) Number of provisioner daemons to create on start. If builds are stuck in queued state for a long time, consider increasing this. diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index c3cc5e0be5ccd..71d975e3ef1d6 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -67,6 +67,10 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { AGPL: coderd.New(options.Options), Options: options, + provisionerDaemonAuth: &provisionerDaemonAuth{ + psk: options.ProvisionerDaemonPSK, + authorizer: options.Authorizer, + }, } defer func() { if err != nil { @@ -193,14 +197,21 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { r.Get("/", api.groupByOrganization) }) }) + // TODO: provisioner daemons are not scoped to organizations in the database, so placing them + // under an organization route doesn't make sense. In order to allow the /serve endpoint to + // work with a pre-shared key (PSK) without an API key, these routes will simply ignore the + // value of {organization}. That is, the route will work with any organization ID, whether or + // not it exits. This doesn't leak any information about the existence of organizations, so is + // fine from a security perspective, but might be a little surprising. + // + // We may in future decide to scope provisioner daemons to organizations, so we'll keep the API + // route as is. r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) { r.Use( api.provisionerDaemonsEnabledMW, - apiKeyMiddleware, - httpmw.ExtractOrganizationParam(api.Database), ) - r.Get("/", api.provisionerDaemons) - r.Get("/serve", api.provisionerDaemonServe) + r.With(apiKeyMiddleware).Get("/", api.provisionerDaemons) + r.With(apiKeyMiddlewareOptional).Get("/serve", api.provisionerDaemonServe) }) r.Route("/templates/{template}/acl", func(r chi.Router) { r.Use( @@ -362,6 +373,9 @@ type Options struct { EntitlementsUpdateInterval time.Duration ProxyHealthInterval time.Duration Keys map[string]ed25519.PublicKey + + // optional pre-shared key for authentication of external provisioner daemons + ProvisionerDaemonPSK string } type API struct { @@ -383,6 +397,8 @@ type API struct { entitlementsUpdateMu sync.Mutex entitlementsMu sync.RWMutex entitlements codersdk.Entitlements + + provisionerDaemonAuth *provisionerDaemonAuth } func (api *API) Close() error { diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 92e0b627d60ae..64cb15c740fed 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -56,6 +56,7 @@ type Options struct { DontAddLicense bool DontAddFirstUser bool ReplicaSyncUpdateInterval time.Duration + ProvisionerDaemonPSK string } // New constructs a codersdk client connected to an in-memory Enterprise API instance. @@ -94,6 +95,7 @@ func NewWithAPI(t *testing.T, options *Options) ( Keys: Keys, ProxyHealthInterval: options.ProxyHealthInterval, DefaultQuietHoursSchedule: oop.DeploymentValues.UserQuietHoursSchedule.DefaultSchedule.Value(), + ProvisionerDaemonPSK: options.ProvisionerDaemonPSK, }) require.NoError(t, err) setHandler(coderAPI.AGPL.RootHandler) diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go index 055704a6bcb11..1b3010d833200 100644 --- a/enterprise/coderd/provisionerdaemons.go +++ b/enterprise/coderd/provisionerdaemons.go @@ -2,6 +2,7 @@ package coderd import ( "context" + "crypto/subtle" "database/sql" "encoding/json" "errors" @@ -87,6 +88,40 @@ func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, apiDaemons) } +type provisionerDaemonAuth struct { + psk string + authorizer rbac.Authorizer +} + +// authorize returns mutated tags and true if the given HTTP request is authorized to access the provisioner daemon +// protobuf API, and returns nil, false otherwise. +func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]string) (map[string]string, bool) { + ctx := r.Context() + apiKey, ok := httpmw.APIKeyOptional(r) + if ok { + tags = provisionerdserver.MutateTags(apiKey.UserID, tags) + if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeUser { + // Any authenticated user can create provisioner daemons scoped + // for jobs that they own, + return tags, true + } + ua := httpmw.UserAuthorization(r) + if err := p.authorizer.Authorize(ctx, ua.Actor, rbac.ActionCreate, rbac.ResourceProvisionerDaemon); err == nil { + // User is allowed to create provisioner daemons + return tags, true + } + } + + // Check for PSK + if p.psk != "" { + psk := r.Header.Get(codersdk.ProvisionerDaemonPSK) + if subtle.ConstantTimeCompare([]byte(p.psk), []byte(psk)) == 1 { + return tags, true + } + } + return nil, false +} + // Serves the provisioner daemon protobuf API over a WebSocket. // // @Summary Serve provisioner daemon @@ -134,19 +169,11 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) } } - // Any authenticated user can create provisioner daemons scoped - // for jobs that they own, but only authorized users can create - // globally scoped provisioners that attach to all jobs. - apiKey := httpmw.APIKey(r) - tags = provisionerdserver.MutateTags(apiKey.UserID, tags) - - if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization { - if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) { - httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{ - Message: "You aren't allowed to create provisioner daemons for the organization.", - }) - return - } + tags, authorized := api.provisionerDaemonAuth.authorize(r, tags) + if !authorized { + httpapi.Write(ctx, rw, http.StatusForbidden, + codersdk.Response{Message: "You aren't allowed to create provisioner daemons"}) + return } provisioners := make([]database.ProvisionerType, 0) diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go index 28a89431b4f00..1586a92773e73 100644 --- a/enterprise/coderd/provisionerdaemons_test.go +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/enterprise/coderd/license" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/testutil" ) func TestProvisionerDaemonServe(t *testing.T) { @@ -28,23 +29,43 @@ func TestProvisionerDaemonServe(t *testing.T) { codersdk.FeatureExternalProvisionerDaemons: 1, }, }}) - srv, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ - codersdk.ProvisionerTypeEcho, - }, map[string]string{}) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + srv, err := client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{}, + }) require.NoError(t, err) srv.DRPCConn().Close() + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 1) }) t.Run("NoLicense", func(t *testing.T) { t.Parallel() client, user := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true}) - _, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ - codersdk.ProvisionerTypeEcho, - }, map[string]string{}) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + _, err := client.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{}, + }) require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + + // querying provisioner daemons is forbidden without license + _, err = client.ProvisionerDaemons(ctx) + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) }) t.Run("Organization", func(t *testing.T) { @@ -55,15 +76,24 @@ func TestProvisionerDaemonServe(t *testing.T) { }, }}) another, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID, rbac.RoleOrgAdmin(user.OrganizationID)) - _, err := another.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ - codersdk.ProvisionerTypeEcho, - }, map[string]string{ - provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + _, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, }) require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 0) }) t.Run("OrganizationNoPerms", func(t *testing.T) { @@ -74,15 +104,24 @@ func TestProvisionerDaemonServe(t *testing.T) { }, }}) another, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) - _, err := another.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ - codersdk.ProvisionerTypeEcho, - }, map[string]string{ - provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + _, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, }) require.Error(t, err) var apiError *codersdk.Error require.ErrorAs(t, err, &apiError) require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 0) }) t.Run("UserLocal", func(t *testing.T) { @@ -141,4 +180,129 @@ func TestProvisionerDaemonServe(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, another, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) }) + + t.Run("PSK", func(t *testing.T) { + t.Parallel() + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + another := codersdk.New(client.URL) + srv, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + PreSharedKey: "provisionersftw", + }) + require.NoError(t, err) + err = srv.DRPCConn().Close() + require.NoError(t, err) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 1) + }) + + t.Run("BadPSK", func(t *testing.T) { + t.Parallel() + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + }) + another := codersdk.New(client.URL) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + _, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + PreSharedKey: "the wrong key", + }) + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 0) + }) + + t.Run("NoAuth", func(t *testing.T) { + t.Parallel() + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + ProvisionerDaemonPSK: "provisionersftw", + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + another := codersdk.New(client.URL) + _, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + }) + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 0) + }) + + t.Run("NoPSK", func(t *testing.T) { + t.Parallel() + client, user := coderdenttest.New(t, &coderdenttest.Options{ + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureExternalProvisionerDaemons: 1, + }, + }, + }) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + another := codersdk.New(client.URL) + _, err := another.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{ + Organization: user.OrganizationID, + Provisioners: []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, + Tags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + PreSharedKey: "provisionersftw", + }) + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + daemons, err := client.ProvisionerDaemons(ctx) + require.NoError(t, err) + require.Len(t, daemons, 0) + }) } diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index e374e46e192f1..652e4081a9c4b 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -710,6 +710,7 @@ export interface ProvisionerConfig { readonly daemon_poll_interval: number readonly daemon_poll_jitter: number readonly force_cancel_interval: number + readonly daemon_psk: string } // From codersdk/provisionerdaemons.go