diff --git a/cli/parameterslist.go b/cli/parameterslist.go index 8955ba96ca0c3..495b48247ff31 100644 --- a/cli/parameterslist.go +++ b/cli/parameterslist.go @@ -45,16 +45,23 @@ func parameterList() *cobra.Command { return xerrors.Errorf("get workspace template: %w", err) } scopeID = template.ID - - case codersdk.ParameterScopeImportJob, "template_version": - scope = string(codersdk.ParameterScopeImportJob) + case codersdk.ParameterImportJob, "template_version": + scope = string(codersdk.ParameterImportJob) scopeID, err = uuid.Parse(name) if err != nil { return xerrors.Errorf("%q must be a uuid for this scope type", name) } + + // Could be a template_version id or a job id. Check for the + // version id. + tv, err := client.TemplateVersion(cmd.Context(), scopeID) + if err == nil { + scopeID = tv.Job.ID + } + default: return xerrors.Errorf("%q is an unsupported scope, use %v", scope, []codersdk.ParameterScope{ - codersdk.ParameterWorkspace, codersdk.ParameterTemplate, codersdk.ParameterScopeImportJob, + codersdk.ParameterWorkspace, codersdk.ParameterTemplate, codersdk.ParameterImportJob, }) } diff --git a/cli/templatecreate.go b/cli/templatecreate.go index 457bf5cca784b..6ec0f86b4021e 100644 --- a/cli/templatecreate.go +++ b/cli/templatecreate.go @@ -82,7 +82,13 @@ func templateCreate() *cobra.Command { } spin.Stop() - job, parameters, err := createValidTemplateVersion(cmd, client, organization, database.ProvisionerType(provisioner), resp.Hash, parameterFile) + job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ + Client: client, + Organization: organization, + Provisioner: database.ProvisionerType(provisioner), + FileHash: resp.Hash, + ParameterFile: parameterFile, + }) if err != nil { return err } @@ -98,7 +104,6 @@ func templateCreate() *cobra.Command { createReq := codersdk.CreateTemplateRequest{ Name: templateName, VersionID: job.ID, - ParameterValues: parameters, MaxTTLMillis: ptr.Ref(maxTTL.Milliseconds()), MinAutostartIntervalMillis: ptr.Ref(minAutostartInterval.Milliseconds()), } @@ -133,14 +138,34 @@ func templateCreate() *cobra.Command { return cmd } -func createValidTemplateVersion(cmd *cobra.Command, client *codersdk.Client, organization codersdk.Organization, provisioner database.ProvisionerType, hash string, parameterFile string, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { +type createValidTemplateVersionArgs struct { + Client *codersdk.Client + Organization codersdk.Organization + Provisioner database.ProvisionerType + FileHash string + ParameterFile string + // Template is only required if updating a template's active version. + Template *codersdk.Template + // ReuseParameters will attempt to reuse params from the Template field + // before prompting the user. Set to false to always prompt for param + // values. + ReuseParameters bool +} + +func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { before := time.Now() - version, err := client.CreateTemplateVersion(cmd.Context(), organization.ID, codersdk.CreateTemplateVersionRequest{ + client := args.Client + + req := codersdk.CreateTemplateVersionRequest{ StorageMethod: codersdk.ProvisionerStorageMethodFile, - StorageSource: hash, - Provisioner: codersdk.ProvisionerType(provisioner), + StorageSource: args.FileHash, + Provisioner: codersdk.ProvisionerType(args.Provisioner), ParameterValues: parameters, - }) + } + if args.Template != nil { + req.TemplateID = args.Template.ID + } + version, err := client.CreateTemplateVersion(cmd.Context(), args.Organization.ID, req) if err != nil { return nil, nil, err } @@ -175,33 +200,77 @@ func createValidTemplateVersion(cmd *cobra.Command, client *codersdk.Client, org return nil, nil, err } + // lastParameterValues are pulled from the current active template version if + // templateID is provided. This allows pulling params from the last + // version instead of prompting if we are updating template versions. + lastParameterValues := make(map[string]codersdk.Parameter) + if args.ReuseParameters && args.Template != nil { + activeVersion, err := client.TemplateVersion(cmd.Context(), args.Template.ActiveVersionID) + if err != nil { + return nil, nil, xerrors.Errorf("Fetch current active template version: %w", err) + } + + // We don't want to compute the params, we only want to copy from this scope + values, err := client.Parameters(cmd.Context(), codersdk.ParameterImportJob, activeVersion.Job.ID) + if err != nil { + return nil, nil, xerrors.Errorf("Fetch previous version parameters: %w", err) + } + for _, value := range values { + lastParameterValues[value.Name] = value + } + } + if provisionerd.IsMissingParameterError(version.Job.Error) { valuesBySchemaID := map[string]codersdk.TemplateVersionParameter{} for _, parameterValue := range parameterValues { valuesBySchemaID[parameterValue.SchemaID.String()] = parameterValue } + sort.Slice(parameterSchemas, func(i, j int) bool { return parameterSchemas[i].Name < parameterSchemas[j].Name }) + + // parameterMapFromFile can be nil if parameter file is not specified + var parameterMapFromFile map[string]string + if args.ParameterFile != "" { + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") + parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile) + if err != nil { + return nil, nil, err + } + } + + // pulled params come from the last template version + pulled := make([]string, 0) missingSchemas := make([]codersdk.ParameterSchema, 0) for _, parameterSchema := range parameterSchemas { _, ok := valuesBySchemaID[parameterSchema.ID.String()] if ok { continue } - missingSchemas = append(missingSchemas, parameterSchema) - } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set.")+"\r\n") - // parameterMapFromFile can be nil if parameter file is not specified - var parameterMapFromFile map[string]string - if parameterFile != "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") - parameterMapFromFile, err = createParameterMapFromFile(parameterFile) - if err != nil { - return nil, nil, err + // The file values are handled below. So don't handle them here, + // just check if a value is present in the file. + _, fileOk := parameterMapFromFile[parameterSchema.Name] + if inherit, ok := lastParameterValues[parameterSchema.Name]; ok && !fileOk { + // If the value is not in the param file, and can be pulled from the last template version, + // then don't mark it as missing. + parameters = append(parameters, codersdk.CreateParameterRequest{ + CloneID: inherit.ID, + }) + pulled = append(pulled, fmt.Sprintf("%q", parameterSchema.Name)) + continue } + + missingSchemas = append(missingSchemas, parameterSchema) } + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set.")) + if len(pulled) > 0 { + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(fmt.Sprintf("The following parameter values are being pulled from the latest template version: %s.", strings.Join(pulled, ", ")))) + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Use \"--always-prompt\" flag to change the values.")) + } + _, _ = fmt.Fprint(cmd.OutOrStdout(), "\r\n") + for _, parameterSchema := range missingSchemas { parameterValue, err := getParameterValueFromMapOrInput(cmd, parameterMapFromFile, parameterSchema) if err != nil { @@ -218,7 +287,7 @@ func createValidTemplateVersion(cmd *cobra.Command, client *codersdk.Client, org // This recursion is only 1 level deep in practice. // The first pass populates the missing parameters, so it does not enter this `if` block again. - return createValidTemplateVersion(cmd, client, organization, provisioner, hash, parameterFile, parameters...) + return createValidTemplateVersion(cmd, args, parameters...) } if version.Job.Status != codersdk.ProvisionerJobSucceeded { diff --git a/cli/templateupdate.go b/cli/templateupdate.go index 80276fd9d66c9..fcbba2b122c05 100644 --- a/cli/templateupdate.go +++ b/cli/templateupdate.go @@ -10,14 +10,17 @@ import ( "golang.org/x/xerrors" "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd/database" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisionersdk" ) func templateUpdate() *cobra.Command { var ( - directory string - provisioner string + directory string + provisioner string + parameterFile string + alwaysPrompt bool ) cmd := &cobra.Command{ @@ -64,42 +67,30 @@ func templateUpdate() *cobra.Command { } spin.Stop() - before := time.Now() - templateVersion, err := client.CreateTemplateVersion(cmd.Context(), organization.ID, codersdk.CreateTemplateVersionRequest{ - TemplateID: template.ID, - StorageMethod: codersdk.ProvisionerStorageMethodFile, - StorageSource: resp.Hash, - Provisioner: codersdk.ProvisionerType(provisioner), + job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ + Client: client, + Organization: organization, + Provisioner: database.ProvisionerType(provisioner), + FileHash: resp.Hash, + ParameterFile: parameterFile, + Template: &template, + ReuseParameters: !alwaysPrompt, }) if err != nil { return err } - logs, err := client.TemplateVersionLogsAfter(cmd.Context(), templateVersion.ID, before) - if err != nil { - return err - } - for { - log, ok := <-logs - if !ok { - break - } - _, _ = fmt.Printf("%s (%s): %s\n", provisioner, log.Level, log.Output) - } - templateVersion, err = client.TemplateVersion(cmd.Context(), templateVersion.ID) - if err != nil { - return err - } - if templateVersion.Job.Status != codersdk.ProvisionerJobSucceeded { - return xerrors.Errorf("job failed: %s", templateVersion.Job.Error) + if job.Job.Status != codersdk.ProvisionerJobSucceeded { + return xerrors.Errorf("job failed: %s", job.Job.Status) } err = client.UpdateActiveTemplateVersion(cmd.Context(), template.ID, codersdk.UpdateActiveTemplateVersion{ - ID: templateVersion.ID, + ID: job.ID, }) if err != nil { return err } + _, _ = fmt.Printf("Updated version!\n") return nil }, @@ -108,6 +99,8 @@ func templateUpdate() *cobra.Command { currentDirectory, _ := os.Getwd() cmd.Flags().StringVarP(&directory, "directory", "d", currentDirectory, "Specify the directory to create from") cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend") + cmd.Flags().StringVarP(¶meterFile, "parameter-file", "", "", "Specify a file path with parameter values.") + cmd.Flags().BoolVar(&alwaysPrompt, "always-prompt", false, "Always prompt all parameters. Does not pull parameter values from active template version") cliui.AllowSkipPrompt(cmd) // This is for testing! err := cmd.Flags().MarkHidden("test.provisioner") diff --git a/cli/templateupdate_test.go b/cli/templateupdate_test.go index 134e884370733..71cabc6ad408b 100644 --- a/cli/templateupdate_test.go +++ b/cli/templateupdate_test.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,47 +18,152 @@ import ( func TestTemplateUpdate(t *testing.T) { t.Parallel() + // NewParameter will: + // 1. Create a template version with 0 params + // 2. Create a new version with 1 param + // 2a. Expects 1 param prompt, fills in value + // 3. Assert 1 param value in new version + // 4. Creates a new version with same param + // 4a. Expects 0 prompts as the param value is carried over + // 5. Assert 1 param value in new version + // 6. Creates a new version with 0 params + // 7. Asset 0 params in new version + t.Run("NewParameter", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + // Create initial template version to update + lastActiveVersion := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + _ = coderdtest.AwaitTemplateVersionJob(t, client, lastActiveVersion.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, lastActiveVersion.ID) - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) - user := coderdtest.CreateFirstUser(t, client) - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) - _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + // Create new template version with a new parameter + source := clitest.CreateTemplateVersionSource(t, &echo.Responses{ + Parse: createTestParseResponse(), + Provision: echo.ProvisionComplete, + }) + cmd, root := clitest.New(t, "templates", "update", template.Name, "-y", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) - // Test the cli command. - source := clitest.CreateTemplateVersionSource(t, &echo.Responses{ - Parse: echo.ParseComplete, - Provision: echo.ProvisionComplete, + execDone := make(chan error) + go func() { + execDone <- cmd.Execute() + }() + + matches := []struct { + match string + write string + }{ + // Expect to be prompted for the new param + {match: "Enter a value:", write: "peter-pan"}, + } + for _, m := range matches { + pty.ExpectMatch(m.match) + pty.WriteLine(m.write) + } + + require.NoError(t, <-execDone) + + // Assert template version changed and we have the new param + latestTV, latestParams := latestTemplateVersion(t, client, template.ID) + assert.NotEqual(t, lastActiveVersion.ID, latestTV.ID) + require.Len(t, latestParams, 1, "expect 1 param") + lastActiveVersion = latestTV + + // Second update of the same source requires no prompt since the params + // are carried over. + cmd, root = clitest.New(t, "templates", "update", template.Name, "-y", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) + clitest.SetupConfig(t, client, root) + go func() { + execDone <- cmd.Execute() + }() + require.NoError(t, <-execDone) + + // Assert template version changed and we have the carried over param + latestTV, latestParams = latestTemplateVersion(t, client, template.ID) + assert.NotEqual(t, lastActiveVersion.ID, latestTV.ID) + require.Len(t, latestParams, 1, "expect 1 param") + lastActiveVersion = latestTV + + // Remove the param + source = clitest.CreateTemplateVersionSource(t, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: echo.ProvisionComplete, + }) + + cmd, root = clitest.New(t, "templates", "update", template.Name, "-y", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) + clitest.SetupConfig(t, client, root) + go func() { + execDone <- cmd.Execute() + }() + require.NoError(t, <-execDone) + // Assert template version changed and the param was removed + latestTV, latestParams = latestTemplateVersion(t, client, template.ID) + assert.NotEqual(t, lastActiveVersion.ID, latestTV.ID) + require.Len(t, latestParams, 0, "expect 0 param") + lastActiveVersion = latestTV }) - cmd, root := clitest.New(t, "templates", "update", template.Name, "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) - clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() - - matches := []struct { - match string - write string - }{ - {match: "Upload", write: "yes"}, - } - for _, m := range matches { - pty.ExpectMatch(m.match) - pty.WriteLine(m.write) - } - - require.NoError(t, <-execDone) - - // Assert that the template version changed. - templateVersions, err := client.TemplateVersionsByTemplate(context.Background(), codersdk.TemplateVersionsByTemplateRequest{ - TemplateID: template.ID, + + t.Run("OK", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) + _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + + // Test the cli command. + source := clitest.CreateTemplateVersionSource(t, &echo.Responses{ + Parse: echo.ParseComplete, + Provision: echo.ProvisionComplete, + }) + cmd, root := clitest.New(t, "templates", "update", template.Name, "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) + clitest.SetupConfig(t, client, root) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) + + execDone := make(chan error) + go func() { + execDone <- cmd.Execute() + }() + + matches := []struct { + match string + write string + }{ + {match: "Upload", write: "yes"}, + } + for _, m := range matches { + pty.ExpectMatch(m.match) + pty.WriteLine(m.write) + } + + require.NoError(t, <-execDone) + + // Assert that the template version changed. + templateVersions, err := client.TemplateVersionsByTemplate(context.Background(), codersdk.TemplateVersionsByTemplateRequest{ + TemplateID: template.ID, + }) + require.NoError(t, err) + assert.Len(t, templateVersions, 2) + assert.NotEqual(t, template.ActiveVersionID, templateVersions[1].ID) }) +} + +func latestTemplateVersion(t *testing.T, client *codersdk.Client, templateID uuid.UUID) (codersdk.TemplateVersion, []codersdk.Parameter) { + t.Helper() + + ctx := context.Background() + newTemplate, err := client.Template(ctx, templateID) + require.NoError(t, err) + tv, err := client.TemplateVersion(ctx, newTemplate.ActiveVersionID) require.NoError(t, err) - assert.Len(t, templateVersions, 2) - assert.NotEqual(t, template.ActiveVersionID, templateVersions[1].ID) + params, err := client.Parameters(ctx, codersdk.ParameterImportJob, tv.Job.ID) + require.NoError(t, err) + + return tv, params } diff --git a/coderd/database/databasefake/databasefake.go b/coderd/database/databasefake/databasefake.go index e41c70b874ed8..ada3051085863 100644 --- a/coderd/database/databasefake/databasefake.go +++ b/coderd/database/databasefake/databasefake.go @@ -12,6 +12,7 @@ import ( "golang.org/x/exp/slices" "github.com/coder/coder/coderd/database" + "github.com/coder/coder/coderd/util/slice" ) // New returns an in-memory fake of the database. @@ -103,6 +104,19 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu return database.ProvisionerJob{}, sql.ErrNoRows } +func (q *fakeQuerier) ParameterValue(_ context.Context, id uuid.UUID) (database.ParameterValue, error) { + q.mutex.Lock() + defer q.mutex.Unlock() + + for _, parameterValue := range q.parameterValues { + if parameterValue.ID.String() != id.String() { + continue + } + return parameterValue, nil + } + return database.ParameterValue{}, sql.ErrNoRows +} + func (q *fakeQuerier) DeleteParameterValueByID(_ context.Context, id uuid.UUID) error { q.mutex.Lock() defer q.mutex.Unlock() @@ -744,17 +758,27 @@ func (q *fakeQuerier) GetOrganizationsByUserID(_ context.Context, userID uuid.UU return organizations, nil } -func (q *fakeQuerier) GetParameterValuesByScope(_ context.Context, arg database.GetParameterValuesByScopeParams) ([]database.ParameterValue, error) { +func (q *fakeQuerier) ParameterValues(_ context.Context, arg database.ParameterValuesParams) ([]database.ParameterValue, error) { q.mutex.RLock() defer q.mutex.RUnlock() parameterValues := make([]database.ParameterValue, 0) for _, parameterValue := range q.parameterValues { - if parameterValue.Scope != arg.Scope { - continue + if len(arg.Scopes) > 0 { + if !slice.Contains(arg.Scopes, parameterValue.Scope) { + continue + } } - if parameterValue.ScopeID != arg.ScopeID { - continue + if len(arg.ScopeIds) > 0 { + if !slice.Contains(arg.ScopeIds, parameterValue.ScopeID) { + continue + } + } + + if len(arg.Ids) > 0 { + if !slice.Contains(arg.Ids, parameterValue.ID) { + continue + } } parameterValues = append(parameterValues, parameterValue) } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index f3108dc8b00ec..4e10a68133a67 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -45,7 +45,6 @@ type querier interface { GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error) GetParameterSchemasCreatedAfter(ctx context.Context, createdAt time.Time) ([]ParameterSchema, error) GetParameterValueByScopeAndName(ctx context.Context, arg GetParameterValueByScopeAndNameParams) (ParameterValue, error) - GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID) (ProvisionerDaemon, error) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error) @@ -109,6 +108,8 @@ type querier interface { InsertWorkspaceApp(ctx context.Context, arg InsertWorkspaceAppParams) (WorkspaceApp, error) InsertWorkspaceBuild(ctx context.Context, arg InsertWorkspaceBuildParams) (WorkspaceBuild, error) InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error) + ParameterValue(ctx context.Context, id uuid.UUID) (ParameterValue, error) + ParameterValues(ctx context.Context, arg ParameterValuesParams) ([]ParameterValue, error) UpdateAPIKeyByID(ctx context.Context, arg UpdateAPIKeyByIDParams) error UpdateGitSSHKey(ctx context.Context, arg UpdateGitSSHKeyParams) error UpdateMemberRoles(ctx context.Context, arg UpdateMemberRolesParams) (OrganizationMember, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index b91a784a2fc84..0ad6f9d554c94 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1079,54 +1079,6 @@ func (q *sqlQuerier) GetParameterValueByScopeAndName(ctx context.Context, arg Ge return i, err } -const getParameterValuesByScope = `-- name: GetParameterValuesByScope :many -SELECT - id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme -FROM - parameter_values -WHERE - scope = $1 - AND scope_id = $2 -` - -type GetParameterValuesByScopeParams struct { - Scope ParameterScope `db:"scope" json:"scope"` - ScopeID uuid.UUID `db:"scope_id" json:"scope_id"` -} - -func (q *sqlQuerier) GetParameterValuesByScope(ctx context.Context, arg GetParameterValuesByScopeParams) ([]ParameterValue, error) { - rows, err := q.db.QueryContext(ctx, getParameterValuesByScope, arg.Scope, arg.ScopeID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ParameterValue - for rows.Next() { - var i ParameterValue - if err := rows.Scan( - &i.ID, - &i.CreatedAt, - &i.UpdatedAt, - &i.Scope, - &i.ScopeID, - &i.Name, - &i.SourceScheme, - &i.SourceValue, - &i.DestinationScheme, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const insertParameterValue = `-- name: InsertParameterValue :one INSERT INTO parameter_values ( @@ -1183,6 +1135,103 @@ func (q *sqlQuerier) InsertParameterValue(ctx context.Context, arg InsertParamet return i, err } +const parameterValue = `-- name: ParameterValue :one +SELECT id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme FROM + parameter_values +WHERE + id = $1 +` + +func (q *sqlQuerier) ParameterValue(ctx context.Context, id uuid.UUID) (ParameterValue, error) { + row := q.db.QueryRowContext(ctx, parameterValue, id) + var i ParameterValue + err := row.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.Scope, + &i.ScopeID, + &i.Name, + &i.SourceScheme, + &i.SourceValue, + &i.DestinationScheme, + ) + return i, err +} + +const parameterValues = `-- name: ParameterValues :many +SELECT + id, created_at, updated_at, scope, scope_id, name, source_scheme, source_value, destination_scheme +FROM + parameter_values +WHERE + CASE + WHEN cardinality($1 :: parameter_scope[]) > 0 THEN + scope = ANY($1 :: parameter_scope[]) + ELSE true + END + AND CASE + WHEN cardinality($2 :: uuid[]) > 0 THEN + scope_id = ANY($2 :: uuid[]) + ELSE true + END + AND CASE + WHEN cardinality($3 :: uuid[]) > 0 THEN + id = ANY($3 :: uuid[]) + ELSE true + END + AND CASE + WHEN cardinality($4 :: text[]) > 0 THEN + "name" = ANY($4 :: text[]) + ELSE true + END +` + +type ParameterValuesParams struct { + Scopes []ParameterScope `db:"scopes" json:"scopes"` + ScopeIds []uuid.UUID `db:"scope_ids" json:"scope_ids"` + Ids []uuid.UUID `db:"ids" json:"ids"` + Names []string `db:"names" json:"names"` +} + +func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesParams) ([]ParameterValue, error) { + rows, err := q.db.QueryContext(ctx, parameterValues, + pq.Array(arg.Scopes), + pq.Array(arg.ScopeIds), + pq.Array(arg.Ids), + pq.Array(arg.Names), + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ParameterValue + for rows.Next() { + var i ParameterValue + if err := rows.Scan( + &i.ID, + &i.CreatedAt, + &i.UpdatedAt, + &i.Scope, + &i.ScopeID, + &i.Name, + &i.SourceScheme, + &i.SourceValue, + &i.DestinationScheme, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one SELECT id, created_at, updated_at, name, provisioners diff --git a/coderd/database/queries/parametervalues.sql b/coderd/database/queries/parametervalues.sql index 5a6ea6ece2818..d747725d271a4 100644 --- a/coderd/database/queries/parametervalues.sql +++ b/coderd/database/queries/parametervalues.sql @@ -1,17 +1,43 @@ +-- name: ParameterValue :one +SELECT * FROM + parameter_values +WHERE + id = $1; + + -- name: DeleteParameterValueByID :exec DELETE FROM parameter_values WHERE id = $1; --- name: GetParameterValuesByScope :many +-- name: ParameterValues :many SELECT * FROM parameter_values WHERE - scope = $1 - AND scope_id = $2; + CASE + WHEN cardinality(@scopes :: parameter_scope[]) > 0 THEN + scope = ANY(@scopes :: parameter_scope[]) + ELSE true + END + AND CASE + WHEN cardinality(@scope_ids :: uuid[]) > 0 THEN + scope_id = ANY(@scope_ids :: uuid[]) + ELSE true + END + AND CASE + WHEN cardinality(@ids :: uuid[]) > 0 THEN + id = ANY(@ids :: uuid[]) + ELSE true + END + AND CASE + WHEN cardinality(@names :: text[]) > 0 THEN + "name" = ANY(@names :: text[]) + ELSE true + END +; -- name: GetParameterValueByScopeAndName :one SELECT diff --git a/coderd/parameter/compute.go b/coderd/parameter/compute.go index 121cfa9ad4351..8a79850b5c99b 100644 --- a/coderd/parameter/compute.go +++ b/coderd/parameter/compute.go @@ -61,9 +61,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options } // Job parameters come second! - err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{ - Scope: database.ParameterScopeImportJob, - ScopeID: scope.TemplateImportJobID, + err = compute.injectScope(ctx, database.ParameterValuesParams{ + Scopes: []database.ParameterScope{database.ParameterScopeImportJob}, + ScopeIds: []uuid.UUID{scope.TemplateImportJobID}, }) if err != nil { return nil, err @@ -105,9 +105,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options if scope.TemplateID.Valid { // Template parameters come third! - err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{ - Scope: database.ParameterScopeTemplate, - ScopeID: scope.TemplateID.UUID, + err = compute.injectScope(ctx, database.ParameterValuesParams{ + Scopes: []database.ParameterScope{database.ParameterScopeTemplate}, + ScopeIds: []uuid.UUID{scope.TemplateID.UUID}, }) if err != nil { return nil, err @@ -116,9 +116,9 @@ func Compute(ctx context.Context, db database.Store, scope ComputeScope, options if scope.WorkspaceID.Valid { // Workspace parameters come last! - err = compute.injectScope(ctx, database.GetParameterValuesByScopeParams{ - Scope: database.ParameterScopeWorkspace, - ScopeID: scope.WorkspaceID.UUID, + err = compute.injectScope(ctx, database.ParameterValuesParams{ + Scopes: []database.ParameterScope{database.ParameterScopeWorkspace}, + ScopeIds: []uuid.UUID{scope.WorkspaceID.UUID}, }) if err != nil { return nil, err @@ -148,13 +148,13 @@ type compute struct { } // Validates and computes the value for parameters; setting the value on "parameterByName". -func (c *compute) injectScope(ctx context.Context, scopeParams database.GetParameterValuesByScopeParams) error { - scopedParameters, err := c.db.GetParameterValuesByScope(ctx, scopeParams) +func (c *compute) injectScope(ctx context.Context, scopeParams database.ParameterValuesParams) error { + scopedParameters, err := c.db.ParameterValues(ctx, scopeParams) if errors.Is(err, sql.ErrNoRows) { err = nil } if err != nil { - return xerrors.Errorf("get %s parameters: %w", scopeParams.Scope, err) + return xerrors.Errorf("get %s parameters: %w", scopeParams.Scopes, err) } for _, scopedParameter := range scopedParameters { diff --git a/coderd/parameters.go b/coderd/parameters.go index 615e8f84948ca..1d22d267fbe2b 100644 --- a/coderd/parameters.go +++ b/coderd/parameters.go @@ -91,9 +91,9 @@ func (api *API) parameters(rw http.ResponseWriter, r *http.Request) { return } - parameterValues, err := api.Database.GetParameterValuesByScope(r.Context(), database.GetParameterValuesByScopeParams{ - Scope: scope, - ScopeID: scopeID, + parameterValues, err := api.Database.ParameterValues(r.Context(), database.ParameterValuesParams{ + Scopes: []database.ParameterScope{scope}, + ScopeIds: []uuid.UUID{scopeID}, }) if errors.Is(err, sql.ErrNoRows) { err = nil @@ -214,6 +214,8 @@ func (api *API) parameterRBACResource(rw http.ResponseWriter, r *http.Request, s switch scope { case database.ParameterScopeWorkspace: resource, err = api.Database.GetWorkspaceByID(ctx, scopeID) + case database.ParameterScopeImportJob: + resource, err = api.Database.GetTemplateVersionByJobID(ctx, scopeID) case database.ParameterScopeTemplate: resource, err = api.Database.GetTemplateByID(ctx, scopeID) default: @@ -237,12 +239,9 @@ func (api *API) parameterRBACResource(rw http.ResponseWriter, r *http.Request, s } func readScopeAndID(rw http.ResponseWriter, r *http.Request) (database.ParameterScope, uuid.UUID, bool) { - var scope database.ParameterScope - switch chi.URLParam(r, "scope") { - case string(codersdk.ParameterTemplate): - scope = database.ParameterScopeTemplate - case string(codersdk.ParameterWorkspace): - scope = database.ParameterScopeWorkspace + scope := database.ParameterScope(chi.URLParam(r, "scope")) + switch scope { + case database.ParameterScopeTemplate, database.ParameterScopeImportJob, database.ParameterScopeWorkspace: default: httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{ Message: fmt.Sprintf("Invalid scope %q.", scope), diff --git a/coderd/templates.go b/coderd/templates.go index 60ab03ef938d3..b71daffcb7ec8 100644 --- a/coderd/templates.go +++ b/coderd/templates.go @@ -220,7 +220,7 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque CreatedAt: database.Now(), UpdatedAt: database.Now(), Scope: database.ParameterScopeTemplate, - ScopeID: dbTemplate.ID, + ScopeID: template.ID, SourceScheme: database.ParameterSourceScheme(parameterValue.SourceScheme), SourceValue: parameterValue.SourceValue, DestinationScheme: database.ParameterDestinationScheme(parameterValue.DestinationScheme), diff --git a/coderd/templateversions.go b/coderd/templateversions.go index 472fd87f8d9e8..de327ac7a0bea 100644 --- a/coderd/templateversions.go +++ b/coderd/templateversions.go @@ -559,9 +559,16 @@ func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Reque }) return } - err = api.Database.UpdateTemplateActiveVersionByID(r.Context(), database.UpdateTemplateActiveVersionByIDParams{ - ID: template.ID, - ActiveVersionID: req.ID, + + err = api.Database.InTx(func(store database.Store) error { + err = store.UpdateTemplateActiveVersionByID(r.Context(), database.UpdateTemplateActiveVersionByIDParams{ + ID: template.ID, + ActiveVersionID: req.ID, + }) + if err != nil { + return xerrors.Errorf("update active version: %w", err) + } + return nil }) if err != nil { httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{ @@ -631,7 +638,53 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht var provisionerJob database.ProvisionerJob err = api.Database.InTx(func(db database.Store) error { jobID := uuid.New() + inherits := make([]uuid.UUID, 0) for _, parameterValue := range req.ParameterValues { + if parameterValue.CloneID != uuid.Nil { + inherits = append(inherits, parameterValue.CloneID) + } + } + + // Expand inherited params + if len(inherits) > 0 { + if req.TemplateID == uuid.Nil { + return xerrors.Errorf("cannot inherit parameters if template_id is not set") + } + + inheritedParams, err := db.ParameterValues(r.Context(), database.ParameterValuesParams{ + Ids: inherits, + }) + if err != nil { + return xerrors.Errorf("fetch inherited params: %w", err) + } + for _, copy := range inheritedParams { + // This is a bit inefficient, as we make a new db call for each + // param. + version, err := db.GetTemplateVersionByJobID(r.Context(), copy.ScopeID) + if err != nil { + return xerrors.Errorf("fetch template version for param %q: %w", copy.Name, err) + } + if !version.TemplateID.Valid || version.TemplateID.UUID != req.TemplateID { + return xerrors.Errorf("cannot inherit parameters from other templates") + } + if copy.Scope != database.ParameterScopeImportJob { + return xerrors.Errorf("copy parameter scope is %q, must be %q", copy.Scope, database.ParameterScopeImportJob) + } + // Add the copied param to the list to process + req.ParameterValues = append(req.ParameterValues, codersdk.CreateParameterRequest{ + Name: copy.Name, + SourceValue: copy.SourceValue, + SourceScheme: codersdk.ParameterSourceScheme(copy.SourceScheme), + DestinationScheme: codersdk.ParameterDestinationScheme(copy.DestinationScheme), + }) + } + } + + for _, parameterValue := range req.ParameterValues { + if parameterValue.CloneID != uuid.Nil { + continue + } + _, err = db.InsertParameterValue(r.Context(), database.InsertParameterValueParams{ ID: uuid.New(), Name: parameterValue.Name, diff --git a/coderd/util/slice/slice.go b/coderd/util/slice/slice.go new file mode 100644 index 0000000000000..dfea2ed26f6c6 --- /dev/null +++ b/coderd/util/slice/slice.go @@ -0,0 +1,10 @@ +package slice + +func Contains[T comparable](haystack []T, needle T) bool { + for _, hay := range haystack { + if needle == hay { + return true + } + } + return false +} diff --git a/coderd/util/slice/slice_test.go b/coderd/util/slice/slice_test.go new file mode 100644 index 0000000000000..1f485b77e3067 --- /dev/null +++ b/coderd/util/slice/slice_test.go @@ -0,0 +1,35 @@ +package slice_test + +import ( + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/coderd/util/slice" +) + +func TestContains(t *testing.T) { + t.Parallel() + + assertSetContains(t, []int{1, 2, 3, 4, 5}, []int{1, 2, 3, 4, 5}, []int{0, 6, -1, -2, 100}) + assertSetContains(t, []string{"hello", "world", "foo", "bar", "baz"}, []string{"hello", "world", "baz"}, []string{"not", "words", "in", "set"}) + assertSetContains(t, + []uuid.UUID{uuid.New(), uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5"), uuid.MustParse("8f3b3e0b-2c3f-46a5-a365-fd5b62bd8818")}, + []uuid.UUID{uuid.MustParse("c7c6686d-a93c-4df2-bef9-5f837e9a33d5")}, + []uuid.UUID{uuid.MustParse("1d00e27d-8de6-46f8-80d5-1da0ca83874a")}, + ) +} + +func assertSetContains[T comparable](t *testing.T, set []T, in []T, out []T) { + t.Helper() + for _, e := range set { + require.True(t, slice.Contains(set, e), "elements in set should be in the set") + } + for _, e := range in { + require.True(t, slice.Contains(set, e), "expect element in set") + } + for _, e := range out { + require.False(t, slice.Contains(set, e), "expect element in set") + } +} diff --git a/codersdk/parameters.go b/codersdk/parameters.go index 85c9bd92576e5..a4943817144be 100644 --- a/codersdk/parameters.go +++ b/codersdk/parameters.go @@ -14,9 +14,9 @@ import ( type ParameterScope string const ( - ParameterTemplate ParameterScope = "template" - ParameterWorkspace ParameterScope = "workspace" - ParameterScopeImportJob ParameterScope = "import_job" + ParameterTemplate ParameterScope = "template" + ParameterWorkspace ParameterScope = "workspace" + ParameterImportJob ParameterScope = "import_job" ) type ParameterSourceScheme string @@ -78,6 +78,13 @@ type ParameterSchema struct { // CreateParameterRequest is used to create a new parameter value for a scope. type CreateParameterRequest struct { + // CloneID allows copying the value of another parameter. + // The other param must be related to the same template_id for this to + // succeed. + // No other fields are required if using this, as all fields will be copied + // from the other parameter. + CloneID uuid.UUID `json:"copy_from_parameter,omitempty" validate:""` + Name string `json:"name" validate:"required"` SourceValue string `json:"source_value" validate:"required"` SourceScheme ParameterSourceScheme `json:"source_scheme" validate:"oneof=data,required"` diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 0f2521ad79d87..d6b9ba0577105 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -51,6 +51,7 @@ export interface CreateOrganizationRequest { // From codersdk/parameters.go:80:6 export interface CreateParameterRequest { + readonly copy_from_parameter?: string readonly name: string readonly source_value: string readonly source_scheme: ParameterSourceScheme