Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions cli/exp_taskcreate.go → cli/exp_task_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"fmt"
"io"
"strings"

"github.com/google/uuid"
Expand All @@ -20,43 +21,49 @@ func (r *RootCmd) taskCreate() *serpent.Command {
templateName string
templateVersionName string
presetName string
taskInput string
stdin bool
)

cmd := &serpent.Command{
Use: "create [template]",
Use: "create [input]",
Short: "Create an experimental task",
Middleware: serpent.Chain(
serpent.RequireRangeArgs(0, 1),
r.InitClient(client),
),
Options: serpent.OptionSet{
{
Flag: "input",
Env: "CODER_TASK_INPUT",
Value: serpent.StringOf(&taskInput),
Required: true,
},
{
Name: "template",
Flag: "template",
Env: "CODER_TASK_TEMPLATE_NAME",
Value: serpent.StringOf(&templateName),
},
{
Name: "template-version",
Flag: "template-version",
Env: "CODER_TASK_TEMPLATE_VERSION",
Value: serpent.StringOf(&templateVersionName),
},
{
Name: "preset",
Flag: "preset",
Env: "CODER_TASK_PRESET_NAME",
Value: serpent.StringOf(&presetName),
Default: PresetNone,
},
{
Name: "stdin",
Flag: "stdin",
Description: "Reads from stdin for the task input.",
Value: serpent.BoolOf(&stdin),
},
},
Handler: func(inv *serpent.Invocation) error {
var (
ctx = inv.Context()
expClient = codersdk.NewExperimentalClient(client)

taskInput string
templateVersionID uuid.UUID
templateVersionPresetID uuid.UUID
)
Expand All @@ -66,22 +73,68 @@ func (r *RootCmd) taskCreate() *serpent.Command {
return xerrors.Errorf("get current organization: %w", err)
}

if len(inv.Args) > 0 {
templateName, templateVersionName, _ = strings.Cut(inv.Args[0], "@")
if stdin {
bytes, err := io.ReadAll(inv.Stdin)
if err != nil {
return xerrors.Errorf("reading stdin: %w", err)
}

taskInput = string(bytes)
} else {
if len(inv.Args) != 1 {
return xerrors.Errorf("expected an input for task")
}

taskInput = inv.Args[0]
}

if templateName == "" {
return xerrors.Errorf("template name not provided")
if taskInput == "" {
return xerrors.Errorf("a task cannot be started with an empty input")
}

if templateVersionName != "" {
switch {
case templateName == "":
templates, err := client.Templates(ctx, codersdk.TemplateFilter{SearchQuery: "has-ai-task:true", OrganizationID: organization.ID})
if err != nil {
return xerrors.Errorf("list templates: %w", err)
}

if len(templates) == 0 {
return xerrors.Errorf("no task templates configured")
}

// When a deployment has only 1 AI task template, we will
// allow omitting the template. Otherwise we will require
// the user to be explicit with their choice of template.
if len(templates) > 1 {
templateNames := make([]string, 0, len(templates))
for _, template := range templates {
templateNames = append(templateNames, template.Name)
}

return xerrors.Errorf("template name not provided, available templates: %s", strings.Join(templateNames, ", "))
}

if templateVersionName != "" {
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templates[0].Name, templateVersionName)
if err != nil {
return xerrors.Errorf("get template version: %w", err)
}

templateVersionID = templateVersion.ID
} else {
templateVersionID = templates[0].ActiveVersionID
}

case templateVersionName != "":
templateVersion, err := client.TemplateVersionByOrganizationAndName(ctx, organization.ID, templateName, templateVersionName)
if err != nil {
return xerrors.Errorf("get template version: %w", err)
}

templateVersionID = templateVersion.ID
} else {

default:
template, err := client.TemplateByName(ctx, organization.ID, templateName)
if err != nil {
return xerrors.Errorf("get template: %w", err)
Expand Down
94 changes: 80 additions & 14 deletions cli/exp_taskcreate_test.go → cli/exp_task_create_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,14 @@ func TestTaskCreate(t *testing.T) {
Name: presetName,
},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
{
ID: templateID,
Name: templateName,
ActiveVersionID: templateVersionID,
},
})
case "/api/experimental/tasks/me":
var req codersdk.CreateTaskRequest
if !httpapi.Read(ctx, w, r, &req) {
Expand Down Expand Up @@ -88,71 +96,80 @@ func TestTaskCreate(t *testing.T) {
tests := []struct {
args []string
env []string
stdin string
expectError string
expectOutput string
handler func(t *testing.T, ctx context.Context) http.HandlerFunc
}{
{
args: []string{"my-template@my-template-version", "--input", "my custom prompt", "--org", organizationID.String()},
args: []string{"--stdin"},
stdin: "reads prompt from stdin",
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "reads prompt from stdin")
},
},
{
args: []string{"my custom prompt"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_VERSION=my-template-version"},
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "my-template-version", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"--input", "my custom prompt", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version"},
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_VERSION=my-template-version"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version", "CODER_TASK_INPUT=my custom prompt", "CODER_ORGANIZATION=" + organizationID.String()},
args: []string{"my custom prompt", "--org", organizationID.String()},
env: []string{"CODER_TASK_TEMPLATE_NAME=my-template", "CODER_TASK_TEMPLATE_VERSION=my-template-version"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "my-template-version", "", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt", "--org", organizationID.String()},
args: []string{"my custom prompt", "--template", "my-template", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt", "--preset", "my-preset", "--org", organizationID.String()},
args: []string{"my custom prompt", "--template", "my-template", "--preset", "my-preset", "--org", organizationID.String()},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt"},
args: []string{"my custom prompt", "--template", "my-template"},
env: []string{"CODER_TASK_PRESET_NAME=my-preset"},
expectOutput: fmt.Sprintf("The task %s has been created at %s!", cliui.Keyword("task-wild-goldfish-27"), cliui.Timestamp(taskCreatedAt)),
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my-template", "--input", "my custom prompt", "--preset", "not-real-preset"},
args: []string{"my custom prompt", "--template", "my-template", "--preset", "not-real-preset"},
expectError: `preset "not-real-preset" not found`,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return templateAndVersionFoundHandler(t, ctx, organizationID, "my-template", "", "my-preset", "my custom prompt")
},
},
{
args: []string{"my-template@not-real-template-version", "--input", "my custom prompt"},
args: []string{"my custom prompt", "--template", "my-template", "--template-version", "not-real-template-version"},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -163,6 +180,11 @@ func TestTaskCreate(t *testing.T) {
ID: organizationID,
}},
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template", organizationID):
httpapi.Write(ctx, w, http.StatusOK, codersdk.Template{
ID: templateID,
ActiveVersionID: templateVersionID,
})
case fmt.Sprintf("/api/v2/organizations/%s/templates/my-template/versions/not-real-template-version", organizationID):
httpapi.ResourceNotFound(w)
default:
Expand All @@ -172,7 +194,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"not-real-template", "--input", "my custom prompt", "--org", organizationID.String()},
args: []string{"my custom prompt", "--template", "not-real-template", "--org", organizationID.String()},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -192,7 +214,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"template-in-different-org", "--input", "my-custom-prompt", "--org", anotherOrganizationID.String()},
args: []string{"my-custom-prompt", "--template", "template-in-different-org", "--org", anotherOrganizationID.String()},
expectError: httpapi.ResourceNotFoundResponse.Message,
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -212,7 +234,7 @@ func TestTaskCreate(t *testing.T) {
},
},
{
args: []string{"no-org", "--input", "my-custom-prompt"},
args: []string{"no-org-prompt"},
expectError: "Must select an organization with --org=<org_name>",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -225,6 +247,49 @@ func TestTaskCreate(t *testing.T) {
}
},
},
{
args: []string{"no task templates"},
expectError: "no task templates configured",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/organizations":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
{MinimalOrganization: codersdk.MinimalOrganization{
ID: organizationID,
}},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
{
args: []string{"no template name provided"},
expectError: "template name not provided, available templates: wibble, wobble",
handler: func(t *testing.T, ctx context.Context) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/v2/users/me/organizations":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Organization{
{MinimalOrganization: codersdk.MinimalOrganization{
ID: organizationID,
}},
})
case "/api/v2/templates":
httpapi.Write(ctx, w, http.StatusOK, []codersdk.Template{
{Name: "wibble"},
{Name: "wobble"},
})
default:
t.Errorf("unexpected path: %s", r.URL.Path)
}
}
},
},
}

