diff --git a/.vscode/settings.json b/.vscode/settings.json index ba58f6f4ee1bf..09aab5fcbc198 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -17,6 +17,7 @@ "codersdk", "cronstrue", "databasefake", + "dbtype", "DERP", "derphttp", "derpmap", diff --git a/cli/deployment/config.go b/cli/deployment/config.go index ba7a5061fa638..ad4b6efea26c0 100644 --- a/cli/deployment/config.go +++ b/cli/deployment/config.go @@ -143,7 +143,7 @@ func newConfig() *codersdk.DeploymentConfig { Name: "Cache Directory", Usage: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.", Flag: "cache-dir", - Default: defaultCacheDir(), + Default: DefaultCacheDir(), }, InMemoryDatabase: &codersdk.DeploymentConfigField[bool]{ Name: "In Memory Database", @@ -672,7 +672,7 @@ func formatEnv(key string) string { return "CODER_" + strings.ToUpper(strings.NewReplacer("-", "_", ".", "_").Replace(key)) } -func defaultCacheDir() string { +func DefaultCacheDir() string { defaultCacheDir, err := os.UserCacheDir() if err != nil { defaultCacheDir = os.TempDir() diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index 3930fe16c8a57..20740be7ae3bf 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -26,7 +26,7 @@ func gitAskpass() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - ctx, stop := signal.NotifyContext(ctx, interruptSignals...) + ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) defer stop() user, host, err := gitauth.ParseAskpass(args[0]) diff --git a/cli/gitssh.go b/cli/gitssh.go index b18b919f79515..09ebc396fdbde 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -29,7 +29,7 @@ func gitssh() *cobra.Command { // Catch interrupt signals to ensure the temporary private // key file is cleaned up on most cases. - ctx, stop := signal.NotifyContext(ctx, interruptSignals...) + ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) defer stop() // Early check so errors are reported immediately. diff --git a/cli/server.go b/cli/server.go index 95d18fcf59828..e8a009a8977c4 100644 --- a/cli/server.go +++ b/cli/server.go @@ -108,7 +108,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co // // To get out of a graceful shutdown, the user can send // SIGQUIT with ctrl+\ or SIGKILL with `kill -9`. - notifyCtx, notifyStop := signal.NotifyContext(ctx, interruptSignals...) + notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...) defer notifyStop() // Clean up idle connections at the end, e.g. @@ -946,7 +946,7 @@ func newProvisionerDaemon( return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { // This debounces calls to listen every second. Read the comment // in provisionerdserver.go to learn more! - return coderAPI.ListenProvisionerDaemon(ctx, time.Second) + return coderAPI.CreateInMemoryProvisionerDaemon(ctx, time.Second) }, &provisionerd.Options{ Logger: logger, PollInterval: 500 * time.Millisecond, diff --git a/cli/signal_unix.go b/cli/signal_unix.go index 7d2cd0e5022c5..05d619c0232e4 100644 --- a/cli/signal_unix.go +++ b/cli/signal_unix.go @@ -7,7 +7,7 @@ import ( "syscall" ) -var interruptSignals = []os.Signal{ +var InterruptSignals = []os.Signal{ os.Interrupt, syscall.SIGTERM, syscall.SIGHUP, diff --git a/cli/signal_windows.go b/cli/signal_windows.go index 17652adfb626d..3624415a6452f 100644 --- a/cli/signal_windows.go +++ b/cli/signal_windows.go @@ -6,4 +6,4 @@ import ( "os" ) -var interruptSignals = []os.Signal{os.Interrupt} +var InterruptSignals = []os.Signal{os.Interrupt} diff --git a/cli/templatecreate.go b/cli/templatecreate.go index 1f8833d0c957a..8c8ce9e034610 100644 --- a/cli/templatecreate.go +++ b/cli/templatecreate.go @@ -24,10 +24,11 @@ import ( func templateCreate() *cobra.Command { var ( - directory string - provisioner string - parameterFile string - defaultTTL time.Duration + directory string + provisioner string + provisionerTags []string + parameterFile string + defaultTTL time.Duration ) cmd := &cobra.Command{ Use: "create [name]", @@ -87,12 +88,18 @@ func templateCreate() *cobra.Command { } spin.Stop() + tags, err := ParseProvisionerTags(provisionerTags) + if err != nil { + return err + } + job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ - Client: client, - Organization: organization, - Provisioner: database.ProvisionerType(provisioner), - FileID: resp.ID, - ParameterFile: parameterFile, + Client: client, + Organization: organization, + Provisioner: database.ProvisionerType(provisioner), + FileID: resp.ID, + ParameterFile: parameterFile, + ProvisionerTags: tags, }) if err != nil { return err @@ -131,6 +138,7 @@ func templateCreate() *cobra.Command { cmd.Flags().StringVarP(&directory, "directory", "d", currentDirectory, "Specify the directory to create from") cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend") cmd.Flags().StringVarP(¶meterFile, "parameter-file", "", "", "Specify a file path with parameter values.") + cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.") cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 24*time.Hour, "Specify a default TTL for workspaces created from this template.") // This is for testing! err := cmd.Flags().MarkHidden("test.provisioner") @@ -154,6 +162,7 @@ type createValidTemplateVersionArgs struct { // before prompting the user. Set to false to always prompt for param // values. ReuseParameters bool + ProvisionerTags map[string]string } func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { @@ -165,6 +174,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers FileID: args.FileID, Provisioner: codersdk.ProvisionerType(args.Provisioner), ParameterValues: parameters, + ProvisionerTags: args.ProvisionerTags, } if args.Template != nil { req.TemplateID = args.Template.ID @@ -334,3 +344,15 @@ func prettyDirectoryPath(dir string) string { } return pretty } + +func ParseProvisionerTags(rawTags []string) (map[string]string, error) { + tags := map[string]string{} + for _, rawTag := range rawTags { + parts := strings.SplitN(rawTag, "=", 2) + if len(parts) < 2 { + return nil, xerrors.Errorf("invalid tag format for %q. must be key=value", rawTag) + } + tags[parts[0]] = parts[1] + } + return tags, nil +} diff --git a/cli/templatepush.go b/cli/templatepush.go index 9eed180667e7c..c3a5e2f0c0ecb 100644 --- a/cli/templatepush.go +++ b/cli/templatepush.go @@ -18,11 +18,12 @@ import ( func templatePush() *cobra.Command { var ( - directory string - versionName string - provisioner string - parameterFile string - alwaysPrompt bool + directory string + versionName string + provisioner string + parameterFile string + alwaysPrompt bool + provisionerTags []string ) cmd := &cobra.Command{ @@ -75,6 +76,11 @@ func templatePush() *cobra.Command { } spin.Stop() + tags, err := ParseProvisionerTags(provisionerTags) + if err != nil { + return err + } + job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ Name: versionName, Client: client, @@ -84,6 +90,7 @@ func templatePush() *cobra.Command { ParameterFile: parameterFile, Template: &template, ReuseParameters: !alwaysPrompt, + ProvisionerTags: tags, }) if err != nil { return err diff --git a/coderd/autobuild/executor/lifecycle_executor.go b/coderd/autobuild/executor/lifecycle_executor.go index ba2795a3a202f..3ed07da8f59f5 100644 --- a/coderd/autobuild/executor/lifecycle_executor.go +++ b/coderd/autobuild/executor/lifecycle_executor.go @@ -278,6 +278,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa Type: database.ProvisionerJobTypeWorkspaceBuild, StorageMethod: priorJob.StorageMethod, FileID: priorJob.FileID, + Tags: priorJob.Tags, Input: input, }) if err != nil { diff --git a/coderd/coderd.go b/coderd/coderd.go index d695bb508db47..077e8ecfc182c 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1,8 +1,10 @@ package coderd import ( + "context" "crypto/tls" "crypto/x509" + "encoding/json" "fmt" "io" "net/http" @@ -18,10 +20,13 @@ import ( "github.com/go-chi/chi/v5/middleware" "github.com/google/uuid" "github.com/klauspost/compress/zstd" + "github.com/moby/moby/pkg/namesgenerator" "github.com/prometheus/client_golang/prometheus" "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" "google.golang.org/api/idtoken" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/tailcfg" @@ -32,17 +37,20 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/awsidentity" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/database/dbtype" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/coderd/gitsshkey" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/metricscache" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/coderd/rbac" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/coderd/tracing" "github.com/coder/coder/coderd/wsconncache" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" "github.com/coder/coder/site" "github.com/coder/coder/tailnet" ) @@ -323,13 +331,6 @@ func New(options *Options) *API { r.Get("/{fileID}", api.fileByID) r.Post("/", api.postFile) }) - - r.Route("/provisionerdaemons", func(r chi.Router) { - r.Use( - apiKeyMiddleware, - ) - r.Get("/", api.provisionerDaemons) - }) r.Route("/organizations", func(r chi.Router) { r.Use( apiKeyMiddleware, @@ -595,18 +596,20 @@ type API struct { // RootHandler serves "/" RootHandler chi.Router - metricsCache *metricscache.Cache - siteHandler http.Handler - websocketWaitMutex sync.Mutex - websocketWaitGroup sync.WaitGroup + metricsCache *metricscache.Cache + siteHandler http.Handler + + WebsocketWaitMutex sync.Mutex + WebsocketWaitGroup sync.WaitGroup + workspaceAgentCache *wsconncache.Cache } // Close waits for all WebSocket connections to drain before returning. func (api *API) Close() error { - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Wait() - api.websocketWaitMutex.Unlock() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Wait() + api.WebsocketWaitMutex.Unlock() api.metricsCache.Close() coordinator := api.TailnetCoordinator.Load() @@ -635,3 +638,70 @@ func compressHandler(h http.Handler) http.Handler { return cmp.Handler(h) } + +// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd +// in the same process. +func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) { + clientSession, serverSession := provisionersdk.TransportPipe() + defer func() { + if err != nil { + _ = clientSession.Close() + _ = serverSession.Close() + } + }() + + name := namesgenerator.GetRandomName(1) + daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + Name: name, + Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform}, + Tags: dbtype.StringMap{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }, + }) + if err != nil { + return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err) + } + + tags, err := json.Marshal(daemon.Tags) + if err != nil { + return nil, xerrors.Errorf("marshal tags: %w", err) + } + + mux := drpcmux.New() + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ + AccessURL: api.AccessURL, + ID: daemon.ID, + Database: api.Database, + Pubsub: api.Pubsub, + Provisioners: daemon.Provisioners, + Telemetry: api.Telemetry, + Tags: tags, + QuotaCommitter: &api.QuotaCommitter, + AcquireJobDebounce: debounce, + Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), + }) + if err != nil { + return nil, err + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + api.Logger.Debug(ctx, "drpc server error", slog.Error(err)) + }, + }) + go func() { + err := server.Serve(ctx, serverSession) + if err != nil && !xerrors.Is(err, io.EOF) { + api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err)) + } + // close the sessions so we don't leak goroutines serving them. + _ = clientSession.Close() + _ = serverSession.Close() + }() + + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil +} diff --git a/coderd/coderdtest/authorize.go b/coderd/coderdtest/authorize.go index 91f5da9e30fe9..ef308fba97dff 100644 --- a/coderd/coderdtest/authorize.go +++ b/coderd/coderdtest/authorize.go @@ -19,7 +19,6 @@ import ( "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/coder/coder/testutil" ) func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { @@ -204,11 +203,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) { AssertAction: rbac.ActionRead, AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID), }, - "GET:/api/v2/provisionerdaemons": { - StatusCode: http.StatusOK, - AssertObject: rbac.ResourceProvisionerDaemon, - }, - "POST:/api/v2/parameters/{scope}/{id}": { AssertAction: rbac.ActionUpdate, AssertObject: rbac.ResourceTemplate, @@ -303,16 +297,6 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a if !ok { t.Fail() } - // The provisioner will call to coderd and register itself. This is async, - // so we wait for it to occur. - require.Eventually(t, func() bool { - provisionerds, err := client.ProvisionerDaemons(ctx) - return assert.NoError(t, err) && len(provisionerds) > 0 - }, testutil.WaitLong, testutil.IntervalSlow) - - provisionerds, err := client.ProvisionerDaemons(ctx) - require.NoError(t, err, "fetch provisioners") - require.Len(t, provisionerds, 1) organization, err := client.Organization(ctx, admin.OrganizationID) require.NoError(t, err, "fetch org") diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index 0f253a0ceeffe..abb8fec4828c8 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -69,7 +69,7 @@ import ( "github.com/coder/coder/cryptorand" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionerd" - "github.com/coder/coder/provisionerd/proto" + provisionerdproto "github.com/coder/coder/provisionerd/proto" "github.com/coder/coder/provisionersdk" sdkproto "github.com/coder/coder/provisionersdk/proto" "github.com/coder/coder/tailnet" @@ -328,8 +328,43 @@ func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer { assert.NoError(t, err) }() - closer := provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) { - return coderAPI.ListenProvisionerDaemon(ctx, 0) + closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { + return coderAPI.CreateInMemoryProvisionerDaemon(ctx, 0) + }, &provisionerd.Options{ + Filesystem: fs, + Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), + PollInterval: 50 * time.Millisecond, + UpdateInterval: 250 * time.Millisecond, + ForceCancelInterval: time.Second, + Provisioners: provisionerd.Provisioners{ + string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)), + }, + WorkDirectory: t.TempDir(), + }) + t.Cleanup(func() { + _ = closer.Close() + }) + return closer +} + +func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uuid.UUID, tags map[string]string) io.Closer { + echoClient, echoServer := provisionersdk.TransportPipe() + ctx, cancelFunc := context.WithCancel(context.Background()) + t.Cleanup(func() { + _ = echoClient.Close() + _ = echoServer.Close() + cancelFunc() + }) + fs := afero.NewMemMapFs() + go func() { + err := echo.Serve(ctx, fs, &provisionersdk.ServeOptions{ + Listener: echoServer, + }) + assert.NoError(t, err) + }() + + closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { + return client.ServeProvisionerDaemon(ctx, org, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, tags) }, &provisionerd.Options{ Filesystem: fs, Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug), diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index 466c86a1bb5ef..51d4016c4f0c3 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -3,6 +3,7 @@ package databasefake import ( "context" "database/sql" + "encoding/json" "sort" "strings" "sync" @@ -146,6 +147,29 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu if !found { continue } + tags := map[string]string{} + if arg.Tags != nil { + err := json.Unmarshal(arg.Tags, &tags) + if err != nil { + return provisionerJob, xerrors.Errorf("unmarshal: %w", err) + } + } + + missing := false + for key, value := range provisionerJob.Tags { + provided, found := tags[key] + if !found { + missing = true + break + } + if provided != value { + missing = true + break + } + } + if missing { + continue + } provisionerJob.StartedAt = arg.StartedAt provisionerJob.UpdatedAt = arg.StartedAt.Time provisionerJob.WorkerID = arg.WorkerID @@ -2244,6 +2268,7 @@ func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.In CreatedAt: arg.CreatedAt, Name: arg.Name, Provisioners: arg.Provisioners, + Tags: arg.Tags, } q.provisionerDaemons = append(q.provisionerDaemons, daemon) return daemon, nil @@ -2264,6 +2289,7 @@ func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser FileID: arg.FileID, Type: arg.Type, Input: arg.Input, + Tags: arg.Tags, } q.provisionerJobs = append(q.provisionerJobs, job) return job, nil diff --git a/coderd/database/dbtype/dbtype.go b/coderd/database/dbtype/dbtype.go new file mode 100644 index 0000000000000..9ab47c16f5552 --- /dev/null +++ b/coderd/database/dbtype/dbtype.go @@ -0,0 +1,30 @@ +package dbtype + +import ( + "database/sql/driver" + "encoding/json" + + "golang.org/x/xerrors" +) + +type StringMap map[string]string + +func (m *StringMap) Scan(src interface{}) error { + if src == nil { + return nil + } + switch src := src.(type) { + case []byte: + err := json.Unmarshal(src, m) + if err != nil { + return err + } + default: + return xerrors.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, m) + } + return nil +} + +func (m StringMap) Value() (driver.Value, error) { + return json.Marshal(m) +} diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index a7c6394b346cf..cece10169071e 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -269,7 +269,8 @@ CREATE TABLE provisioner_daemons ( updated_at timestamp with time zone, name character varying(64) NOT NULL, provisioners provisioner_type[] NOT NULL, - replica_id uuid + replica_id uuid, + tags jsonb DEFAULT '{}'::jsonb NOT NULL ); CREATE TABLE provisioner_job_logs ( @@ -306,7 +307,8 @@ CREATE TABLE provisioner_jobs ( type provisioner_job_type NOT NULL, input jsonb NOT NULL, worker_id uuid, - file_id uuid NOT NULL + file_id uuid NOT NULL, + tags jsonb DEFAULT '{"scope": "organization"}'::jsonb NOT NULL ); CREATE TABLE replicas ( diff --git a/coderd/database/migrations/000079_provisioner_daemon_tags.down.sql b/coderd/database/migrations/000079_provisioner_daemon_tags.down.sql new file mode 100644 index 0000000000000..4674e60ace5c2 --- /dev/null +++ b/coderd/database/migrations/000079_provisioner_daemon_tags.down.sql @@ -0,0 +1,2 @@ +ALTER TABLE provisioner_daemons DROP COLUMN tags; +ALTER TABLE provisioner_jobs DROP COLUMN tags; diff --git a/coderd/database/migrations/000079_provisioner_daemon_tags.up.sql b/coderd/database/migrations/000079_provisioner_daemon_tags.up.sql new file mode 100644 index 0000000000000..778214074625a --- /dev/null +++ b/coderd/database/migrations/000079_provisioner_daemon_tags.up.sql @@ -0,0 +1,5 @@ +ALTER TABLE provisioner_daemons ADD COLUMN tags jsonb NOT NULL DEFAULT '{}'; + +-- We must add the organization scope by default, otherwise pending jobs +-- could be provisioned on new daemons that don't match the tags. +ALTER TABLE provisioner_jobs ADD COLUMN tags jsonb NOT NULL DEFAULT '{"scope":"organization"}'; diff --git a/coderd/database/models.go b/coderd/database/models.go index 08dfdfd6d6c72..9b6f2cc1b163e 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -10,6 +10,7 @@ import ( "fmt" "time" + "github.com/coder/coder/coderd/database/dbtype" "github.com/google/uuid" "github.com/lib/pq" "github.com/tabbed/pqtype" @@ -525,6 +526,7 @@ type ProvisionerDaemon struct { Name string `db:"name" json:"name"` Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"` ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"` + Tags dbtype.StringMap `db:"tags" json:"tags"` } type ProvisionerJob struct { @@ -543,6 +545,7 @@ type ProvisionerJob struct { Input json.RawMessage `db:"input" json:"input"` WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` FileID uuid.UUID `db:"file_id" json:"file_id"` + Tags dbtype.StringMap `db:"tags" json:"tags"` } type ProvisionerJobLog struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 608815fd33865..30c577becdbc6 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -10,6 +10,7 @@ import ( "encoding/json" "time" + "github.com/coder/coder/coderd/database/dbtype" "github.com/google/uuid" "github.com/lib/pq" "github.com/tabbed/pqtype" @@ -2243,7 +2244,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one SELECT - id, created_at, updated_at, name, provisioners, replica_id + id, created_at, updated_at, name, provisioners, replica_id, tags FROM provisioner_daemons WHERE @@ -2260,13 +2261,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) &i.Name, pq.Array(&i.Provisioners), &i.ReplicaID, + &i.Tags, ) return i, err } const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many SELECT - id, created_at, updated_at, name, provisioners, replica_id + id, created_at, updated_at, name, provisioners, replica_id, tags FROM provisioner_daemons ` @@ -2287,6 +2289,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa &i.Name, pq.Array(&i.Provisioners), &i.ReplicaID, + &i.Tags, ); err != nil { return nil, err } @@ -2307,10 +2310,11 @@ INSERT INTO id, created_at, "name", - provisioners + provisioners, + tags ) VALUES - ($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id + ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at, name, provisioners, replica_id, tags ` type InsertProvisionerDaemonParams struct { @@ -2318,6 +2322,7 @@ type InsertProvisionerDaemonParams struct { CreatedAt time.Time `db:"created_at" json:"created_at"` Name string `db:"name" json:"name"` Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"` + Tags dbtype.StringMap `db:"tags" json:"tags"` } func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error) { @@ -2326,6 +2331,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv arg.CreatedAt, arg.Name, pq.Array(arg.Provisioners), + arg.Tags, ) var i ProvisionerDaemon err := row.Scan( @@ -2335,6 +2341,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv &i.Name, pq.Array(&i.Provisioners), &i.ReplicaID, + &i.Tags, ) return i, err } @@ -2487,19 +2494,22 @@ WHERE AND nested.canceled_at IS NULL AND nested.completed_at IS NULL AND nested.provisioner = ANY($3 :: provisioner_type [ ]) + -- Ensure the caller satisfies all job tags. + AND nested.tags <@ $4 :: jsonb ORDER BY nested.created_at FOR UPDATE SKIP LOCKED LIMIT 1 - ) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id + ) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags ` type AcquireProvisionerJobParams struct { StartedAt sql.NullTime `db:"started_at" json:"started_at"` WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` Types []ProvisionerType `db:"types" json:"types"` + Tags json.RawMessage `db:"tags" json:"tags"` } // Acquires the lock for a single job that isn't started, completed, @@ -2509,7 +2519,12 @@ type AcquireProvisionerJobParams struct { // multiple provisioners from acquiring the same jobs. See: // https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) { - row := q.db.QueryRowContext(ctx, acquireProvisionerJob, arg.StartedAt, arg.WorkerID, pq.Array(arg.Types)) + row := q.db.QueryRowContext(ctx, acquireProvisionerJob, + arg.StartedAt, + arg.WorkerID, + pq.Array(arg.Types), + arg.Tags, + ) var i ProvisionerJob err := row.Scan( &i.ID, @@ -2527,13 +2542,14 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi &i.Input, &i.WorkerID, &i.FileID, + &i.Tags, ) return i, err } const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one SELECT - id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id + id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags FROM provisioner_jobs WHERE @@ -2559,13 +2575,14 @@ func (q *sqlQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (P &i.Input, &i.WorkerID, &i.FileID, + &i.Tags, ) return i, err } const getProvisionerJobsByIDs = `-- name: GetProvisionerJobsByIDs :many SELECT - id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id + id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags FROM provisioner_jobs WHERE @@ -2597,6 +2614,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI &i.Input, &i.WorkerID, &i.FileID, + &i.Tags, ); err != nil { return nil, err } @@ -2612,7 +2630,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI } const getProvisionerJobsCreatedAfter = `-- name: GetProvisionerJobsCreatedAfter :many -SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id FROM provisioner_jobs WHERE created_at > $1 +SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags FROM provisioner_jobs WHERE created_at > $1 ` func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) { @@ -2640,6 +2658,7 @@ func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, created &i.Input, &i.WorkerID, &i.FileID, + &i.Tags, ); err != nil { return nil, err } @@ -2666,10 +2685,11 @@ INSERT INTO storage_method, file_id, "type", - "input" + "input", + tags ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags ` type InsertProvisionerJobParams struct { @@ -2683,6 +2703,7 @@ type InsertProvisionerJobParams struct { FileID uuid.UUID `db:"file_id" json:"file_id"` Type ProvisionerJobType `db:"type" json:"type"` Input json.RawMessage `db:"input" json:"input"` + Tags dbtype.StringMap `db:"tags" json:"tags"` } func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) { @@ -2697,6 +2718,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi arg.FileID, arg.Type, arg.Input, + arg.Tags, ) var i ProvisionerJob err := row.Scan( @@ -2715,6 +2737,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi &i.Input, &i.WorkerID, &i.FileID, + &i.Tags, ) return i, err } diff --git a/coderd/database/queries/provisionerdaemons.sql b/coderd/database/queries/provisionerdaemons.sql index 30ff6d9d43eda..65908876e8a36 100644 --- a/coderd/database/queries/provisionerdaemons.sql +++ b/coderd/database/queries/provisionerdaemons.sql @@ -18,10 +18,11 @@ INSERT INTO id, created_at, "name", - provisioners + provisioners, + tags ) VALUES - ($1, $2, $3, $4) RETURNING *; + ($1, $2, $3, $4, $5) RETURNING *; -- name: UpdateProvisionerDaemonByID :exec UPDATE diff --git a/coderd/database/queries/provisionerjobs.sql b/coderd/database/queries/provisionerjobs.sql index 6c4097b6b2ef5..f3013bf0dbde4 100644 --- a/coderd/database/queries/provisionerjobs.sql +++ b/coderd/database/queries/provisionerjobs.sql @@ -22,6 +22,8 @@ WHERE AND nested.canceled_at IS NULL AND nested.completed_at IS NULL AND nested.provisioner = ANY(@types :: provisioner_type [ ]) + -- Ensure the caller satisfies all job tags. + AND nested.tags <@ @tags :: jsonb ORDER BY nested.created_at FOR UPDATE @@ -61,10 +63,11 @@ INSERT INTO storage_method, file_id, "type", - "input" + "input", + tags ) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *; + ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *; -- name: UpdateProvisionerJobByID :exec UPDATE diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 5bff20ea4d8f3..564a82eba1309 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -17,6 +17,10 @@ packages: output_db_file_name: db_tmp.go overrides: + - column: "provisioner_daemons.tags" + go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap" + - column: "provisioner_jobs.tags" + go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap" - column: "users.rbac_roles" go_type: "github.com/lib/pq.StringArray" - column: "templates.user_acl" diff --git a/coderd/provisionerdaemons.go b/coderd/provisionerdaemons.go deleted file mode 100644 index fb21c75956aaf..0000000000000 --- a/coderd/provisionerdaemons.go +++ /dev/null @@ -1,113 +0,0 @@ -package coderd - -import ( - "context" - "database/sql" - "errors" - "fmt" - "io" - "net/http" - "time" - - "github.com/google/uuid" - "github.com/moby/moby/pkg/namesgenerator" - "golang.org/x/xerrors" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - - "cdr.dev/slog" - - "github.com/coder/coder/coderd/database" - "github.com/coder/coder/coderd/httpapi" - "github.com/coder/coder/coderd/provisionerdserver" - "github.com/coder/coder/coderd/rbac" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/provisionerd/proto" - "github.com/coder/coder/provisionersdk" -) - -func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - daemons, err := api.Database.GetProvisionerDaemons(ctx) - if errors.Is(err, sql.ErrNoRows) { - err = nil - } - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner daemons.", - Detail: err.Error(), - }) - return - } - if daemons == nil { - daemons = []database.ProvisionerDaemon{} - } - daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching provisioner daemons.", - Detail: err.Error(), - }) - return - } - - httpapi.Write(ctx, rw, http.StatusOK, daemons) -} - -// ListenProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd -// in the same process. -func (api *API) ListenProvisionerDaemon(ctx context.Context, acquireJobDebounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) { - clientSession, serverSession := provisionersdk.TransportPipe() - defer func() { - if err != nil { - _ = clientSession.Close() - _ = serverSession.Close() - } - }() - - name := namesgenerator.GetRandomName(1) - daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{ - ID: uuid.New(), - CreatedAt: database.Now(), - Name: name, - Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform}, - }) - if err != nil { - return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err) - } - - mux := drpcmux.New() - err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ - AccessURL: api.AccessURL, - ID: daemon.ID, - Database: api.Database, - Pubsub: api.Pubsub, - Provisioners: daemon.Provisioners, - Telemetry: api.Telemetry, - Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), - AcquireJobDebounce: acquireJobDebounce, - QuotaCommitter: &api.QuotaCommitter, - }) - if err != nil { - return nil, err - } - server := drpcserver.NewWithOptions(mux, drpcserver.Options{ - Log: func(err error) { - if xerrors.Is(err, io.EOF) { - return - } - api.Logger.Debug(ctx, "drpc server error", slog.Error(err)) - }, - }) - go func() { - err := server.Serve(ctx, serverSession) - if err != nil && !xerrors.Is(err, io.EOF) { - api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err)) - } - // close the sessions so we don't leak goroutines serving them. - _ = clientSession.Close() - _ = serverSession.Close() - }() - - return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil -} diff --git a/coderd/provisionerdaemons_test.go b/coderd/provisionerdaemons_test.go deleted file mode 100644 index d3b0be35cd020..0000000000000 --- a/coderd/provisionerdaemons_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package coderd_test - -import ( - "context" - "crypto/rand" - "runtime" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/coderd/coderdtest" - "github.com/coder/coder/codersdk" - "github.com/coder/coder/provisionersdk" - "github.com/coder/coder/testutil" -) - -func TestProvisionerDaemons(t *testing.T) { - t.Parallel() - t.Run("PayloadTooBig", func(t *testing.T) { - t.Parallel() - if runtime.GOOS == "windows" { - // Takes too long to allocate memory on Windows! - t.Skip() - } - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - user := coderdtest.CreateFirstUser(t, client) - data := make([]byte, provisionersdk.MaxMessageSize) - rand.Read(data) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - resp, err := client.Upload(ctx, codersdk.ContentTypeTar, data) - require.NoError(t, err) - t.Log(resp.ID) - - version, err := client.CreateTemplateVersion(ctx, user.OrganizationID, codersdk.CreateTemplateVersionRequest{ - StorageMethod: codersdk.ProvisionerStorageMethodFile, - FileID: resp.ID, - Provisioner: codersdk.ProvisionerTypeEcho, - }) - require.NoError(t, err) - require.Eventually(t, func() bool { - var err error - version, err = client.TemplateVersion(ctx, version.ID) - return assert.NoError(t, err) && version.Job.Error != "" - }, testutil.WaitShort, testutil.IntervalFast) - }) -} - -func TestProvisionerDaemonsByOrganization(t *testing.T) { - t.Parallel() - t.Run("NoAuth", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, nil) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - _, err := client.ProvisionerDaemons(ctx) - require.Error(t, err) - }) - - t.Run("Get", func(t *testing.T) { - t.Parallel() - client := coderdtest.New(t, nil) - _ = coderdtest.CreateFirstUser(t, client) - - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) - defer cancel() - - _, err := client.ProvisionerDaemons(ctx) - require.NoError(t, err) - }) -} diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 0b267de5b17fc..240d042ee5ee0 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -39,6 +39,7 @@ type Server struct { ID uuid.UUID Logger slog.Logger Provisioners []database.ProvisionerType + Tags json.RawMessage Database database.Store Pubsub database.Pubsub Telemetry telemetry.Reporter @@ -71,6 +72,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac Valid: true, }, Types: server.Provisioners, + Tags: server.Tags, }) if errors.Is(err, sql.ErrNoRows) { // The provisioner daemon assumes no jobs are available if diff --git a/coderd/provisionerdserver/provisionertags.go b/coderd/provisionerdserver/provisionertags.go new file mode 100644 index 0000000000000..7c9e029839d35 --- /dev/null +++ b/coderd/provisionerdserver/provisionertags.go @@ -0,0 +1,33 @@ +package provisionerdserver + +import "github.com/google/uuid" + +const ( + TagScope = "scope" + TagOwner = "owner" + + ScopeUser = "user" + ScopeOrganization = "organization" +) + +// MutateTags adjusts the "owner" tag dependent on the "scope". +// If the scope is "user", the "owner" is changed to the user ID. +// This is for user-scoped provisioner daemons, where users should +// own their own operations. +func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string { + if tags == nil { + tags = map[string]string{} + } + _, ok := tags[TagScope] + if !ok { + tags[TagScope] = ScopeOrganization + } + switch tags[TagScope] { + case ScopeUser: + tags[TagOwner] = userID.String() + case ScopeOrganization: + default: + tags[TagScope] = ScopeOrganization + } + return tags +} diff --git a/coderd/provisionerjobs.go b/coderd/provisionerjobs.go index c9eedf8dd68cc..130a42beb8b2e 100644 --- a/coderd/provisionerjobs.go +++ b/coderd/provisionerjobs.go @@ -131,10 +131,10 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job return } - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() conn, err := websocket.Accept(rw, r, nil) if err != nil { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -312,6 +312,7 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov CreatedAt: provisionerJob.CreatedAt, Error: provisionerJob.Error.String, FileID: provisionerJob.FileID, + Tags: provisionerJob.Tags, } // Applying values optional to the struct. if provisionerJob.StartedAt.Valid { diff --git a/coderd/templateversions.go b/coderd/templateversions.go index c5a962b20907a..2952baede44e8 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -291,6 +291,8 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques FileID: job.FileID, Type: database.ProvisionerJobTypeTemplateVersionDryRun, Input: input, + // Copy tags from the previous run. + Tags: job.Tags, }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -764,6 +766,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht return } + // Ensures the "owner" is properly applied. + tags := provisionerdserver.MutateTags(apiKey.UserID, req.ProvisionerTags) + file, err := api.Database.GetFileByID(ctx, req.FileID) if errors.Is(err, sql.ErrNoRows) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ @@ -862,6 +867,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht FileID: file.ID, Type: database.ProvisionerJobTypeTemplateVersionImport, Input: []byte{'{', '}'}, + Tags: tags, }) if err != nil { return xerrors.Errorf("insert provisioner job: %w", err) diff --git a/coderd/templateversions_test.go b/coderd/templateversions_test.go index b9c17e98a92f5..65c1ec31d06e6 100644 --- a/coderd/templateversions_test.go +++ b/coderd/templateversions_test.go @@ -13,6 +13,7 @@ import ( "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/provisionerdserver" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" @@ -122,6 +123,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) { }) require.NoError(t, err) require.Equal(t, "bananas", version.Name) + require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope]) require.Len(t, auditor.AuditLogs, 1) assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[0].Action) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 933060d76fb16..8f336d6d4074f 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -181,10 +181,10 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgentParam(r) workspace := httpmw.WorkspaceParam(r) @@ -442,10 +442,10 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) if err != nil { @@ -614,10 +614,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R } } - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgentParam(r) conn, err := websocket.Accept(rw, r, nil) @@ -759,10 +759,10 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordin func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - api.websocketWaitMutex.Lock() - api.websocketWaitGroup.Add(1) - api.websocketWaitMutex.Unlock() - defer api.websocketWaitGroup.Done() + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + defer api.WebsocketWaitGroup.Done() workspaceAgent := httpmw.WorkspaceAgent(r) resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID) diff --git a/coderd/workspacebuilds.go b/coderd/workspacebuilds.go index a797c50031d28..0d823d0a8d82b 100644 --- a/coderd/workspacebuilds.go +++ b/coderd/workspacebuilds.go @@ -428,6 +428,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { return } + tags := provisionerdserver.MutateTags(workspace.OwnerID, templateVersionJob.Tags) + // Store prior build number to compute new build number var priorBuildNum int32 priorHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID) @@ -513,6 +515,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) { StorageMethod: templateVersionJob.StorageMethod, FileID: templateVersionJob.FileID, Input: input, + Tags: tags, }) if err != nil { return xerrors.Errorf("insert provisioner job: %w", err) diff --git a/coderd/workspaces.go b/coderd/workspaces.go index 81016d3b67054..3f26dd6367c53 100644 --- a/coderd/workspaces.go +++ b/coderd/workspaces.go @@ -373,6 +373,8 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req return } + tags := provisionerdserver.MutateTags(user.ID, templateVersionJob.Tags) + var ( provisionerJob database.ProvisionerJob workspaceBuild database.WorkspaceBuild @@ -435,6 +437,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req StorageMethod: templateVersionJob.StorageMethod, FileID: templateVersionJob.FileID, Input: input, + Tags: tags, }) if err != nil { return xerrors.Errorf("insert provisioner job: %w", err) diff --git a/codersdk/features.go b/codersdk/features.go index 07e836715ab6e..3b707a7c42489 100644 --- a/codersdk/features.go +++ b/codersdk/features.go @@ -15,13 +15,14 @@ const ( ) const ( - FeatureUserLimit = "user_limit" - FeatureAuditLog = "audit_log" - FeatureBrowserOnly = "browser_only" - FeatureSCIM = "scim" - FeatureTemplateRBAC = "template_rbac" - FeatureHighAvailability = "high_availability" - FeatureMultipleGitAuth = "multiple_git_auth" + FeatureUserLimit = "user_limit" + FeatureAuditLog = "audit_log" + FeatureBrowserOnly = "browser_only" + FeatureSCIM = "scim" + FeatureTemplateRBAC = "template_rbac" + FeatureHighAvailability = "high_availability" + FeatureMultipleGitAuth = "multiple_git_auth" + FeatureExternalProvisionerDaemons = "external_provisioner_daemons" ) var FeatureNames = []string{ @@ -32,6 +33,7 @@ var FeatureNames = []string{ FeatureTemplateRBAC, FeatureHighAvailability, FeatureMultipleGitAuth, + FeatureExternalProvisionerDaemons, } type Feature struct { diff --git a/codersdk/organizations.go b/codersdk/organizations.go index ce49794123cb4..21ede686d96c6 100644 --- a/codersdk/organizations.go +++ b/codersdk/organizations.go @@ -36,11 +36,12 @@ type Organization struct { type CreateTemplateVersionRequest struct { Name string `json:"name,omitempty" validate:"omitempty,template_name"` // TemplateID optionally associates a version with a template. - TemplateID uuid.UUID `json:"template_id,omitempty"` + TemplateID uuid.UUID `json:"template_id,omitempty"` + StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"` + FileID uuid.UUID `json:"file_id" validate:"required"` + Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"` + ProvisionerTags map[string]string `json:"tags"` - StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"` - FileID uuid.UUID `json:"file_id" validate:"required"` - Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"` // ParameterValues allows for additional parameters to be provided // during the dry-run provision stage. ParameterValues []CreateParameterRequest `json:"parameter_values,omitempty"` diff --git a/codersdk/provisionerdaemons.go b/codersdk/provisionerdaemons.go index 823a80e025794..f80f5af586832 100644 --- a/codersdk/provisionerdaemons.go +++ b/codersdk/provisionerdaemons.go @@ -13,20 +13,22 @@ import ( "time" "github.com/google/uuid" + "github.com/hashicorp/yamux" "golang.org/x/xerrors" "nhooyr.io/websocket" + + "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" ) type LogSource string +type LogLevel string + const ( LogSourceProvisionerDaemon LogSource = "provisioner_daemon" LogSourceProvisioner LogSource = "provisioner" -) -type LogLevel string - -const ( LogLevelTrace LogLevel = "trace" LogLevelDebug LogLevel = "debug" LogLevelInfo LogLevel = "info" @@ -40,6 +42,7 @@ type ProvisionerDaemon struct { UpdatedAt sql.NullTime `json:"updated_at"` Name string `json:"name"` Provisioners []ProvisionerType `json:"provisioners"` + Tags map[string]string `json:"tags"` } // ProvisionerJobStatus represents the at-time state of a job. @@ -73,6 +76,7 @@ type ProvisionerJob struct { Status ProvisionerJobStatus `json:"status"` WorkerID *uuid.UUID `json:"worker_id,omitempty"` FileID uuid.UUID `json:"file_id"` + Tags map[string]string `json:"tags"` } type ProvisionerJobLog struct { @@ -162,3 +166,51 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after return nil }), nil } + +// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation. +func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) { + serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization)) + if err != nil { + return nil, xerrors.Errorf("parse url: %w", err) + } + query := serverURL.Query() + for _, provisioner := range provisioners { + query.Add("provisioner", string(provisioner)) + } + for key, value := range tags { + query.Add("tag", fmt.Sprintf("%s=%s", key, value)) + } + serverURL.RawQuery = query.Encode() + jar, err := cookiejar.New(nil) + if err != nil { + return nil, xerrors.Errorf("create cookie jar: %w", err) + } + jar.SetCookies(serverURL, []*http.Cookie{{ + Name: SessionTokenKey, + Value: c.SessionToken(), + }}) + httpClient := &http.Client{ + Jar: jar, + } + conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{ + HTTPClient: httpClient, + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + if res == nil { + return nil, err + } + return nil, readBodyAsError(res) + } + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) + + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config) + if err != nil { + return nil, xerrors.Errorf("multiplex client: %w", err) + } + return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(session)), nil +} diff --git a/enterprise/cli/provisionerdaemons.go b/enterprise/cli/provisionerdaemons.go new file mode 100644 index 0000000000000..285922d88bc06 --- /dev/null +++ b/enterprise/cli/provisionerdaemons.go @@ -0,0 +1,155 @@ +package cli + +import ( + "context" + "fmt" + "os" + "os/signal" + "time" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/sloghuman" + agpl "github.com/coder/coder/cli" + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/cli/deployment" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisioner/terraform" + "github.com/coder/coder/provisionerd" + provisionerdproto "github.com/coder/coder/provisionerd/proto" + "github.com/coder/coder/provisionersdk" + "github.com/coder/coder/provisionersdk/proto" + + "github.com/spf13/cobra" + "golang.org/x/xerrors" +) + +func provisionerDaemons() *cobra.Command { + cmd := &cobra.Command{ + Use: "provisionerd", + Short: "Manage provisioner daemons", + } + cmd.AddCommand(provisionerDaemonStart()) + + return cmd +} + +func provisionerDaemonStart() *cobra.Command { + var ( + cacheDir string + rawTags []string + ) + cmd := &cobra.Command{ + Use: "start", + Short: "Run a provisioner daemon", + RunE: func(cmd *cobra.Command, args []string) error { + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...) + defer notifyStop() + + client, err := agpl.CreateClient(cmd) + if err != nil { + return xerrors.Errorf("create client: %w", err) + } + org, err := agpl.CurrentOrganization(cmd, client) + if err != nil { + return xerrors.Errorf("get current organization: %w", err) + } + + tags, err := agpl.ParseProvisionerTags(rawTags) + if err != nil { + return err + } + + err = os.MkdirAll(cacheDir, 0o700) + if err != nil { + return xerrors.Errorf("mkdir %q: %w", cacheDir, err) + } + + terraformClient, terraformServer := provisionersdk.TransportPipe() + go func() { + <-ctx.Done() + _ = terraformClient.Close() + _ = terraformServer.Close() + }() + + logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) + errCh := make(chan error, 1) + go func() { + defer cancel() + + err := terraform.Serve(ctx, &terraform.ServeOptions{ + ServeOptions: &provisionersdk.ServeOptions{ + Listener: terraformServer, + }, + CachePath: cacheDir, + Logger: logger.Named("terraform"), + }) + if err != nil && !xerrors.Is(err, context.Canceled) { + select { + case errCh <- err: + default: + } + } + }() + + tempDir, err := os.MkdirTemp("", "provisionerd") + if err != nil { + return err + } + + logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags)) + + provisioners := provisionerd.Provisioners{ + string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(provisionersdk.Conn(terraformClient)), + } + srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) { + return client.ServeProvisionerDaemon(ctx, org.ID, []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeTerraform, + }, tags) + }, &provisionerd.Options{ + Logger: logger, + PollInterval: 500 * time.Millisecond, + UpdateInterval: 500 * time.Millisecond, + Provisioners: provisioners, + WorkDirectory: tempDir, + }) + + var exitErr error + select { + case <-notifyCtx.Done(): + exitErr = notifyCtx.Err() + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render( + "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", + )) + case exitErr = <-errCh: + } + if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) { + cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr) + } + + shutdown, shutdownCancel := context.WithTimeout(ctx, time.Minute) + defer shutdownCancel() + err = srv.Shutdown(shutdown) + if err != nil { + return xerrors.Errorf("shutdown: %w", err) + } + + cancel() + if xerrors.Is(exitErr, context.Canceled) { + return nil + } + return exitErr + }, + } + + cliflag.StringVarP(cmd.Flags(), &cacheDir, "cache-dir", "c", "CODER_CACHE_DIRECTORY", deployment.DefaultCacheDir(), + "Specify a directory to cache provisioner job files.") + cliflag.StringArrayVarP(cmd.Flags(), &rawTags, "tag", "t", "CODER_PROVISIONERD_TAGS", []string{}, + "Specify a list of tags to target provisioner jobs.") + + return cmd +} diff --git a/enterprise/cli/root.go b/enterprise/cli/root.go index 41337f14c77dd..269bb7c0f1615 100644 --- a/enterprise/cli/root.go +++ b/enterprise/cli/root.go @@ -12,6 +12,7 @@ func enterpriseOnly() []*cobra.Command { features(), licenses(), groups(), + provisionerDaemons(), } } diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 10c40b1b38718..5c3b88d936576 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -90,7 +90,15 @@ func New(ctx context.Context, options *Options) (*API, error) { r.Get("/", api.group) }) }) - + r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) { + r.Use( + api.provisionerDaemonsEnabledMW, + apiKeyMiddleware, + httpmw.ExtractOrganizationParam(api.Database), + ) + r.Get("/", api.provisionerDaemons) + r.Get("/serve", api.provisionerDaemonServe) + }) r.Route("/templates/{template}/acl", func(r chi.Router) { r.Use( api.templateRBACEnabledMW, @@ -100,7 +108,6 @@ func New(ctx context.Context, options *Options) (*API, error) { r.Get("/", api.templateACL) r.Patch("/", api.patchTemplateACL) }) - r.Route("/groups/{group}", func(r chi.Router) { r.Use( api.templateRBACEnabledMW, @@ -111,7 +118,6 @@ func New(ctx context.Context, options *Options) (*API, error) { r.Patch("/", api.patchGroup) r.Delete("/", api.deleteGroup) }) - r.Route("/workspace-quota", func(r chi.Router) { r.Use( apiKeyMiddleware, @@ -222,12 +228,13 @@ func (api *API) updateEntitlements(ctx context.Context) error { defer api.entitlementsMu.Unlock() entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), len(api.GitAuthConfigs), api.Keys, map[string]bool{ - codersdk.FeatureAuditLog: api.AuditLogging, - codersdk.FeatureBrowserOnly: api.BrowserOnly, - codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, - codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "", - codersdk.FeatureMultipleGitAuth: len(api.GitAuthConfigs) > 1, - codersdk.FeatureTemplateRBAC: api.RBAC, + codersdk.FeatureAuditLog: api.AuditLogging, + codersdk.FeatureBrowserOnly: api.BrowserOnly, + codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0, + codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "", + codersdk.FeatureMultipleGitAuth: len(api.GitAuthConfigs) > 1, + codersdk.FeatureTemplateRBAC: api.RBAC, + codersdk.FeatureExternalProvisionerDaemons: true, }) if err != nil { return err diff --git a/enterprise/coderd/coderd_test.go b/enterprise/coderd/coderd_test.go index a6c3c6d86973b..6d8e90c9b185d 100644 --- a/enterprise/coderd/coderd_test.go +++ b/enterprise/coderd/coderd_test.go @@ -41,9 +41,10 @@ func TestEntitlements(t *testing.T) { }) _ = coderdtest.CreateFirstUser(t, client) coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - TemplateRBAC: true, + UserLimit: 100, + AuditLog: true, + TemplateRBAC: true, + ExternalProvisionerDaemons: true, }) res, err := client.Entitlements(context.Background()) require.NoError(t, err) diff --git a/enterprise/coderd/coderdenttest/coderdenttest.go b/enterprise/coderd/coderdenttest/coderdenttest.go index 3348227c7d29b..0b36f938a8265 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest.go +++ b/enterprise/coderd/coderdenttest/coderdenttest.go @@ -99,19 +99,20 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c } type LicenseOptions struct { - AccountType string - AccountID string - Trial bool - AllFeatures bool - GraceAt time.Time - ExpiresAt time.Time - UserLimit int64 - AuditLog bool - BrowserOnly bool - SCIM bool - TemplateRBAC bool - HighAvailability bool - MultipleGitAuth bool + AccountType string + AccountID string + Trial bool + AllFeatures bool + GraceAt time.Time + ExpiresAt time.Time + UserLimit int64 + AuditLog bool + BrowserOnly bool + SCIM bool + TemplateRBAC bool + HighAvailability bool + MultipleGitAuth bool + ExternalProvisionerDaemons bool } // AddLicense generates a new license with the options provided and inserts it. @@ -158,6 +159,11 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { multipleGitAuth = 1 } + externalProvisionerDaemons := int64(0) + if options.ExternalProvisionerDaemons { + externalProvisionerDaemons = 1 + } + c := &license.Claims{ RegisteredClaims: jwt.RegisteredClaims{ Issuer: "test@testing.test", @@ -172,13 +178,14 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string { Version: license.CurrentVersion, AllFeatures: options.AllFeatures, Features: license.Features{ - UserLimit: options.UserLimit, - AuditLog: auditLog, - BrowserOnly: browserOnly, - SCIM: scim, - HighAvailability: highAvailability, - TemplateRBAC: rbacEnabled, - MultipleGitAuth: multipleGitAuth, + UserLimit: options.UserLimit, + AuditLog: auditLog, + BrowserOnly: browserOnly, + SCIM: scim, + HighAvailability: highAvailability, + TemplateRBAC: rbacEnabled, + MultipleGitAuth: multipleGitAuth, + ExternalProvisionerDaemons: externalProvisionerDaemons, }, } tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c) diff --git a/enterprise/coderd/coderdenttest/coderdenttest_test.go b/enterprise/coderd/coderdenttest/coderdenttest_test.go index 319c805163271..e1a99291cd9f7 100644 --- a/enterprise/coderd/coderdenttest/coderdenttest_test.go +++ b/enterprise/coderd/coderdenttest/coderdenttest_test.go @@ -33,7 +33,8 @@ func TestAuthorizeAllEndpoints(t *testing.T) { ctx, _ := testutil.Context(t) admin := coderdtest.CreateFirstUser(t, client) license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ - TemplateRBAC: true, + TemplateRBAC: true, + ExternalProvisionerDaemons: true, }) group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{ Name: "testgroup", @@ -47,6 +48,8 @@ func TestAuthorizeAllEndpoints(t *testing.T) { a.URLParams["{groupName}"] = group.Name skipRoutes, assertRoute := coderdtest.AGPLRoutes(a) + skipRoutes["GET:/api/v2/organizations/{organization}/provisionerdaemons/serve"] = "This route checks for RBAC dependent on input parameters!" + assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{ NoAuthorize: true, } @@ -84,6 +87,14 @@ func TestAuthorizeAllEndpoints(t *testing.T) { AssertAction: rbac.ActionRead, AssertObject: groupObj, } + assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceProvisionerDaemon, + } + assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{ + AssertAction: rbac.ActionRead, + AssertObject: rbac.ResourceProvisionerDaemon, + } assertRoute["GET:/api/v2/groups/{group}"] = coderdtest.RouteCheck{ AssertAction: rbac.ActionRead, AssertObject: groupObj, diff --git a/enterprise/coderd/license/license.go b/enterprise/coderd/license/license.go index 5307c490e3ae5..ab37354a46bbe 100644 --- a/enterprise/coderd/license/license.go +++ b/enterprise/coderd/license/license.go @@ -117,6 +117,12 @@ func Entitlements( Enabled: true, } } + if claims.Features.ExternalProvisionerDaemons > 0 { + entitlements.Features[codersdk.FeatureExternalProvisionerDaemons] = codersdk.Feature{ + Entitlement: entitlement, + Enabled: true, + } + } if claims.AllFeatures { allFeatures = true } @@ -238,13 +244,14 @@ var ( ) type Features struct { - UserLimit int64 `json:"user_limit"` - AuditLog int64 `json:"audit_log"` - BrowserOnly int64 `json:"browser_only"` - SCIM int64 `json:"scim"` - TemplateRBAC int64 `json:"template_rbac"` - HighAvailability int64 `json:"high_availability"` - MultipleGitAuth int64 `json:"multiple_git_auth"` + UserLimit int64 `json:"user_limit"` + AuditLog int64 `json:"audit_log"` + BrowserOnly int64 `json:"browser_only"` + SCIM int64 `json:"scim"` + TemplateRBAC int64 `json:"template_rbac"` + HighAvailability int64 `json:"high_availability"` + MultipleGitAuth int64 `json:"multiple_git_auth"` + ExternalProvisionerDaemons int64 `json:"external_provisioner_daemons"` } type Claims struct { diff --git a/enterprise/coderd/license/license_test.go b/enterprise/coderd/license/license_test.go index d1262e0833178..c0833ed9c594b 100644 --- a/enterprise/coderd/license/license_test.go +++ b/enterprise/coderd/license/license_test.go @@ -20,12 +20,13 @@ import ( func TestEntitlements(t *testing.T) { t.Parallel() all := map[string]bool{ - codersdk.FeatureAuditLog: true, - codersdk.FeatureBrowserOnly: true, - codersdk.FeatureSCIM: true, - codersdk.FeatureHighAvailability: true, - codersdk.FeatureTemplateRBAC: true, - codersdk.FeatureMultipleGitAuth: true, + codersdk.FeatureAuditLog: true, + codersdk.FeatureBrowserOnly: true, + codersdk.FeatureSCIM: true, + codersdk.FeatureHighAvailability: true, + codersdk.FeatureTemplateRBAC: true, + codersdk.FeatureMultipleGitAuth: true, + codersdk.FeatureExternalProvisionerDaemons: true, } t.Run("Defaults", func(t *testing.T) { @@ -61,13 +62,14 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - HighAvailability: true, - TemplateRBAC: true, - MultipleGitAuth: true, + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + HighAvailability: true, + TemplateRBAC: true, + MultipleGitAuth: true, + ExternalProvisionerDaemons: true, }), Exp: time.Now().Add(time.Hour), }) @@ -84,14 +86,15 @@ func TestEntitlements(t *testing.T) { db := databasefake.New() db.InsertLicense(context.Background(), database.InsertLicenseParams{ JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{ - UserLimit: 100, - AuditLog: true, - BrowserOnly: true, - SCIM: true, - HighAvailability: true, - TemplateRBAC: true, - GraceAt: time.Now().Add(-time.Hour), - ExpiresAt: time.Now().Add(time.Hour), + UserLimit: 100, + AuditLog: true, + BrowserOnly: true, + SCIM: true, + HighAvailability: true, + TemplateRBAC: true, + ExternalProvisionerDaemons: true, + GraceAt: time.Now().Add(-time.Hour), + ExpiresAt: time.Now().Add(time.Hour), }), Exp: time.Now().Add(time.Hour), }) diff --git a/enterprise/coderd/licenses_test.go b/enterprise/coderd/licenses_test.go index 0605ff2f742c1..4e105b7831e08 100644 --- a/enterprise/coderd/licenses_test.go +++ b/enterprise/coderd/licenses_test.go @@ -101,25 +101,27 @@ func TestGetLicense(t *testing.T) { assert.Equal(t, int32(1), licenses[0].ID) assert.Equal(t, "testing", licenses[0].Claims["account_id"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("0"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureHighAvailability: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("1"), - codersdk.FeatureMultipleGitAuth: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("0"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("1"), + codersdk.FeatureMultipleGitAuth: json.Number("0"), + codersdk.FeatureExternalProvisionerDaemons: json.Number("0"), }, licenses[0].Claims["features"]) assert.Equal(t, int32(2), licenses[1].ID) assert.Equal(t, "testing2", licenses[1].Claims["account_id"]) assert.Equal(t, true, licenses[1].Claims["trial"]) assert.Equal(t, map[string]interface{}{ - codersdk.FeatureUserLimit: json.Number("200"), - codersdk.FeatureAuditLog: json.Number("1"), - codersdk.FeatureSCIM: json.Number("1"), - codersdk.FeatureBrowserOnly: json.Number("1"), - codersdk.FeatureHighAvailability: json.Number("0"), - codersdk.FeatureTemplateRBAC: json.Number("0"), - codersdk.FeatureMultipleGitAuth: json.Number("0"), + codersdk.FeatureUserLimit: json.Number("200"), + codersdk.FeatureAuditLog: json.Number("1"), + codersdk.FeatureSCIM: json.Number("1"), + codersdk.FeatureBrowserOnly: json.Number("1"), + codersdk.FeatureHighAvailability: json.Number("0"), + codersdk.FeatureTemplateRBAC: json.Number("0"), + codersdk.FeatureMultipleGitAuth: json.Number("0"), + codersdk.FeatureExternalProvisionerDaemons: json.Number("0"), }, licenses[1].Claims["features"]) }) } diff --git a/enterprise/coderd/provisionerdaemons.go b/enterprise/coderd/provisionerdaemons.go new file mode 100644 index 0000000000000..e9f357f970e30 --- /dev/null +++ b/enterprise/coderd/provisionerdaemons.go @@ -0,0 +1,245 @@ +package coderd + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/google/uuid" + "github.com/hashicorp/yamux" + "github.com/moby/moby/pkg/namesgenerator" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + + "cdr.dev/slog" + + "github.com/coder/coder/coderd" + "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/coderd/provisionerdserver" + "github.com/coder/coder/coderd/rbac" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/provisionerd/proto" +) + +func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + api.entitlementsMu.RLock() + epd := api.entitlements.Features[codersdk.FeatureExternalProvisionerDaemons].Enabled + api.entitlementsMu.RUnlock() + + if !epd { + httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ + Message: "External provisioner daemons is an Enterprise feature. Contact sales!", + }) + return + } + + next.ServeHTTP(rw, r) + }) +} + +func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + org := httpmw.OrganizationParam(r) + if !api.Authorize(r, rbac.ActionRead, rbac.ResourceProvisionerDaemon.InOrg(org.ID)) { + httpapi.Forbidden(rw) + return + } + daemons, err := api.Database.GetProvisionerDaemons(ctx) + if errors.Is(err, sql.ErrNoRows) { + err = nil + } + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner daemons.", + Detail: err.Error(), + }) + return + } + if daemons == nil { + daemons = []database.ProvisionerDaemon{} + } + daemons, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, daemons) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching provisioner daemons.", + Detail: err.Error(), + }) + return + } + apiDaemons := make([]codersdk.ProvisionerDaemon, 0) + for _, daemon := range daemons { + apiDaemons = append(apiDaemons, convertProvisionerDaemon(daemon)) + } + httpapi.Write(ctx, rw, http.StatusOK, apiDaemons) +} + +// Serves the provisioner daemon protobuf API over a WebSocket. +func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) { + tags := map[string]string{} + if r.URL.Query().Has("tag") { + for _, tag := range r.URL.Query()["tag"] { + parts := strings.SplitN(tag, "=", 2) + if len(parts) < 2 { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag), + }) + return + } + tags[parts[0]] = parts[1] + } + } + if !r.URL.Query().Has("provisioner") { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "The provisioner query parameter must be specified.", + }) + return + } + + provisionersMap := map[codersdk.ProvisionerType]struct{}{} + for _, provisioner := range r.URL.Query()["provisioner"] { + switch provisioner { + case string(codersdk.ProvisionerTypeEcho): + provisionersMap[codersdk.ProvisionerTypeEcho] = struct{}{} + case string(codersdk.ProvisionerTypeTerraform): + provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{} + default: + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Unknown provisioner type %q", provisioner), + }) + return + } + } + + // Any authenticated user can create provisioner daemons scoped + // for jobs that they own, but only authorized users can create + // globally scoped provisioners that attach to all jobs. + apiKey := httpmw.APIKey(r) + tags = provisionerdserver.MutateTags(apiKey.UserID, tags) + + if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization { + if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) { + httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{ + Message: "You aren't allowed to create provisioner daemons for the organization.", + }) + return + } + } + + provisioners := make([]database.ProvisionerType, 0) + for p := range provisionersMap { + switch p { + case codersdk.ProvisionerTypeTerraform: + provisioners = append(provisioners, database.ProvisionerTypeTerraform) + case codersdk.ProvisionerTypeEcho: + provisioners = append(provisioners, database.ProvisionerTypeEcho) + } + } + + name := namesgenerator.GetRandomName(1) + daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{ + ID: uuid.New(), + CreatedAt: database.Now(), + Name: name, + Provisioners: provisioners, + Tags: tags, + }) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error writing provisioner daemon.", + Detail: err.Error(), + }) + return + } + + rawTags, err := json.Marshal(daemon.Tags) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error marshaling daemon tags.", + Detail: err.Error(), + }) + return + } + + api.AGPL.WebsocketWaitMutex.Lock() + api.AGPL.WebsocketWaitGroup.Add(1) + api.AGPL.WebsocketWaitMutex.Unlock() + defer api.AGPL.WebsocketWaitGroup.Done() + + conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{ + // Need to disable compression to avoid a data-race. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Internal error accepting websocket connection.", + Detail: err.Error(), + }) + return + } + // Align with the frame size of yamux. + conn.SetReadLimit(256 * 1024) + + // Multiplexes the incoming connection using yamux. + // This allows multiple function calls to occur over + // the same connection. + config := yamux.DefaultConfig() + config.LogOutput = io.Discard + session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err)) + return + } + mux := drpcmux.New() + err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{ + AccessURL: api.AccessURL, + ID: daemon.ID, + Database: api.Database, + Pubsub: api.Pubsub, + Provisioners: daemon.Provisioners, + Telemetry: api.Telemetry, + Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)), + Tags: rawTags, + }) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err)) + return + } + server := drpcserver.NewWithOptions(mux, drpcserver.Options{ + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err)) + }, + }) + err = server.Serve(r.Context(), session) + if err != nil && !xerrors.Is(err, io.EOF) { + api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err)) + _ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err)) + return + } + _ = conn.Close(websocket.StatusGoingAway, "") +} + +func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.ProvisionerDaemon { + result := codersdk.ProvisionerDaemon{ + ID: daemon.ID, + CreatedAt: daemon.CreatedAt, + UpdatedAt: daemon.UpdatedAt, + Name: daemon.Name, + Tags: daemon.Tags, + } + for _, provisionerType := range daemon.Provisioners { + result.Provisioners = append(result.Provisioners, codersdk.ProvisionerType(provisionerType)) + } + return result +} diff --git a/enterprise/coderd/provisionerdaemons_test.go b/enterprise/coderd/provisionerdaemons_test.go new file mode 100644 index 0000000000000..f603f3569e807 --- /dev/null +++ b/enterprise/coderd/provisionerdaemons_test.go @@ -0,0 +1,139 @@ +package coderd_test + +import ( + "context" + "net/http" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/provisionerdserver" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/provisioner/echo" + "github.com/coder/coder/provisionersdk/proto" +) + +func TestProvisionerDaemonServe(t *testing.T) { + t.Parallel() + t.Run("NoLicense", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + _, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, map[string]string{}) + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + }) + + t.Run("Organization", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + ExternalProvisionerDaemons: true, + }) + srv, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, map[string]string{}) + require.NoError(t, err) + srv.DRPCConn().Close() + }) + + t.Run("OrganizationNoPerms", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + ExternalProvisionerDaemons: true, + }) + another := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + _, err := another.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{ + codersdk.ProvisionerTypeEcho, + }, map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeOrganization, + }) + require.Error(t, err) + var apiError *codersdk.Error + require.ErrorAs(t, err, &apiError) + require.Equal(t, http.StatusForbidden, apiError.StatusCode()) + }) + + t.Run("UserLocal", func(t *testing.T) { + t.Parallel() + client := coderdenttest.New(t, nil) + user := coderdtest.CreateFirstUser(t, client) + coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{ + ExternalProvisionerDaemons: true, + }) + closer := coderdtest.NewExternalProvisionerDaemon(t, client, user.OrganizationID, map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeUser, + }) + defer closer.Close() + + authToken := uuid.NewString() + data, err := echo.Tar(&echo.Responses{ + Parse: echo.ParseComplete, + ProvisionPlan: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + }}, + }}, + }, + }, + }}, + ProvisionApply: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "example", + Type: "aws_instance", + Agents: []*proto.Agent{{ + Id: uuid.NewString(), + Name: "example", + Auth: &proto.Agent_Token{ + Token: authToken, + }, + }}, + }}, + }, + }, + }}, + }) + require.NoError(t, err) + file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data) + require.NoError(t, err) + + version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{ + Name: "example", + StorageMethod: codersdk.ProvisionerStorageMethodFile, + FileID: file.ID, + Provisioner: codersdk.ProvisionerTypeEcho, + ProvisionerTags: map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeUser, + }, + }) + require.NoError(t, err) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + another := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + _ = closer.Close() + closer = coderdtest.NewExternalProvisionerDaemon(t, another, user.OrganizationID, map[string]string{ + provisionerdserver.TagScope: provisionerdserver.ScopeUser, + }) + defer closer.Close() + workspace := coderdtest.CreateWorkspace(t, another, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + }) +} diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index ce16ff2709172..831e5d9640692 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -58,6 +58,9 @@ type Options struct { // New creates and starts a provisioner daemon. func New(clientDialer Dialer, opts *Options) *Server { + if opts == nil { + opts = &Options{} + } if opts.PollInterval == 0 { opts.PollInterval = 5 * time.Second } diff --git a/site/src/api/api.ts b/site/src/api/api.ts index d3cf623453679..8b344af736936 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -69,12 +69,14 @@ export const provisioners: TypesGen.ProvisionerDaemon[] = [ name: "Terraform", created_at: "", provisioners: [], + tags: {}, }, { id: "cdr-basic", name: "Basic", created_at: "", provisioners: [], + tags: {}, }, ] diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 0e27b3853801d..7faaf7357c22c 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -197,6 +197,7 @@ export interface CreateTemplateVersionRequest { readonly storage_method: ProvisionerStorageMethod readonly file_id: string readonly provisioner: ProvisionerType + readonly tags: Record readonly parameter_values?: CreateParameterRequest[] } @@ -540,6 +541,7 @@ export interface ProvisionerDaemon { readonly updated_at?: string readonly name: string readonly provisioners: ProvisionerType[] + readonly tags: Record } // From codersdk/provisionerdaemons.go @@ -553,6 +555,7 @@ export interface ProvisionerJob { readonly status: ProvisionerJobStatus readonly worker_id?: string readonly file_id: string + readonly tags: Record } // From codersdk/provisionerdaemons.go diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index c43d7a64d19a8..7d1e5213518b7 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -131,6 +131,7 @@ export const MockProvisioner: TypesGen.ProvisionerDaemon = { id: "test-provisioner", name: "Test Provisioner", provisioners: ["echo"], + tags: {}, } export const MockProvisionerJob: TypesGen.ProvisionerJob = { @@ -139,6 +140,7 @@ export const MockProvisionerJob: TypesGen.ProvisionerJob = { status: "succeeded", file_id: "fc0774ce-cc9e-48d4-80ae-88f7a4d4a8b0", completed_at: "2022-05-17T17:39:01.382927298Z", + tags: {}, } export const MockFailedProvisionerJob: TypesGen.ProvisionerJob = {