diff --git a/cli/server.go b/cli/server.go index f9e744761b22e..4d78cb47e475f 100644 --- a/cli/server.go +++ b/cli/server.go @@ -31,6 +31,8 @@ import ( "sync/atomic" "time" + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" "github.com/charmbracelet/lipgloss" "github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-systemd/daemon" @@ -629,6 +631,13 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. vals.WorkspaceHostnameSuffix.String()) } + var anthropicClient atomic.Pointer[anthropic.Client] + if vals.AnthropicAPIKey.String() != "" { + client := anthropic.NewClient(anthropicoption.WithAPIKey(vals.AnthropicAPIKey.String())) + + anthropicClient.Store(&client) + } + options := &coderd.Options{ AccessURL: vals.AccessURL.Value(), AppHostname: appHostname, @@ -666,6 +675,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. AllowWorkspaceRenames: vals.AllowWorkspaceRenames.Value(), Entitlements: entitlements.New(), NotificationsEnqueuer: notifications.NewNoopEnqueuer(), // Changed further down if notifications enabled. + AnthropicClient: &anthropicClient, } if httpServers.TLSConfig != nil { options.TLSCertificates = httpServers.TLSConfig.Certificates diff --git a/coderd/aitasks.go b/coderd/aitasks.go index e1d72f264a025..0c45a880ac1f8 100644 --- a/coderd/aitasks.go +++ b/coderd/aitasks.go @@ -1,15 +1,20 @@ package coderd import ( + "context" "database/sql" "errors" "fmt" + "io" "net/http" "slices" "strings" + "github.com/anthropics/anthropic-sdk-go" "github.com/google/uuid" + "golang.org/x/xerrors" + "github.com/coder/aisdk-go" "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/httpapi" @@ -69,6 +74,74 @@ func (api *API) aiTasksPrompts(rw http.ResponseWriter, r *http.Request) { }) } +func (api *API) generateTaskName(ctx context.Context, prompt, fallback string) (string, error) { + var ( + stream aisdk.DataStream + err error + ) + + conversation := []aisdk.Message{ + { + Role: "system", + Parts: []aisdk.Part{{ + Type: aisdk.PartTypeText, + Text: `You are a task summarizer. +You summarize AI prompts into workspace names. +You will only respond with a workspace name. +The workspace name **MUST** follow this regex /^[a-z0-9]+(?:-[a-z0-9]+)*$/ +The workspace name **MUST** be 32 characters or **LESS**. +The workspace name **MUST** be all lower case. +The workspace name **MUST** end in a number between 0 and 100. +The workspace name **MUST** be prefixed with "task".`, + }}, + }, + { + Role: "user", + Parts: []aisdk.Part{{ + Type: aisdk.PartTypeText, + Text: prompt, + }}, + }, + } + + anthropicClient := api.anthropicClient.Load() + if anthropicClient == nil { + return fallback, nil + } + + stream, err = anthropicDataStream(ctx, *anthropicClient, conversation) + if err != nil { + return "", xerrors.Errorf("create anthropic data stream: %w", err) + } + + var acc aisdk.DataStreamAccumulator + stream = stream.WithAccumulator(&acc) + + if err := stream.Pipe(io.Discard); err != nil { + return "", err + } + + if len(acc.Messages()) == 0 { + return fallback, nil + } + + return acc.Messages()[0].Content, nil +} + +func anthropicDataStream(ctx context.Context, client anthropic.Client, input []aisdk.Message) (aisdk.DataStream, error) { + messages, system, err := aisdk.MessagesToAnthropic(input) + if err != nil { + return nil, xerrors.Errorf("convert messages to anthropic format: %w", err) + } + + return aisdk.AnthropicToDataStream(client.Messages.NewStreaming(ctx, anthropic.MessageNewParams{ + Model: anthropic.ModelClaude3_5HaikuLatest, + MaxTokens: 24, + System: system, + Messages: messages, + })), nil +} + // This endpoint is experimental and not guaranteed to be stable, so we're not // generating public-facing documentation for it. func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { @@ -104,8 +177,21 @@ func (api *API) tasksCreate(rw http.ResponseWriter, r *http.Request) { return } + taskName, err := api.generateTaskName(ctx, req.Prompt, req.Name) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error generating name for task.", + Detail: err.Error(), + }) + return + } + + if taskName == "" { + taskName = req.Name + } + createReq := codersdk.CreateWorkspaceRequest{ - Name: req.Name, + Name: taskName, TemplateVersionID: req.TemplateVersionID, TemplateVersionPresetID: req.TemplateVersionPresetID, RichParameterValues: []codersdk.WorkspaceBuildParameter{ diff --git a/coderd/coderd.go b/coderd/coderd.go index 2aa30c9d7a45c..39724c174b972 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -20,6 +20,8 @@ import ( "sync/atomic" "time" + "github.com/anthropics/anthropic-sdk-go" + "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/coder/v2/coderd/prebuilds" @@ -276,6 +278,8 @@ type Options struct { // WebPushDispatcher is a way to send notifications over Web Push. WebPushDispatcher webpush.Dispatcher + + AnthropicClient *atomic.Pointer[anthropic.Client] } // @title Coder API @@ -475,6 +479,10 @@ func New(options *Options) *API { options.NotificationsEnqueuer = notifications.NewNoopEnqueuer() } + if options.AnthropicClient == nil { + options.AnthropicClient = &atomic.Pointer[anthropic.Client]{} + } + r := chi.NewRouter() // We add this middleware early, to make sure that authorization checks made // by other middleware get recorded. @@ -600,7 +608,8 @@ func New(options *Options) *API { options.Database, options.Pubsub, ), - dbRolluper: options.DatabaseRolluper, + dbRolluper: options.DatabaseRolluper, + anthropicClient: options.AnthropicClient, } api.WorkspaceAppsProvider = workspaceapps.NewDBTokenProvider( options.Logger.Named("workspaceapps"), @@ -1723,6 +1732,8 @@ type API struct { // dbRolluper rolls up template usage stats from raw agent and app // stats. This is used to provide insights in the WebUI. dbRolluper *dbrollup.Rolluper + + anthropicClient *atomic.Pointer[anthropic.Client] } // Close waits for all WebSocket connections to drain before returning. diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 1d6fa4572772e..9ffe71aa229a6 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -497,6 +497,7 @@ type DeploymentValues struct { WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"` Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"` HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"` + AnthropicAPIKey serpent.String `json:"anthropic_api_key,omitempty" typescript:",notnull"` Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"` WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"` @@ -3205,6 +3206,13 @@ Write out the current server config as YAML to stdout.`, Group: &deploymentGroupClient, YAML: "hideAITasks", }, + { + Name: "Anthropic API Key", + Description: "API Key for accessing Anthropic's API platform.", + Env: "ANTHROPIC_API_KEY", + Value: &c.AnthropicAPIKey, + Group: &deploymentGroupClient, + }, } return opts diff --git a/go.mod b/go.mod index e10c7a248db7e..6d703cdd1245e 100644 --- a/go.mod +++ b/go.mod @@ -477,6 +477,7 @@ require ( ) require ( + github.com/anthropics/anthropic-sdk-go v1.4.0 github.com/brianvoe/gofakeit/v7 v7.3.0 github.com/coder/agentapi-sdk-go v0.0.0-20250505131810-560d1d88d225 github.com/coder/aisdk-go v0.0.9 @@ -500,7 +501,6 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.50.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.50.0 // indirect github.com/Masterminds/semver/v3 v3.3.1 // indirect - github.com/anthropics/anthropic-sdk-go v1.4.0 // indirect github.com/aquasecurity/go-version v0.0.1 // indirect github.com/aquasecurity/trivy v0.58.2 // indirect github.com/aws/aws-sdk-go v1.55.7 // indirect