for _, tt := range tests {
Expand All @@ -244,6 +309,7 @@ func TestTaskCreate(t *testing.T) {

inv, root := clitest.New(t, append(args, tt.args...)...)
inv.Environ = serpent.ParseEnviron(tt.env, "")
inv.Stdin = strings.NewReader(tt.stdin)
inv.Stdout = &sb
inv.Stderr = &sb
clitest.SetupConfig(t, client, root)
Expand Down
18 changes: 18 additions & 0 deletions coderd/apikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"

"cdr.dev/slog"

"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
Expand Down Expand Up @@ -56,6 +58,14 @@ func (api *API) postToken(rw http.ResponseWriter, r *http.Request) {
return
}

// TODO(Cian): System users technically just have the 'member' role
// and we don't want to disallow all members from creating API keys.
if user.IsSystem {
api.Logger.Warn(ctx, "disallowed creating api key for system user", slog.F("user_id", user.ID))
httpapi.Forbidden(rw)
return
}

scope := database.APIKeyScopeAll
if scope != "" {
scope = database.APIKeyScope(createToken.Scope)
Expand Down Expand Up @@ -124,6 +134,14 @@ func (api *API) postAPIKey(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user := httpmw.UserParam(r)

// TODO(Cian): System users technically just have the 'member' role
// and we don't want to disallow all members from creating API keys.
if user.IsSystem {
api.Logger.Warn(ctx, "disallowed creating api key for system user", slog.F("user_id", user.ID))
httpapi.Forbidden(rw)
return
}

cookie, _, err := api.createAPIKey(ctx, apikey.CreateParams{
UserID: user.ID,
DefaultLifetime: api.DeploymentValues.Sessions.DefaultTokenDuration.Value(),
Expand Down
Loading
Loading