diff --git a/aibridged/aibridged.go b/aibridged/aibridged.go new file mode 100644 index 0000000000000..fdde680c08fd3 --- /dev/null +++ b/aibridged/aibridged.go @@ -0,0 +1,284 @@ +package aibridged + +import ( + "context" + "errors" + "io" + "net/http" + "sync" + "time" + + "github.com/hashicorp/yamux" + "github.com/valyala/fasthttp/fasthttputil" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/retry" + + "github.com/coder/aibridge" + "github.com/coder/aibridge/proto" + "github.com/coder/coder/v2/codersdk" +) + +type Dialer func(ctx context.Context) (aibridge.APIClient, error) + +type Server struct { + clientDialer Dialer + clientCh chan aibridge.APIClient + + logger slog.Logger + wg sync.WaitGroup + + // initConnectionCh will receive when the daemon connects to coderd for the + // first time. + initConnectionCh chan struct{} + initConnectionOnce sync.Once + + // mutex protects all subsequent fields + mutex sync.Mutex + // closeContext is canceled when we start closing. + closeContext context.Context + closeCancel context.CancelFunc + // closeError stores the error when closing to return to subsequent callers + closeError error + // closingB is set to true when we start closing + closingB bool + // closedCh will receive when we complete closing + closedCh chan struct{} + // shuttingDownB is set to true when we start graceful shutdown + shuttingDownB bool + // shuttingDownCh will receive when we start graceful shutdown + shuttingDownCh chan struct{} +} + +var _ aibridge.APIServer = &Server{} + +func New(rpcDialer Dialer, logger slog.Logger) (*Server, error) { + if rpcDialer == nil { + return nil, xerrors.Errorf("nil rpcDialer given") + } + + ctx, cancel := context.WithCancel(context.Background()) + daemon := &Server{ + logger: logger, + clientDialer: rpcDialer, + clientCh: make(chan aibridge.APIClient), + closeContext: ctx, + closeCancel: cancel, + closedCh: make(chan struct{}), + shuttingDownCh: make(chan struct{}), + initConnectionCh: make(chan struct{}), + } + + daemon.wg.Add(1) + go daemon.connect() + + return daemon, nil +} + +// Connect establishes a connection to coderd. +func (s *Server) connect() { + defer s.logger.Debug(s.closeContext, "connect loop exited") + defer s.wg.Done() + logConnect := s.logger.Debug + // An exponential back-off occurs when the connection is failing to dial. + // This is to prevent server spam in case of a coderd outage. +connectLoop: + for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(s.closeContext); { + // TODO(dannyk): handle premature close. + //// It's possible for the provisioner daemon to be shut down + //// before the wait is complete! + // if s.isClosed() { + // return + //} + + s.logger.Debug(s.closeContext, "dialing coderd") + client, err := s.clientDialer(s.closeContext) + if err != nil { + if errors.Is(err, context.Canceled) { + return + } + var sdkErr *codersdk.Error + // If something is wrong with our auth, stop trying to connect. + if errors.As(err, &sdkErr) && sdkErr.StatusCode() == http.StatusForbidden { + s.logger.Error(s.closeContext, "not authorized to dial coderd", slog.Error(err)) + return + } + if s.isClosed() { + return + } + s.logger.Warn(s.closeContext, "coderd client failed to dial", slog.Error(err)) + continue + } + + // This log is useful to verify that an external provisioner daemon is + // successfully connecting to coderd. It doesn't add much value if the + // daemon is built-in, so we only log it on the info level if p.externalProvisioner + // is true. This log message is mentioned in the docs: + // https://github.com/coder/coder/blob/5bd86cb1c06561d1d3e90ce689da220467e525c0/docs/admin/provisioners.md#L346 + logConnect(s.closeContext, "successfully connected to coderd") + retrier.Reset() + s.initConnectionOnce.Do(func() { + close(s.initConnectionCh) + }) + + // serve the client until we are closed or it disconnects + for { + select { + case <-s.closeContext.Done(): + client.DRPCConn().Close() + return + case <-client.DRPCConn().Closed(): + logConnect(s.closeContext, "connection to coderd closed") + continue connectLoop + case s.clientCh <- client: + continue + } + } + } +} + +func (s *Server) Client() (aibridge.APIClient, error) { + select { + case <-s.closeContext.Done(): + return nil, xerrors.New("context closed") + case <-s.shuttingDownCh: + // Shutting down should return a nil client and unblock + return nil, xerrors.New("shutting down") + case client := <-s.clientCh: + return client, nil + } +} + +func (s *Server) StoreSession(ctx context.Context, in *proto.StoreSessionRequest) (*proto.StoreSessionResponse, error) { + out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client aibridge.APIClient) (*proto.StoreSessionResponse, error) { + return client.StoreSession(ctx, in) + }) + if err != nil { + return nil, err + } + return out, nil +} + +func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) { + out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client aibridge.APIClient) (*proto.TrackTokenUsageResponse, error) { + return client.TrackTokenUsage(ctx, in) + }) + if err != nil { + return nil, err + } + return out, nil +} + +func (s *Server) TrackUserPrompt(ctx context.Context, in *proto.TrackUserPromptRequest) (*proto.TrackUserPromptResponse, error) { + out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client aibridge.APIClient) (*proto.TrackUserPromptResponse, error) { + return client.TrackUserPrompt(ctx, in) + }) + if err != nil { + return nil, err + } + return out, nil +} + +func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageRequest) (*proto.TrackToolUsageResponse, error) { + out, err := clientDoWithRetries(ctx, s.Client, func(ctx context.Context, client aibridge.APIClient) (*proto.TrackToolUsageResponse, error) { + return client.TrackToolUsage(ctx, in) + }) + if err != nil { + return nil, err + } + return out, nil +} + +// NOTE: mostly copypasta from provisionerd; might be work abstracting. +func retryable(err error) bool { + return xerrors.Is(err, yamux.ErrSessionShutdown) || xerrors.Is(err, io.EOF) || xerrors.Is(err, fasthttputil.ErrInmemoryListenerClosed) || + // annoyingly, dRPC sometimes returns context.Canceled if the transport was closed, even if the context for + // the RPC *is not canceled*. Retrying is fine if the RPC context is not canceled. + xerrors.Is(err, context.Canceled) +} + +// clientDoWithRetries runs the function f with a client, and retries with +// backoff until either the error returned is not retryable() or the context +// expires. +// NOTE: mostly copypasta from provisionerd; might be work abstracting. +func clientDoWithRetries[T any](ctx context.Context, + getClient func() (aibridge.APIClient, error), + f func(context.Context, aibridge.APIClient) (T, error), +) (ret T, _ error) { + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(ctx); { + var empty T + client, err := getClient() + if err != nil { + if retryable(err) { + continue + } + return empty, err + } + resp, err := f(ctx, client) + if retryable(err) { + continue + } + return resp, err + } + return ret, ctx.Err() +} + +// isClosed returns whether the API is closed or not. +func (s *Server) isClosed() bool { + select { + case <-s.closeContext.Done(): + return true + default: + return false + } +} + +// closeWithError closes the provisioner; subsequent reads/writes will return the error err. +func (s *Server) closeWithError(err error) error { + s.mutex.Lock() + first := false + if !s.closingB { + first = true + s.closingB = true + } + // don't hold the mutex while doing I/O. + s.mutex.Unlock() + + if first { + s.closeCancel() + s.logger.Debug(context.Background(), "waiting for goroutines to exit") + s.wg.Wait() + s.logger.Debug(context.Background(), "closing server with error", slog.Error(err)) + s.closeError = err + close(s.closedCh) + return err + } + s.logger.Debug(s.closeContext, "waiting for first closer to complete") + <-s.closedCh + s.logger.Debug(s.closeContext, "first closer completed") + return s.closeError +} + +// Close ends the aibridge daemon. +func (s *Server) Close() error { + if s == nil { + return nil + } + + s.logger.Info(s.closeContext, "closing aibridged") + // TODO: invalidate all running requests (canceling context should be enough?). + errMsg := "aibridged closed gracefully" + err := s.closeWithError(nil) + if err != nil { + errMsg = err.Error() + } + s.logger.Warn(s.closeContext, errMsg) + + return err +} + +func (s *Server) Shutdown(ctx context.Context) error { + // TODO: implement or remove. + return nil +} diff --git a/aibridged/context.go b/aibridged/context.go new file mode 100644 index 0000000000000..7f049d331accc --- /dev/null +++ b/aibridged/context.go @@ -0,0 +1,6 @@ +package aibridged + +type ( + ContextKeyBridgeAPIKey struct{} + ContextKeyBridgeUserID struct{} +) diff --git a/aibridged/middleware.go b/aibridged/middleware.go new file mode 100644 index 0000000000000..de2a511acbd55 --- /dev/null +++ b/aibridged/middleware.go @@ -0,0 +1,71 @@ +package aibridged + +import ( + "bytes" + "context" + "crypto/subtle" + "net/http" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/httpmw" +) + +// AuthMiddleware extracts and validates authorization tokens for AI bridge endpoints. +// It supports both Bearer tokens in Authorization headers and Coder session tokens +// from cookies/headers following the same patterns as existing Coder authentication. +func AuthMiddleware(db database.Store) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Extract token using the same pattern as the bridge + token := extractAuthTokenForBridge(r) + if token == "" { + http.Error(rw, "Authorization token required", http.StatusUnauthorized) + return + } + + // Validate token using httpmw.APIKeyFromRequest + key, _, ok := httpmw.APIKeyFromRequest(ctx, db, func(*http.Request) string { + return token + }, &http.Request{}) + + if !ok { + http.Error(rw, "Invalid or expired session token", http.StatusUnauthorized) + return + } + + ctx = context.WithValue( + context.WithValue(ctx, ContextKeyBridgeUserID{}, key.UserID), + ContextKeyBridgeAPIKey{}, token) + + // Pass request with modify context including the request token. + next.ServeHTTP(rw, r.WithContext(ctx)) + }) + } +} + +// extractAuthTokenForBridge extracts authorization token from HTTP request using multiple sources. +// It checks Authorization header (Bearer token), X-Api-Key header, and Coder session headers and cookies. +func extractAuthTokenForBridge(r *http.Request) string { + // 1. Check Authorization header for Bearer token + authHeader := r.Header.Get("Authorization") + if authHeader != "" { + bearer := []byte("bearer ") + hdr := []byte(authHeader) + + // Use case-insensitive comparison for Bearer token + if len(hdr) >= len(bearer) && subtle.ConstantTimeCompare(bytes.ToLower(hdr[:len(bearer)]), bearer) == 1 { + return string(hdr[len(bearer):]) + } + } + + // 2. Check X-Api-Key header + apiKeyHeader := r.Header.Get("X-Api-Key") + if apiKeyHeader != "" { + return apiKeyHeader + } + + // 3. Fall back to Coder's standard token extraction + return httpmw.APITokenFromRequest(r) +} diff --git a/cli/server.go b/cli/server.go index f9e744761b22e..661db8cb28165 100644 --- a/cli/server.go +++ b/cli/server.go @@ -31,6 +31,7 @@ import ( "sync/atomic" "time" + "github.com/ammario/tlru" "github.com/charmbracelet/lipgloss" "github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-systemd/daemon" @@ -62,6 +63,8 @@ import ( "github.com/coder/serpent" "github.com/coder/wgtunnel/tunnelsdk" + "github.com/coder/aibridge" + "github.com/coder/coder/v2/aibridged" "github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/notifications/reports" "github.com/coder/coder/v2/coderd/runtimeconfig" @@ -1026,6 +1029,12 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. } provisionerdMetrics.Runner.NumDaemons.Set(float64(len(provisionerDaemons))) + // Since errCh only has one buffered slot, all routines + // sending on it must be wrapped in a select/default to + // avoid leaving dangling goroutines waiting for the + // channel to be consumed. + errCh = make(chan error, 1) + shutdownConnsCtx, shutdownConns := context.WithCancel(ctx) defer shutdownConns() @@ -1093,6 +1102,33 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. } }() + // TODO: this shouldn't block. + aiBridgeDaemons := make([]*aibridged.Server, 0) + defer func() { + // We have no graceful shutdown of aiBridgeDaemons + // here because that's handled at the end of main, this + // is here in case the program exits early. + for _, daemon := range aiBridgeDaemons { + _ = daemon.Close() + } + }() + + // Built in aibridge daemons. + for i := int64(0); i < vals.AI.BridgeConfig.Daemons.Value(); i++ { + suffix := fmt.Sprintf("%d", i) + // The suffix is added to the hostname, so we may need to trim to fit into + // the 64 character limit. + hostname := stringutil.Truncate(cliutil.Hostname(), 63-len(suffix)) + name := fmt.Sprintf("%s-%s", hostname, suffix) + daemon, err := newAIBridgeDaemon(ctx, coderAPI, name, vals.AI.BridgeConfig) + if err != nil { + return xerrors.Errorf("create provisioner daemon: %w", err) + } + aiBridgeDaemons = append(aiBridgeDaemons, daemon) + } + coderAPI.AIBridgeDaemons = aiBridgeDaemons + coderAPI.AIBridges = tlru.New[string](tlru.ConstantCost[*aibridge.Bridge], 100) // TODO: configurable. + // Updates the systemd status from activating to activated. _, err = daemon.SdNotify(false, daemon.SdNotifyReady) if err != nil { @@ -1208,6 +1244,34 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. } wg.Wait() + // Shut down aibridge daemons before waiting for WebSockets + // connections to close. + for i, aiBridgeDaemon := range aiBridgeDaemons { + id := i + 1 + wg.Add(1) + go func() { + defer wg.Done() + + r.Verbosef(inv, "Shutting down AI bridge daemon %d...", id) + + err := shutdownWithTimeout(func(ctx context.Context) error { + // We only want to cancel active jobs if we aren't exiting gracefully. + return aiBridgeDaemon.Shutdown(ctx) + }, 5*time.Second) + if err != nil { + cliui.Errorf(inv.Stderr, "Failed to shut down AI bridge daemon %d: %s\n", id, err) + return + } + err = aiBridgeDaemon.Close() + if err != nil { + cliui.Errorf(inv.Stderr, "Close AI bridge daemon %d: %s\n", id, err) + return + } + r.Verbosef(inv, "Gracefully shut down AI bridge daemon %d", id) + }() + } + wg.Wait() + cliui.Info(inv.Stdout, "Waiting for WebSocket connections to close..."+"\n") _ = coderAPICloser.Close() cliui.Info(inv.Stdout, "Done waiting for WebSocket connections"+"\n") @@ -1508,6 +1572,14 @@ func newProvisionerDaemon( }), nil } +func newAIBridgeDaemon(ctx context.Context, coderAPI *coderd.API, name string, bridgeCfg codersdk.AIBridgeConfig) (*aibridged.Server, error) { + return aibridged.New(func(dialCtx context.Context) (aibridge.APIClient, error) { + // This debounces calls to listen every second. + // TODO: is this true / necessary? + return coderAPI.CreateInMemoryAIBridgeDaemon(dialCtx, name) + }, coderAPI.Logger.Named("aibridged").With(slog.F("name", name))) +} + // nolint: revive func PrintLogo(inv *serpent.Invocation, daemonTitle string) { // Only print the logo in TTYs. diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index e23274e442078..576e16817b298 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -709,3 +709,19 @@ workspace_prebuilds: # limit; disabled when set to zero. # (default: 3, type: int) failure_hard_limit: 3 +ai_bridge: + # TODO. + # (default: 3, type: int) + daemons: 3 + # TODO. + # (default: https://api.openai.com/v1/, type: string) + openai_base_url: https://api.openai.com/v1/ + # TODO. + # (default: , type: string) + openai_key: "" + # TODO. + # (default: https://api.anthropic.com/, type: string) + base_url: https://api.anthropic.com/ + # TODO. + # (default: , type: string) + key: "" diff --git a/coderd/aibridge.go b/coderd/aibridge.go new file mode 100644 index 0000000000000..81b93211dfaea --- /dev/null +++ b/coderd/aibridge.go @@ -0,0 +1,148 @@ +package coderd + +import ( + "context" + "fmt" + "net/http" + "time" + + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" + + "cdr.dev/slog" + + "github.com/google/uuid" + + "github.com/coder/aibridge" + "github.com/coder/coder/v2/aibridged" + "github.com/coder/coder/v2/coderd/util/slice" +) + +type rt struct { + http.RoundTripper + + server *aibridged.Server +} + +func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) { + start := time.Now() + defer func() { + fmt.Printf("req to %q started %v completed\n", req.URL.String(), start.Local().Format(time.RFC3339Nano)) + }() + + resp, err := r.RoundTripper.RoundTrip(req) + + return resp, err +} + +func (api *API) bridgeAIRequest(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if len(api.AIBridgeDaemons) == 0 { + http.Error(rw, "no AI bridge daemons running", http.StatusInternalServerError) + return + } + + // Random loadbalancing. + // TODO: introduce better strategy. + server, err := slice.PickRandom(api.AIBridgeDaemons) + if err != nil { + api.Logger.Error(ctx, "failed to pick random AI bridge server", slog.Error(err)) + http.Error(rw, "failed to select AI bridge", http.StatusInternalServerError) + return + } + + key, ok := r.Context().Value(aibridged.ContextKeyBridgeAPIKey{}).(string) + if key == "" || !ok { + http.Error(rw, "unable to retrieve request session key", http.StatusBadRequest) + return + } + + userID, ok := r.Context().Value(aibridged.ContextKeyBridgeUserID{}).(uuid.UUID) + if !ok { + api.Logger.Error(r.Context(), "missing initiator ID in context") + http.Error(rw, "unable to retrieve initiator", http.StatusInternalServerError) + return + } + + r.Header.Set(aibridge.InitiatorHeaderKey, userID.String()) + + bridge, err := api.createOrLoadBridgeForAPIKey(ctx, key, server.Client) + if err != nil { + api.Logger.Error(ctx, "failed to create ai bridge", slog.Error(err)) + http.Error(rw, "failed to create ai bridge", http.StatusInternalServerError) + return + } + http.StripPrefix("/api/v2/aibridge", bridge.Handler()).ServeHTTP(rw, r) +} + +func (api *API) createOrLoadBridgeForAPIKey(ctx context.Context, key string, apiClientFn func() (aibridge.APIClient, error)) (*aibridge.Bridge, error) { + if api.AIBridges == nil { + return nil, xerrors.New("bridge cache storage is not configured") + } + + api.AIBridgesMu.RLock() + val, _, ok := api.AIBridges.Get(key) + api.AIBridgesMu.RUnlock() + + // TODO: TOCTOU potential here + // TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs) + if !ok { + api.AIBridgesMu.Lock() + defer api.AIBridgesMu.Unlock() + + tools, err := api.fetchTools(ctx, api.Logger, key) + if err != nil { + api.Logger.Warn(ctx, "failed to load tools", slog.Error(err)) + } + + // TODO: only instantiate once. + registry := aibridge.ProviderRegistry{ + aibridge.ProviderOpenAI: aibridge.NewOpenAIProvider(api.DeploymentValues.AI.BridgeConfig.OpenAI.BaseURL.String(), api.DeploymentValues.AI.BridgeConfig.OpenAI.Key.String()), + aibridge.ProviderAnthropic: aibridge.NewAnthropicMessagesProvider(api.DeploymentValues.AI.BridgeConfig.Anthropic.BaseURL.String(), api.DeploymentValues.AI.BridgeConfig.Anthropic.Key.String()), + } + bridge, err := aibridge.NewBridge(registry, api.Logger.Named("ai_bridge"), apiClientFn, tools) + if err != nil { + return nil, xerrors.Errorf("create new bridge server: %w", err) + } + + api.Logger.Info(ctx, "created bridge") // TODO: improve usefulness; log user ID? + + api.AIBridges.Set(key, bridge, time.Minute) // TODO: configurable. + val = bridge + } + + return val, nil +} + +func (api *API) fetchTools(ctx context.Context, logger slog.Logger, key string) (map[string][]*aibridge.MCPTool, error) { + url := api.DeploymentValues.AccessURL.String() + "/api/experimental/mcp/http" + coderMCP, err := aibridge.NewMCPToolBridge("coder", url, map[string]string{ + "Coder-Session-Token": key, + }, logger.Named("mcp-bridge-coder")) + if err != nil { + return nil, xerrors.Errorf("coder MCP bridge setup: %w", err) + } + + // TODO: add github mcp if external auth is configured. + var eg errgroup.Group + eg.Go(func() error { + ctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + + err := coderMCP.Init(ctx) + if err == nil { + return nil + } + return xerrors.Errorf("coder: %w", err) + }) + + // This must block requests until MCP proxies are setup. + if err := eg.Wait(); err != nil { + return nil, xerrors.Errorf("MCP proxy init: %w", err) + } + + return map[string][]*aibridge.MCPTool{ + "coder": coderMCP.ListTools(), + }, nil +} diff --git a/coderd/aibridgedserver/aibridgedserver.go b/coderd/aibridgedserver/aibridgedserver.go new file mode 100644 index 0000000000000..03de6d182584a --- /dev/null +++ b/coderd/aibridgedserver/aibridgedserver.go @@ -0,0 +1,137 @@ +package aibridgedserver + +import ( + "context" + "encoding/json" + + "github.com/google/uuid" + "golang.org/x/xerrors" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" + + "cdr.dev/slog" + + proto "github.com/coder/aibridge/proto" + "github.com/coder/coder/v2/coderd/database" +) + +var _ proto.DRPCStoreServer = &Server{} + +type Server struct { + // lifecycleCtx must be tied to the API server's lifecycle + // as when the API server shuts down, we want to cancel any + // long-running operations. + lifecycleCtx context.Context + store database.Store + logger slog.Logger +} + +func NewServer(lifecycleCtx context.Context, store database.Store, logger slog.Logger) (*Server, error) { + return &Server{ + lifecycleCtx: lifecycleCtx, + store: store, + logger: logger.Named("aibridgedserver"), + }, nil +} + +// StoreSession implements proto.DRPCStoreServer. +func (s *Server) StoreSession(ctx context.Context, in *proto.StoreSessionRequest) (*proto.StoreSessionResponse, error) { + sessID, err := uuid.Parse(in.GetSessionId()) + if err != nil { + return nil, xerrors.Errorf("invalid session ID %q: %w", in.GetSessionId(), err) + } + initID, err := uuid.Parse(in.GetInitiatorId()) + if err != nil { + return nil, xerrors.Errorf("invalid initiator ID %q: %w", in.GetInitiatorId(), err) + } + + err = s.store.InsertAIBridgeSession(ctx, database.InsertAIBridgeSessionParams{ + ID: sessID, + InitiatorID: initID, + Provider: in.Provider, + Model: in.Model, + }) + if err != nil { + return nil, xerrors.Errorf("start session: %w", err) + } + + return &proto.StoreSessionResponse{}, nil +} + +func (s *Server) TrackTokenUsage(ctx context.Context, in *proto.TrackTokenUsageRequest) (*proto.TrackTokenUsageResponse, error) { + sessID, err := uuid.Parse(in.GetSessionId()) + if err != nil { + return nil, xerrors.Errorf("failed to parse session_id %q: %w", in.GetSessionId(), err) + } + + err = s.store.InsertAIBridgeTokenUsage(ctx, database.InsertAIBridgeTokenUsageParams{ + ID: uuid.New(), + SessionID: sessID, + ProviderID: in.GetMsgId(), + InputTokens: in.GetInputTokens(), + OutputTokens: in.GetOutputTokens(), + Metadata: s.marshalMetadata(in.GetMetadata()), + }) + if err != nil { + return nil, xerrors.Errorf("insert token usage: %w", err) + } + return &proto.TrackTokenUsageResponse{}, nil +} + +func (s *Server) TrackUserPrompt(ctx context.Context, in *proto.TrackUserPromptRequest) (*proto.TrackUserPromptResponse, error) { + sessID, err := uuid.Parse(in.GetSessionId()) + if err != nil { + return nil, xerrors.Errorf("failed to parse session_id %q: %w", in.GetSessionId(), err) + } + + err = s.store.InsertAIBridgeUserPrompt(ctx, database.InsertAIBridgeUserPromptParams{ + ID: uuid.New(), + SessionID: sessID, + ProviderID: in.GetMsgId(), + Prompt: in.GetPrompt(), + Metadata: s.marshalMetadata(in.GetMetadata()), + }) + if err != nil { + return nil, xerrors.Errorf("insert user prompt: %w", err) + } + return &proto.TrackUserPromptResponse{}, nil +} + +func (s *Server) TrackToolUsage(ctx context.Context, in *proto.TrackToolUsageRequest) (*proto.TrackToolUsageResponse, error) { + sessID, err := uuid.Parse(in.GetSessionId()) + if err != nil { + return nil, xerrors.Errorf("failed to parse session_id %q: %w", in.GetSessionId(), err) + } + + err = s.store.InsertAIBridgeToolUsage(ctx, database.InsertAIBridgeToolUsageParams{ + ID: uuid.New(), + SessionID: sessID, + ProviderID: in.GetMsgId(), + Tool: in.GetTool(), + Input: in.GetInput(), + Injected: in.GetInjected(), + Metadata: s.marshalMetadata(in.GetMetadata()), + }) + if err != nil { + return nil, xerrors.Errorf("insert tool usage: %w", err) + } + return &proto.TrackToolUsageResponse{}, nil +} + +func (s *Server) marshalMetadata(in map[string]*anypb.Any) []byte { + mdMap := map[string]interface{}{} + for k, v := range in { + if v == nil { + continue + } + var sv structpb.Value + if err := v.UnmarshalTo(&sv); err == nil { + mdMap[k] = sv.AsInterface() + } + } + out, err := json.Marshal(mdMap) + if err != nil { + s.logger.Warn(s.lifecycleCtx, "failed to marshal metadata", slog.Error(err)) + } + return out +} diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 0a8b2c07793c3..fcf1fa8dab96e 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -10934,6 +10934,50 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeAnthropicConfig": { + "type": "object", + "properties": { + "base_url": { + "type": "string" + }, + "key": { + "type": "string" + } + } + }, + "codersdk.AIBridgeConfig": { + "type": "object", + "properties": { + "anthropic": { + "$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig" + }, + "daemons": { + "type": "integer" + }, + "openai": { + "$ref": "#/definitions/codersdk.AIBridgeOpenAIConfig" + } + } + }, + "codersdk.AIBridgeOpenAIConfig": { + "type": "object", + "properties": { + "base_url": { + "type": "string" + }, + "key": { + "type": "string" + } + } + }, + "codersdk.AIConfig": { + "type": "object", + "properties": { + "bridge": { + "$ref": "#/definitions/codersdk.AIBridgeConfig" + } + } + }, "codersdk.APIKey": { "type": "object", "required": [ @@ -12551,6 +12595,9 @@ const docTemplate = `{ "agent_stat_refresh_interval": { "type": "integer" }, + "ai": { + "$ref": "#/definitions/codersdk.AIConfig" + }, "allow_workspace_renames": { "type": "boolean" }, diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index cd6537de0e210..eeb2c0f0a5e27 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -9694,6 +9694,50 @@ } } }, + "codersdk.AIBridgeAnthropicConfig": { + "type": "object", + "properties": { + "base_url": { + "type": "string" + }, + "key": { + "type": "string" + } + } + }, + "codersdk.AIBridgeConfig": { + "type": "object", + "properties": { + "anthropic": { + "$ref": "#/definitions/codersdk.AIBridgeAnthropicConfig" + }, + "daemons": { + "type": "integer" + }, + "openai": { + "$ref": "#/definitions/codersdk.AIBridgeOpenAIConfig" + } + } + }, + "codersdk.AIBridgeOpenAIConfig": { + "type": "object", + "properties": { + "base_url": { + "type": "string" + }, + "key": { + "type": "string" + } + } + }, + "codersdk.AIConfig": { + "type": "object", + "properties": { + "bridge": { + "$ref": "#/definitions/codersdk.AIBridgeConfig" + } + } + }, "codersdk.APIKey": { "type": "object", "required": [ @@ -11220,6 +11264,9 @@ "agent_stat_refresh_interval": { "type": "integer" }, + "ai": { + "$ref": "#/definitions/codersdk.AIConfig" + }, "allow_workspace_renames": { "type": "boolean" }, diff --git a/coderd/coderd.go b/coderd/coderd.go index 2aa30c9d7a45c..0b24a46e45385 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -20,6 +20,10 @@ import ( "sync/atomic" "time" + "github.com/ammario/tlru" + + aibridgeproto "github.com/coder/aibridge/proto" + "github.com/coder/coder/v2/aibridged" "github.com/coder/coder/v2/coderd/oauth2provider" "github.com/coder/coder/v2/coderd/pproflabel" "github.com/coder/coder/v2/coderd/prebuilds" @@ -43,6 +47,10 @@ import ( "tailscale.com/types/key" "tailscale.com/util/singleflight" + "github.com/coder/aibridge" + "github.com/coder/coder/v2/coderd/aibridgedserver" + provisionerdproto "github.com/coder/coder/v2/provisionerd/proto" + "cdr.dev/slog" "github.com/coder/quartz" "github.com/coder/serpent" @@ -94,7 +102,6 @@ import ( "github.com/coder/coder/v2/coderd/workspacestats" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/healthsdk" - "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionersdk" "github.com/coder/coder/v2/site" "github.com/coder/coder/v2/tailnet" @@ -624,7 +631,7 @@ func New(options *Options) *API { ExternalURL: buildinfo.ExternalURL(), Version: buildinfo.Version(), AgentAPIVersion: AgentAPIVersionREST, - ProvisionerAPIVersion: proto.CurrentVersion.String(), + ProvisionerAPIVersion: provisionerdproto.CurrentVersion.String(), DashboardURL: api.AccessURL.String(), WorkspaceProxy: false, UpgradeMessage: api.DeploymentValues.CLIUpgradeMessage.String(), @@ -686,7 +693,7 @@ func New(options *Options) *API { }, ProvisionerDaemons: healthcheck.ProvisionerDaemonsReportDeps{ CurrentVersion: buildinfo.Version(), - CurrentAPIMajorVersion: proto.CurrentMajor, + CurrentAPIMajorVersion: provisionerdproto.CurrentMajor, Store: options.Database, StaleInterval: provisionerdserver.StaleInterval, // TimeNow set to default, see healthcheck/provisioner.go @@ -849,7 +856,9 @@ func New(options *Options) *API { expvar.Publish("derp", api.DERPServer.ExpVar()) } }) - cors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value()) + regularCors := httpmw.Cors(options.DeploymentValues.Dangerous.AllowAllCors.Value()) + permissiveCors := httpmw.PermissiveCors() + cors := httpmw.ConditionalCors("/api/v2/aibridge", regularCors, permissiveCors) prometheusMW := httpmw.Prometheus(options.PrometheusRegistry) r.Use( @@ -1566,6 +1575,11 @@ func New(options *Options) *API { r.Use(apiKeyMiddleware) r.Get("/", api.tailnetRPCConn) }) + r.Route("/aibridge", func(r chi.Router) { + r.Use(aibridged.AuthMiddleware(api.Database)) + r.HandleFunc("/openai/*", api.bridgeAIRequest) + r.HandleFunc("/anthropic/*", api.bridgeAIRequest) + }) }) if options.SwaggerEndpoint { @@ -1664,7 +1678,7 @@ type API struct { TailnetClientService *tailnet.ClientService // WebpushDispatcher is a way to send notifications to users via Web Push. WebpushDispatcher webpush.Dispatcher - QuotaCommitter atomic.Pointer[proto.QuotaCommitter] + QuotaCommitter atomic.Pointer[provisionerdproto.QuotaCommitter] AppearanceFetcher atomic.Pointer[appearance.Fetcher] // WorkspaceProxyHostsFn returns the hosts of healthy workspace proxies // for header reasons. @@ -1723,6 +1737,10 @@ type API struct { // dbRolluper rolls up template usage stats from raw agent and app // stats. This is used to provide insights in the WebUI. dbRolluper *dbrollup.Rolluper + + AIBridgeDaemons []*aibridged.Server + AIBridges *tlru.Cache[string, *aibridge.Bridge] + AIBridgesMu sync.RWMutex } // Close waits for all WebSocket connections to drain before returning. @@ -1823,11 +1841,11 @@ type memoryProvisionerDaemonOptions struct { // CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. // Useful when starting coderd and provisionerd in the same process. -func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name string, provisionerTypes []codersdk.ProvisionerType) (client proto.DRPCProvisionerDaemonClient, err error) { +func (api *API) CreateInMemoryProvisionerDaemon(dialCtx context.Context, name string, provisionerTypes []codersdk.ProvisionerType) (client provisionerdproto.DRPCProvisionerDaemonClient, err error) { return api.CreateInMemoryTaggedProvisionerDaemon(dialCtx, name, provisionerTypes, nil) } -func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, name string, provisionerTypes []codersdk.ProvisionerType, provisionerTags map[string]string, opts ...MemoryProvisionerDaemonOption) (client proto.DRPCProvisionerDaemonClient, err error) { +func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, name string, provisionerTypes []codersdk.ProvisionerType, provisionerTags map[string]string, opts ...MemoryProvisionerDaemonOption) (client provisionerdproto.DRPCProvisionerDaemonClient, err error) { options := &memoryProvisionerDaemonOptions{} for _, opt := range opts { opt(options) @@ -1859,7 +1877,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n return nil, xerrors.Errorf("failed to parse built-in provisioner key ID: %w", err) } - apiVersion := proto.CurrentVersion.String() + apiVersion := provisionerdproto.CurrentVersion.String() if options.versionOverride != "" && flag.Lookup("test.v") != nil { // This should only be usable for unit testing. To fake a different provisioner version apiVersion = options.versionOverride @@ -1914,7 +1932,7 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n if err != nil { return nil, err } - err = proto.DRPCRegisterProvisionerDaemon(mux, srv) + err = provisionerdproto.DRPCRegisterProvisionerDaemon(mux, srv) if err != nil { return nil, err } @@ -1949,7 +1967,65 @@ func (api *API) CreateInMemoryTaggedProvisionerDaemon(dialCtx context.Context, n _ = serverSession.Close() }() - return proto.NewDRPCProvisionerDaemonClient(clientSession), nil + return provisionerdproto.NewDRPCProvisionerDaemonClient(clientSession), nil +} + +func (api *API) CreateInMemoryAIBridgeDaemon(dialCtx context.Context, name string) (client aibridgeproto.DRPCStoreClient, err error) { + // TODO(dannyk): implement options. + // TODO(dannyk): implement tracing. + + clientSession, serverSession := drpcsdk.MemTransportPipe() + defer func() { + if err != nil { + _ = clientSession.Close() + _ = serverSession.Close() + } + }() + + // TODO(dannyk): implement API versioning. + // TODO(dannyk): implement database tracking of daemons. + + mux := drpcmux.New() + api.Logger.Debug(dialCtx, "starting in-memory AI bridge daemon", slog.F("name", name)) + logger := api.Logger.Named(fmt.Sprintf("inmem-aibridged-%s", name)) + srv, err := aibridgedserver.NewServer(api.ctx, api.Database, api.Logger) + if err != nil { + return nil, err + } + err = aibridgeproto.DRPCRegisterStore(mux, srv) + if err != nil { + return nil, err + } + server := drpcserver.NewWithOptions(&tracing.DRPCHandler{Handler: mux}, + drpcserver.Options{ + Manager: drpcsdk.DefaultDRPCOptions(nil), + Log: func(err error) { + if xerrors.Is(err, io.EOF) { + return + } + logger.Debug(dialCtx, "drpc server error", slog.Error(err)) + }, + }, + ) + // in-mem pipes aren't technically "websockets" but they have the same properties as far as the + // API is concerned: they are long-lived connections that we need to close before completing + // shutdown of the API. + api.WebsocketWaitMutex.Lock() + api.WebsocketWaitGroup.Add(1) + api.WebsocketWaitMutex.Unlock() + go func() { + defer api.WebsocketWaitGroup.Done() + // Here we pass the background context, since we want the server to keep serving until the + // client hangs up. The aibridged is local, in-mem, so there isn't a danger of losing contact with it and + // having a dead connection we don't know the status of. + err := server.Serve(context.Background(), serverSession) + logger.Info(dialCtx, "AI bridge daemon disconnected", slog.Error(err)) + // close the sessions, so we don't leak goroutines serving them. + _ = clientSession.Close() + _ = serverSession.Close() + }() + + return aibridgeproto.NewDRPCStoreClient(clientSession), nil } func (api *API) DERPMap() *tailcfg.DERPMap { diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 2cbcf1ec6f0d4..b5412f10d3b27 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -3660,6 +3660,26 @@ func (q *querier) GetWorkspacesEligibleForTransition(ctx context.Context, now ti return q.db.GetWorkspacesEligibleForTransition(ctx, now) } +func (q *querier) InsertAIBridgeSession(ctx context.Context, arg database.InsertAIBridgeSessionParams) error { + // TODO: authz. + return q.db.InsertAIBridgeSession(ctx, arg) +} + +func (q *querier) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error { + // TODO: authz. + return q.db.InsertAIBridgeTokenUsage(ctx, arg) +} + +func (q *querier) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error { + // TODO: authz. + return q.db.InsertAIBridgeToolUsage(ctx, arg) +} + +func (q *querier) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error { + // TODO: authz. + return q.db.InsertAIBridgeUserPrompt(ctx, arg) +} + func (q *querier) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { return insert(q.log, q.auth, rbac.ResourceApiKey.WithOwner(arg.UserID.String()), diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 9bfdbf049ac1a..b1f20178eb495 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -2126,6 +2126,34 @@ func (m queryMetricsStore) GetWorkspacesEligibleForTransition(ctx context.Contex return workspaces, err } +func (m queryMetricsStore) InsertAIBridgeSession(ctx context.Context, arg database.InsertAIBridgeSessionParams) error { + start := time.Now() + r0 := m.s.InsertAIBridgeSession(ctx, arg) + m.queryLatencies.WithLabelValues("InsertAIBridgeSession").Observe(time.Since(start).Seconds()) + return r0 +} + +func (m queryMetricsStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error { + start := time.Now() + r0 := m.s.InsertAIBridgeTokenUsage(ctx, arg) + m.queryLatencies.WithLabelValues("InsertAIBridgeTokenUsage").Observe(time.Since(start).Seconds()) + return r0 +} + +func (m queryMetricsStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error { + start := time.Now() + r0 := m.s.InsertAIBridgeToolUsage(ctx, arg) + m.queryLatencies.WithLabelValues("InsertAIBridgeToolUsage").Observe(time.Since(start).Seconds()) + return r0 +} + +func (m queryMetricsStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error { + start := time.Now() + r0 := m.s.InsertAIBridgeUserPrompt(ctx, arg) + m.queryLatencies.WithLabelValues("InsertAIBridgeUserPrompt").Observe(time.Since(start).Seconds()) + return r0 +} + func (m queryMetricsStore) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { start := time.Now() key, err := m.s.InsertAPIKey(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 934cd434426b2..708e87232d45a 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -4545,6 +4545,62 @@ func (mr *MockStoreMockRecorder) InTx(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InTx", reflect.TypeOf((*MockStore)(nil).InTx), arg0, arg1) } +// InsertAIBridgeSession mocks base method. +func (m *MockStore) InsertAIBridgeSession(ctx context.Context, arg database.InsertAIBridgeSessionParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertAIBridgeSession", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// InsertAIBridgeSession indicates an expected call of InsertAIBridgeSession. +func (mr *MockStoreMockRecorder) InsertAIBridgeSession(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeSession", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeSession), ctx, arg) +} + +// InsertAIBridgeTokenUsage mocks base method. +func (m *MockStore) InsertAIBridgeTokenUsage(ctx context.Context, arg database.InsertAIBridgeTokenUsageParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertAIBridgeTokenUsage", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// InsertAIBridgeTokenUsage indicates an expected call of InsertAIBridgeTokenUsage. +func (mr *MockStoreMockRecorder) InsertAIBridgeTokenUsage(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeTokenUsage", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeTokenUsage), ctx, arg) +} + +// InsertAIBridgeToolUsage mocks base method. +func (m *MockStore) InsertAIBridgeToolUsage(ctx context.Context, arg database.InsertAIBridgeToolUsageParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertAIBridgeToolUsage", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// InsertAIBridgeToolUsage indicates an expected call of InsertAIBridgeToolUsage. +func (mr *MockStoreMockRecorder) InsertAIBridgeToolUsage(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeToolUsage", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeToolUsage), ctx, arg) +} + +// InsertAIBridgeUserPrompt mocks base method. +func (m *MockStore) InsertAIBridgeUserPrompt(ctx context.Context, arg database.InsertAIBridgeUserPromptParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "InsertAIBridgeUserPrompt", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// InsertAIBridgeUserPrompt indicates an expected call of InsertAIBridgeUserPrompt. +func (mr *MockStoreMockRecorder) InsertAIBridgeUserPrompt(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertAIBridgeUserPrompt", reflect.TypeOf((*MockStore)(nil).InsertAIBridgeUserPrompt), ctx, arg) +} + // InsertAPIKey mocks base method. func (m *MockStore) InsertAPIKey(ctx context.Context, arg database.InsertAPIKeyParams) (database.APIKey, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 7bea770248310..0de816b466f49 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -814,6 +814,44 @@ BEGIN END; $$; +CREATE TABLE aibridge_sessions ( + id uuid NOT NULL, + initiator_id uuid NOT NULL, + provider text NOT NULL, + model text NOT NULL, + created_at timestamp with time zone DEFAULT now() +); + +CREATE TABLE aibridge_token_usages ( + id uuid NOT NULL, + session_id uuid NOT NULL, + provider_id text NOT NULL, + input_tokens bigint NOT NULL, + output_tokens bigint NOT NULL, + metadata jsonb, + created_at timestamp with time zone DEFAULT now() +); + +CREATE TABLE aibridge_tool_usages ( + id uuid NOT NULL, + session_id uuid NOT NULL, + provider_id text NOT NULL, + tool text NOT NULL, + input text NOT NULL, + injected boolean DEFAULT false NOT NULL, + metadata jsonb, + created_at timestamp with time zone DEFAULT now() +); + +CREATE TABLE aibridge_user_prompts ( + id uuid NOT NULL, + session_id uuid NOT NULL, + provider_id text NOT NULL, + prompt text NOT NULL, + metadata jsonb, + created_at timestamp with time zone DEFAULT now() +); + CREATE TABLE api_keys ( id text NOT NULL, hashed_secret bytea NOT NULL, @@ -2498,6 +2536,18 @@ ALTER TABLE ONLY workspace_resource_metadata ALTER COLUMN id SET DEFAULT nextval ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); +ALTER TABLE ONLY aibridge_sessions + ADD CONSTRAINT aibridge_sessions_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY aibridge_token_usages + ADD CONSTRAINT aibridge_token_usages_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY aibridge_tool_usages + ADD CONSTRAINT aibridge_tool_usages_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY aibridge_user_prompts + ADD CONSTRAINT aibridge_user_prompts_pkey PRIMARY KEY (id); + ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); @@ -2784,6 +2834,24 @@ CREATE INDEX idx_agent_stats_created_at ON workspace_agent_stats USING btree (cr CREATE INDEX idx_agent_stats_user_id ON workspace_agent_stats USING btree (user_id); +CREATE INDEX idx_aibridge_sessions_model ON aibridge_sessions USING btree (model); + +CREATE INDEX idx_aibridge_sessions_provider ON aibridge_sessions USING btree (provider); + +CREATE INDEX idx_aibridge_token_usages_session_id ON aibridge_token_usages USING btree (session_id); + +CREATE INDEX idx_aibridge_token_usages_session_provider_id ON aibridge_token_usages USING btree (session_id, provider_id); + +CREATE INDEX idx_aibridge_tool_usages_session_id ON aibridge_tool_usages USING btree (session_id); + +CREATE INDEX idx_aibridge_tool_usages_session_provider_id ON aibridge_tool_usages USING btree (session_id, provider_id); + +CREATE INDEX idx_aibridge_tool_usages_tool ON aibridge_tool_usages USING btree (tool); + +CREATE INDEX idx_aibridge_user_prompts_session_id ON aibridge_user_prompts USING btree (session_id); + +CREATE INDEX idx_aibridge_user_prompts_session_provider_id ON aibridge_user_prompts USING btree (session_id, provider_id); + CREATE UNIQUE INDEX idx_api_key_name ON api_keys USING btree (user_id, token_name) WHERE (login_type = 'token'::login_type); CREATE INDEX idx_api_keys_user ON api_keys USING btree (user_id); @@ -3018,6 +3086,18 @@ COMMENT ON TRIGGER workspace_agent_name_unique_trigger ON workspace_agents IS 'U the uniqueness requirement. A trigger allows us to enforce uniqueness going forward without requiring a migration to clean up historical data.'; +ALTER TABLE ONLY aibridge_sessions + ADD CONSTRAINT aibridge_sessions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id) ON DELETE CASCADE; + +ALTER TABLE ONLY aibridge_token_usages + ADD CONSTRAINT aibridge_token_usages_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; + +ALTER TABLE ONLY aibridge_tool_usages + ADD CONSTRAINT aibridge_tool_usages_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; + +ALTER TABLE ONLY aibridge_user_prompts + ADD CONSTRAINT aibridge_user_prompts_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; + ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 33aa8edd69032..d2bf59c1c356d 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -6,6 +6,10 @@ type ForeignKeyConstraint string // ForeignKeyConstraint enums. const ( + ForeignKeyAibridgeSessionsInitiatorID ForeignKeyConstraint = "aibridge_sessions_initiator_id_fkey" // ALTER TABLE ONLY aibridge_sessions ADD CONSTRAINT aibridge_sessions_initiator_id_fkey FOREIGN KEY (initiator_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyAibridgeTokenUsagesSessionID ForeignKeyConstraint = "aibridge_token_usages_session_id_fkey" // ALTER TABLE ONLY aibridge_token_usages ADD CONSTRAINT aibridge_token_usages_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; + ForeignKeyAibridgeToolUsagesSessionID ForeignKeyConstraint = "aibridge_tool_usages_session_id_fkey" // ALTER TABLE ONLY aibridge_tool_usages ADD CONSTRAINT aibridge_tool_usages_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; + ForeignKeyAibridgeUserPromptsSessionID ForeignKeyConstraint = "aibridge_user_prompts_session_id_fkey" // ALTER TABLE ONLY aibridge_user_prompts ADD CONSTRAINT aibridge_user_prompts_session_id_fkey FOREIGN KEY (session_id) REFERENCES aibridge_sessions(id) ON DELETE CASCADE; ForeignKeyAPIKeysUserIDUUID ForeignKeyConstraint = "api_keys_user_id_uuid_fkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_user_id_uuid_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyConnectionLogsOrganizationID ForeignKeyConstraint = "connection_logs_organization_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyConnectionLogsWorkspaceID ForeignKeyConstraint = "connection_logs_workspace_id_fkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_workspace_id_fkey FOREIGN KEY (workspace_id) REFERENCES workspaces(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000358_aibridge.down.sql b/coderd/database/migrations/000358_aibridge.down.sql new file mode 100644 index 0000000000000..ef4b81bd08b0c --- /dev/null +++ b/coderd/database/migrations/000358_aibridge.down.sql @@ -0,0 +1,4 @@ +DROP TABLE IF EXISTS aibridge_tool_usages CASCADE; +DROP TABLE IF EXISTS aibridge_user_prompts CASCADE; +DROP TABLE IF EXISTS aibridge_token_usages CASCADE; +DROP TABLE IF EXISTS aibridge_sessions CASCADE; diff --git a/coderd/database/migrations/000358_aibridge.up.sql b/coderd/database/migrations/000358_aibridge.up.sql new file mode 100644 index 0000000000000..71362d7d79dfd --- /dev/null +++ b/coderd/database/migrations/000358_aibridge.up.sql @@ -0,0 +1,50 @@ +CREATE TABLE IF NOT EXISTS aibridge_sessions ( + id UUID PRIMARY KEY, + initiator_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + provider TEXT NOT NULL, + model TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX idx_aibridge_sessions_provider ON aibridge_sessions(provider); +CREATE INDEX idx_aibridge_sessions_model ON aibridge_sessions(model); + +CREATE TABLE IF NOT EXISTS aibridge_token_usages ( + id UUID PRIMARY KEY, + session_id UUID NOT NULL REFERENCES aibridge_sessions(id) ON DELETE CASCADE, + provider_id TEXT NOT NULL, + input_tokens BIGINT NOT NULL, + output_tokens BIGINT NOT NULL, + metadata JSONB DEFAULT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX idx_aibridge_token_usages_session_id ON aibridge_token_usages(session_id); +CREATE INDEX idx_aibridge_token_usages_session_provider_id ON aibridge_token_usages(session_id, provider_id); + +CREATE TABLE IF NOT EXISTS aibridge_user_prompts ( + id UUID PRIMARY KEY, + session_id UUID NOT NULL REFERENCES aibridge_sessions(id) ON DELETE CASCADE, + provider_id TEXT NOT NULL, + prompt TEXT NOT NULL, + metadata JSONB DEFAULT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX idx_aibridge_user_prompts_session_id ON aibridge_user_prompts(session_id); +CREATE INDEX idx_aibridge_user_prompts_session_provider_id ON aibridge_user_prompts(session_id, provider_id); + +CREATE TABLE IF NOT EXISTS aibridge_tool_usages ( + id UUID PRIMARY KEY, + session_id UUID NOT NULL REFERENCES aibridge_sessions(id) ON DELETE CASCADE, + provider_id TEXT NOT NULL, + tool TEXT NOT NULL, + input TEXT NOT NULL, + injected BOOLEAN NOT NULL DEFAULT FALSE, + metadata JSONB DEFAULT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); + +CREATE INDEX idx_aibridge_tool_usages_session_id ON aibridge_tool_usages(session_id); +CREATE INDEX idx_aibridge_tool_usages_tool ON aibridge_tool_usages(tool); +CREATE INDEX idx_aibridge_tool_usages_session_provider_id ON aibridge_tool_usages(session_id, provider_id); diff --git a/coderd/database/models.go b/coderd/database/models.go index 75d2b941dab3c..9f9061bab4d8f 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -2968,6 +2968,44 @@ type APIKey struct { TokenName string `db:"token_name" json:"token_name"` } +type AibridgeSession struct { + ID uuid.UUID `db:"id" json:"id"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + CreatedAt sql.NullTime `db:"created_at" json:"created_at"` +} + +type AibridgeTokenUsage struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + InputTokens int64 `db:"input_tokens" json:"input_tokens"` + OutputTokens int64 `db:"output_tokens" json:"output_tokens"` + Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"` + CreatedAt sql.NullTime `db:"created_at" json:"created_at"` +} + +type AibridgeToolUsage struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + Tool string `db:"tool" json:"tool"` + Input string `db:"input" json:"input"` + Injected bool `db:"injected" json:"injected"` + Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"` + CreatedAt sql.NullTime `db:"created_at" json:"created_at"` +} + +type AibridgeUserPrompt struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + Prompt string `db:"prompt" json:"prompt"` + Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"` + CreatedAt sql.NullTime `db:"created_at" json:"created_at"` +} + type AuditLog struct { ID uuid.UUID `db:"id" json:"id"` Time time.Time `db:"time" json:"time"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9c179351b26e3..82506246ad873 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -477,6 +477,10 @@ type sqlcQuerier interface { GetWorkspacesAndAgentsByOwnerID(ctx context.Context, ownerID uuid.UUID) ([]GetWorkspacesAndAgentsByOwnerIDRow, error) GetWorkspacesByTemplateID(ctx context.Context, templateID uuid.UUID) ([]WorkspaceTable, error) GetWorkspacesEligibleForTransition(ctx context.Context, now time.Time) ([]GetWorkspacesEligibleForTransitionRow, error) + InsertAIBridgeSession(ctx context.Context, arg InsertAIBridgeSessionParams) error + InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) error + InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) error + InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) error InsertAPIKey(ctx context.Context, arg InsertAPIKeyParams) (APIKey, error) // We use the organization_id as the id // for simplicity since all users is diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index c039b7f94e8d5..442dbcc06fd96 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -107,6 +107,115 @@ func (q *sqlQuerier) ActivityBumpWorkspace(ctx context.Context, arg ActivityBump return err } +const insertAIBridgeSession = `-- name: InsertAIBridgeSession :exec +INSERT INTO aibridge_sessions (id, initiator_id, provider, model) +VALUES ($1::uuid, $2::uuid, $3, $4) +` + +type InsertAIBridgeSessionParams struct { + ID uuid.UUID `db:"id" json:"id"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` +} + +func (q *sqlQuerier) InsertAIBridgeSession(ctx context.Context, arg InsertAIBridgeSessionParams) error { + _, err := q.db.ExecContext(ctx, insertAIBridgeSession, + arg.ID, + arg.InitiatorID, + arg.Provider, + arg.Model, + ) + return err +} + +const insertAIBridgeTokenUsage = `-- name: InsertAIBridgeTokenUsage :exec +INSERT INTO aibridge_token_usages ( + id, session_id, provider_id, input_tokens, output_tokens, metadata +) VALUES ( + $1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb) +) +` + +type InsertAIBridgeTokenUsageParams struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + InputTokens int64 `db:"input_tokens" json:"input_tokens"` + OutputTokens int64 `db:"output_tokens" json:"output_tokens"` + Metadata json.RawMessage `db:"metadata" json:"metadata"` +} + +func (q *sqlQuerier) InsertAIBridgeTokenUsage(ctx context.Context, arg InsertAIBridgeTokenUsageParams) error { + _, err := q.db.ExecContext(ctx, insertAIBridgeTokenUsage, + arg.ID, + arg.SessionID, + arg.ProviderID, + arg.InputTokens, + arg.OutputTokens, + arg.Metadata, + ) + return err +} + +const insertAIBridgeToolUsage = `-- name: InsertAIBridgeToolUsage :exec +INSERT INTO aibridge_tool_usages ( + id, session_id, provider_id, tool, input, injected, metadata +) VALUES ( + $1, $2, $3, $4, $5, $6, COALESCE($7::jsonb, '{}'::jsonb) +) +` + +type InsertAIBridgeToolUsageParams struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + Tool string `db:"tool" json:"tool"` + Input string `db:"input" json:"input"` + Injected bool `db:"injected" json:"injected"` + Metadata json.RawMessage `db:"metadata" json:"metadata"` +} + +func (q *sqlQuerier) InsertAIBridgeToolUsage(ctx context.Context, arg InsertAIBridgeToolUsageParams) error { + _, err := q.db.ExecContext(ctx, insertAIBridgeToolUsage, + arg.ID, + arg.SessionID, + arg.ProviderID, + arg.Tool, + arg.Input, + arg.Injected, + arg.Metadata, + ) + return err +} + +const insertAIBridgeUserPrompt = `-- name: InsertAIBridgeUserPrompt :exec +INSERT INTO aibridge_user_prompts ( + id, session_id, provider_id, prompt, metadata +) VALUES ( + $1, $2, $3, $4, COALESCE($5::jsonb, '{}'::jsonb) +) +` + +type InsertAIBridgeUserPromptParams struct { + ID uuid.UUID `db:"id" json:"id"` + SessionID uuid.UUID `db:"session_id" json:"session_id"` + ProviderID string `db:"provider_id" json:"provider_id"` + Prompt string `db:"prompt" json:"prompt"` + Metadata json.RawMessage `db:"metadata" json:"metadata"` +} + +func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIBridgeUserPromptParams) error { + _, err := q.db.ExecContext(ctx, insertAIBridgeUserPrompt, + arg.ID, + arg.SessionID, + arg.ProviderID, + arg.Prompt, + arg.Metadata, + ) + return err +} + const deleteAPIKeyByID = `-- name: DeleteAPIKeyByID :exec DELETE FROM api_keys diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql new file mode 100644 index 0000000000000..4089b5c9dec32 --- /dev/null +++ b/coderd/database/queries/aibridge.sql @@ -0,0 +1,24 @@ +-- name: InsertAIBridgeSession :exec +INSERT INTO aibridge_sessions (id, initiator_id, provider, model) +VALUES (@id::uuid, @initiator_id::uuid, @provider, @model); + +-- name: InsertAIBridgeTokenUsage :exec +INSERT INTO aibridge_token_usages ( + id, session_id, provider_id, input_tokens, output_tokens, metadata +) VALUES ( + @id, @session_id, @provider_id, @input_tokens, @output_tokens, COALESCE(@metadata::jsonb, '{}'::jsonb) +); + +-- name: InsertAIBridgeUserPrompt :exec +INSERT INTO aibridge_user_prompts ( + id, session_id, provider_id, prompt, metadata +) VALUES ( + @id, @session_id, @provider_id, @prompt, COALESCE(@metadata::jsonb, '{}'::jsonb) +); + +-- name: InsertAIBridgeToolUsage :exec +INSERT INTO aibridge_tool_usages ( + id, session_id, provider_id, tool, input, injected, metadata +) VALUES ( + @id, @session_id, @provider_id, @tool, @input, @injected, COALESCE(@metadata::jsonb, '{}'::jsonb) +); diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 3ed326102b18c..0208ac962ceda 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -7,6 +7,10 @@ type UniqueConstraint string // UniqueConstraint enums. const ( UniqueAgentStatsPkey UniqueConstraint = "agent_stats_pkey" // ALTER TABLE ONLY workspace_agent_stats ADD CONSTRAINT agent_stats_pkey PRIMARY KEY (id); + UniqueAibridgeSessionsPkey UniqueConstraint = "aibridge_sessions_pkey" // ALTER TABLE ONLY aibridge_sessions ADD CONSTRAINT aibridge_sessions_pkey PRIMARY KEY (id); + UniqueAibridgeTokenUsagesPkey UniqueConstraint = "aibridge_token_usages_pkey" // ALTER TABLE ONLY aibridge_token_usages ADD CONSTRAINT aibridge_token_usages_pkey PRIMARY KEY (id); + UniqueAibridgeToolUsagesPkey UniqueConstraint = "aibridge_tool_usages_pkey" // ALTER TABLE ONLY aibridge_tool_usages ADD CONSTRAINT aibridge_tool_usages_pkey PRIMARY KEY (id); + UniqueAibridgeUserPromptsPkey UniqueConstraint = "aibridge_user_prompts_pkey" // ALTER TABLE ONLY aibridge_user_prompts ADD CONSTRAINT aibridge_user_prompts_pkey PRIMARY KEY (id); UniqueAPIKeysPkey UniqueConstraint = "api_keys_pkey" // ALTER TABLE ONLY api_keys ADD CONSTRAINT api_keys_pkey PRIMARY KEY (id); UniqueAuditLogsPkey UniqueConstraint = "audit_logs_pkey" // ALTER TABLE ONLY audit_logs ADD CONSTRAINT audit_logs_pkey PRIMARY KEY (id); UniqueConnectionLogsPkey UniqueConstraint = "connection_logs_pkey" // ALTER TABLE ONLY connection_logs ADD CONSTRAINT connection_logs_pkey PRIMARY KEY (id); diff --git a/coderd/httpmw/cors.go b/coderd/httpmw/cors.go index 218aab6609f60..13250112c1ede 100644 --- a/coderd/httpmw/cors.go +++ b/coderd/httpmw/cors.go @@ -120,3 +120,37 @@ func WorkspaceAppCors(regex *regexp.Regexp, app appurl.ApplicationURL) func(next AllowCredentials: true, }) } + +// PermissiveCors creates a very permissive CORS middleware that allows all origins, +// methods, and headers. This bypasses go-chi's CORS library for maximum compatibility. +func PermissiveCors() func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH") + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Max-Age", "86400") + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// ConditionalCors applies permissive CORS for requests with the specified prefix, +// and regular CORS for all other requests. +func ConditionalCors(prefix string, regularCors, permissiveCors func(next http.Handler) http.Handler) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, prefix) { + permissiveCors(next).ServeHTTP(w, r) + } else { + regularCors(next).ServeHTTP(w, r) + } + }) + } +} diff --git a/coderd/util/slice/slice.go b/coderd/util/slice/slice.go index bb2011c05d1b2..6fc5975ae47d3 100644 --- a/coderd/util/slice/slice.go +++ b/coderd/util/slice/slice.go @@ -1,7 +1,10 @@ package slice import ( + "math/rand" + "golang.org/x/exp/constraints" + "golang.org/x/xerrors" ) // ToStrings works for any type where the base type is a string. @@ -90,6 +93,25 @@ func Find[T any](haystack []T, cond func(T) bool) (T, bool) { return empty, false } +// PickRandom returns one element at a random index. +// NOTE: callers MUST protect haystack from modification while this function is called to prevent panics. +func PickRandom[T any](haystack []T) (T, error) { + var zero T + + length := len(haystack) + if length == 0 { + return zero, xerrors.New("cannot pick from empty slice") + } + + index := rand.Intn(length) + + if length != len(haystack) { + return zero, xerrors.New("slice was modified during operation") + } + + return haystack[index], nil +} + // Filter returns all elements that satisfy the condition. func Filter[T any](haystack []T, cond func(T) bool) []T { out := make([]T, 0, len(haystack)) diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 1d6fa4572772e..e83251316d2ad 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -497,6 +497,7 @@ type DeploymentValues struct { WorkspaceHostnameSuffix serpent.String `json:"workspace_hostname_suffix,omitempty" typescript:",notnull"` Prebuilds PrebuildsConfig `json:"workspace_prebuilds,omitempty" typescript:",notnull"` HideAITasks serpent.Bool `json:"hide_ai_tasks,omitempty" typescript:",notnull"` + AI AIConfig `json:"ai,omitempty"` Config serpent.YAMLConfigPath `json:"config,omitempty" typescript:",notnull"` WriteConfig serpent.Bool `json:"write_config,omitempty" typescript:",notnull"` @@ -1155,6 +1156,10 @@ func (c *DeploymentValues) Options() serpent.OptionSet { Parent: &deploymentGroupNotifications, YAML: "inbox", } + deploymentGroupAIBridge = serpent.Group{ + Name: "AI Bridge", + YAML: "ai_bridge", + } ) httpAddress := serpent.Option{ @@ -3205,11 +3210,88 @@ Write out the current server config as YAML to stdout.`, Group: &deploymentGroupClient, YAML: "hideAITasks", }, + + // AI Bridge Options + { + Name: "AI Bridge Daemons", + Description: "TODO.", + Flag: "ai-bridge-daemons", + Env: "CODER_AI_BRIDGE_DAEMONS", + Value: &c.AI.BridgeConfig.Daemons, + Default: "3", + Group: &deploymentGroupAIBridge, + YAML: "daemons", + Hidden: true, + }, + { + Name: "AI Bridge OpenAI Base URL", + Description: "TODO.", + Flag: "ai-bridge-openai-base-url", + Env: "CODER_AI_BRIDGE_OPENAI_BASE_URL", + Value: &c.AI.BridgeConfig.OpenAI.BaseURL, + Default: "https://api.openai.com/v1/", + Group: &deploymentGroupAIBridge, + YAML: "openai_base_url", + Hidden: true, + }, + { + Name: "AI Bridge OpenAI Key", + Description: "TODO.", + Flag: "ai-bridge-openai-key", + Env: "CODER_AI_BRIDGE_OPENAI_KEY", + Value: &c.AI.BridgeConfig.OpenAI.Key, + Default: "", + Group: &deploymentGroupAIBridge, + YAML: "openai_key", + Hidden: true, + }, + { + Name: "AI Bridge Anthropic Base URL", + Description: "TODO.", + Flag: "ai-bridge-anthropic-base-url", + Env: "CODER_AI_BRIDGE_ANTHROPIC_BASE_URL", + Value: &c.AI.BridgeConfig.Anthropic.BaseURL, + Default: "https://api.anthropic.com/", + Group: &deploymentGroupAIBridge, + YAML: "base_url", + Hidden: true, + }, + { + Name: "AI Bridge Anthropic KEY", + Description: "TODO.", + Flag: "ai-bridge-anthropic-key", + Env: "CODER_AI_BRIDGE_ANTHROPIC_KEY", + Value: &c.AI.BridgeConfig.Anthropic.Key, + Default: "", + Group: &deploymentGroupAIBridge, + YAML: "key", + Hidden: true, + }, } return opts } +type AIBridgeConfig struct { + Daemons serpent.Int64 `json:"daemons" typescript:",notnull"` + OpenAI AIBridgeOpenAIConfig `json:"openai" typescript:",notnull"` + Anthropic AIBridgeAnthropicConfig `json:"anthropic" typescript:",notnull"` +} + +type AIBridgeOpenAIConfig struct { + BaseURL serpent.String `json:"base_url" typescript:",notnull"` + Key serpent.String `json:"key" typescript:",notnull"` +} + +type AIBridgeAnthropicConfig struct { + BaseURL serpent.String `json:"base_url" typescript:",notnull"` + Key serpent.String `json:"key" typescript:",notnull"` +} + +type AIConfig struct { + BridgeConfig AIBridgeConfig `json:"bridge,omitempty"` +} + type SupportConfig struct { Links serpent.Struct[[]LinkConfig] `json:"links" typescript:",notnull"` } diff --git a/docs/reference/api/general.md b/docs/reference/api/general.md index 72543f6774dfd..df1422eb9cf5b 100644 --- a/docs/reference/api/general.md +++ b/docs/reference/api/general.md @@ -161,6 +161,19 @@ curl -X GET http://coder-server:8080/api/v2/deployment/config \ "user": {} }, "agent_stat_refresh_interval": 0, + "ai": { + "bridge": { + "anthropic": { + "base_url": "string", + "key": "string" + }, + "daemons": 0, + "openai": { + "base_url": "string", + "key": "string" + } + } + }, "allow_workspace_renames": true, "autobuild_poll_interval": 0, "browser_only": true, diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index b3824d0c9b9b8..6d8e791b955ea 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -335,6 +335,86 @@ | `groups` | array of [codersdk.Group](#codersdkgroup) | false | | | | `users` | array of [codersdk.ReducedUser](#codersdkreduceduser) | false | | | +## codersdk.AIBridgeAnthropicConfig + +```json +{ + "base_url": "string", + "key": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------------|--------|----------|--------------|-------------| +| `base_url` | string | false | | | +| `key` | string | false | | | + +## codersdk.AIBridgeConfig + +```json +{ + "anthropic": { + "base_url": "string", + "key": "string" + }, + "daemons": 0, + "openai": { + "base_url": "string", + "key": "string" + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-------------|----------------------------------------------------------------------|----------|--------------|-------------| +| `anthropic` | [codersdk.AIBridgeAnthropicConfig](#codersdkaibridgeanthropicconfig) | false | | | +| `daemons` | integer | false | | | +| `openai` | [codersdk.AIBridgeOpenAIConfig](#codersdkaibridgeopenaiconfig) | false | | | + +## codersdk.AIBridgeOpenAIConfig + +```json +{ + "base_url": "string", + "key": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------------|--------|----------|--------------|-------------| +| `base_url` | string | false | | | +| `key` | string | false | | | + +## codersdk.AIConfig + +```json +{ + "bridge": { + "anthropic": { + "base_url": "string", + "key": "string" + }, + "daemons": 0, + "openai": { + "base_url": "string", + "key": "string" + } + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|----------|----------------------------------------------------|----------|--------------|-------------| +| `bridge` | [codersdk.AIBridgeConfig](#codersdkaibridgeconfig) | false | | | + ## codersdk.APIKey ```json @@ -2152,6 +2232,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "user": {} }, "agent_stat_refresh_interval": 0, + "ai": { + "bridge": { + "anthropic": { + "base_url": "string", + "key": "string" + }, + "daemons": 0, + "openai": { + "base_url": "string", + "key": "string" + } + } + }, "allow_workspace_renames": true, "autobuild_poll_interval": 0, "browser_only": true, @@ -2639,6 +2732,19 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o "user": {} }, "agent_stat_refresh_interval": 0, + "ai": { + "bridge": { + "anthropic": { + "base_url": "string", + "key": "string" + }, + "daemons": 0, + "openai": { + "base_url": "string", + "key": "string" + } + } + }, "allow_workspace_renames": true, "autobuild_poll_interval": 0, "browser_only": true, @@ -3017,6 +3123,7 @@ CreateWorkspaceRequest provides options for creating a new workspace. Only one o | `address` | [serpent.HostPort](#serpenthostport) | false | | Deprecated: Use HTTPAddress or TLS.Address instead. | | `agent_fallback_troubleshooting_url` | [serpent.URL](#serpenturl) | false | | | | `agent_stat_refresh_interval` | integer | false | | | +| `ai` | [codersdk.AIConfig](#codersdkaiconfig) | false | | | | `allow_workspace_renames` | boolean | false | | | | `autobuild_poll_interval` | integer | false | | | | `browser_only` | boolean | false | | | diff --git a/go.mod b/go.mod index e10c7a248db7e..f7ce0e9c6ca0e 100644 --- a/go.mod +++ b/go.mod @@ -196,7 +196,7 @@ require ( go.uber.org/mock v0.5.0 go4.org/netipx v0.0.0-20230728180743-ad4cb58a6516 golang.org/x/crypto v0.41.0 - golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 + golang.org/x/exp v0.0.0-20250811191247-51f88131bc50 golang.org/x/mod v0.27.0 golang.org/x/net v0.43.0 golang.org/x/oauth2 v0.30.0 @@ -208,14 +208,14 @@ require ( golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da google.golang.org/api v0.246.0 google.golang.org/grpc v1.74.2 - google.golang.org/protobuf v1.36.6 + google.golang.org/protobuf v1.36.7 gopkg.in/DataDog/dd-trace-go.v1 v1.74.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 gvisor.dev/gvisor v0.0.0-20240509041132-65b30f7869dc kernel.org/pub/linux/libs/security/libcap/cap v1.2.73 - storj.io/drpc v0.0.33 - tailscale.com v1.80.3 + storj.io/drpc v0.0.34 + tailscale.com v1.86.4 ) require ( @@ -466,7 +466,10 @@ require ( require github.com/coder/clistat v1.0.0 -require github.com/SherClockHolmes/webpush-go v1.4.0 +require ( + github.com/SherClockHolmes/webpush-go v1.4.0 + github.com/coder/aibridge v0.0.0 +) require ( github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect @@ -484,8 +487,12 @@ require ( github.com/fsnotify/fsnotify v1.9.0 github.com/go-git/go-git/v5 v5.16.2 github.com/mark3labs/mcp-go v0.37.0 + github.com/tidwall/sjson v1.2.5 // indirect ) +// aibridge-related deps and directives. +replace github.com/coder/aibridge v0.0.0 => /home/coder/aibridge + require ( cel.dev/expr v0.24.0 // indirect cloud.google.com/go v0.120.0 // indirect @@ -524,15 +531,14 @@ require ( github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/moby/sys/user v0.4.0 // indirect github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect - github.com/openai/openai-go v1.7.0 // indirect + github.com/openai/openai-go v1.12.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/puzpuzpuz/xsync/v3 v3.5.1 // indirect github.com/samber/lo v1.50.0 // indirect github.com/sergeymakinen/go-bmp v1.0.0 // indirect github.com/sergeymakinen/go-ico v1.0.0-beta.0 // indirect github.com/spiffe/go-spiffe/v2 v2.5.0 // indirect - github.com/tidwall/sjson v1.2.5 // indirect - github.com/tmaxmax/go-sse v0.10.0 // indirect + github.com/tmaxmax/go-sse v0.11.0 // indirect github.com/ulikunitz/xz v0.5.12 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect diff --git a/go.sum b/go.sum index 3575f35177154..c91ed089ebf8a 100644 --- a/go.sum +++ b/go.sum @@ -1625,8 +1625,8 @@ github.com/open-telemetry/opentelemetry-collector-contrib/pkg/sampling v0.120.1 github.com/open-telemetry/opentelemetry-collector-contrib/pkg/sampling v0.120.1/go.mod h1:01TvyaK8x640crO2iFwW/6CFCZgNsOvOGH3B5J239m0= github.com/open-telemetry/opentelemetry-collector-contrib/processor/probabilisticsamplerprocessor v0.120.1 h1:TCyOus9tym82PD1VYtthLKMVMlVyRwtDI4ck4SR2+Ok= github.com/open-telemetry/opentelemetry-collector-contrib/processor/probabilisticsamplerprocessor v0.120.1/go.mod h1:Z/S1brD5gU2Ntht/bHxBVnGxXKTvZDr0dNv/riUzPmY= -github.com/openai/openai-go v1.7.0 h1:M1JfDjQgo3d3PsLyZgpGUG0wUAaUAitqJPM4Rl56dCA= -github.com/openai/openai-go v1.7.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -1825,8 +1825,8 @@ github.com/tklauser/go-sysconf v0.3.15 h1:VE89k0criAymJ/Os65CSn1IXaol+1wrsFHEB8O github.com/tklauser/go-sysconf v0.3.15/go.mod h1:Dmjwr6tYFIseJw7a3dRLJfsHAMXZ3nEnL/aZY+0IuI4= github.com/tklauser/numcpus v0.10.0 h1:18njr6LDBk1zuna922MgdjQuJFjrdppsZG60sHGfjso= github.com/tklauser/numcpus v0.10.0/go.mod h1:BiTKazU708GQTYF4mB+cmlpT2Is1gLk7XVuEeem8LsQ= -github.com/tmaxmax/go-sse v0.10.0 h1:j9F93WB4Hxt8wUf6oGffMm4dutALvUPoDDxfuDQOSqA= -github.com/tmaxmax/go-sse v0.10.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8= +github.com/tmaxmax/go-sse v0.11.0 h1:nogmJM6rJUoOLoAwEKeQe5XlVpt9l7N82SS1jI7lWFg= +github.com/tmaxmax/go-sse v0.11.0/go.mod h1:u/2kZQR1tyngo1lKaNCj1mJmhXGZWS1Zs5yiSOD+Eg8= github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a h1:eg5FkNoQp76ZsswyGZ+TjYqA/rhKefxK8BW7XOlQsxo= github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a/go.mod h1:e/8TmrdreH0sZOw2DFKBaUV7bvDWRq6SeM9PzkuVM68= github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= @@ -2031,8 +2031,8 @@ golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 h1:R84qjqJb5nVJMxqWYb3np9L5ZsaDtB+a39EqjV0JSUM= -golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0/go.mod h1:S9Xr4PYopiDyqSyp5NjCrhFrqg6A5zA2E/iPHPhqnS8= +golang.org/x/exp v0.0.0-20250811191247-51f88131bc50 h1:3yiSh9fhy5/RhCSntf4Sy0Tnx50DmMpQ4MQdKKk4yg4= +golang.org/x/exp v0.0.0-20250811191247-51f88131bc50/go.mod h1:rT6SFzZ7oxADUDx58pcaKFTcZ+inxAa9fTrYx/uVYwg= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -2707,8 +2707,8 @@ google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqw google.golang.org/protobuf v1.29.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= -google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= -google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= +google.golang.org/protobuf v1.36.7 h1:IgrO7UwFQGJdRNXH/sQux4R1Dj1WAKcLElzeeRaXV2A= +google.golang.org/protobuf v1.36.7/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= gopkg.in/DataDog/dd-trace-go.v1 v1.74.0 h1:wScziU1ff6Bnyr8MEyxATPSLJdnLxKz3p6RsA8FUaek= gopkg.in/DataDog/dd-trace-go.v1 v1.74.0/go.mod h1:ReNBsNfnsjVC7GsCe80zRcykL/n+nxvsNrg3NbjuleM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -2800,5 +2800,5 @@ sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= software.sslmate.com/src/go-pkcs12 v0.2.0 h1:nlFkj7bTysH6VkC4fGphtjXRbezREPgrHuJG20hBGPE= software.sslmate.com/src/go-pkcs12 v0.2.0/go.mod h1:23rNcYsMabIc1otwLpTkCCPwUq6kQsTyowttG/as0kQ= -storj.io/drpc v0.0.33 h1:yCGZ26r66ZdMP0IcTYsj7WDAUIIjzXk6DJhbhvt9FHI= -storj.io/drpc v0.0.33/go.mod h1:vR804UNzhBa49NOJ6HeLjd2H3MakC1j5Gv8bsOQT6N4= +storj.io/drpc v0.0.34 h1:q9zlQKfJ5A7x8NQNFk8x7eKUF78FMhmAbZLnFK+og7I= +storj.io/drpc v0.0.34/go.mod h1:Y9LZaa8esL1PW2IDMqJE7CFSNq7d5bQ3RI7mGPtmKMg= diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 6f5ab307a2fa8..848c396274355 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -6,6 +6,30 @@ export interface ACLAvailable { readonly groups: readonly Group[]; } +// From codersdk/deployment.go +export interface AIBridgeAnthropicConfig { + readonly base_url: string; + readonly key: string; +} + +// From codersdk/deployment.go +export interface AIBridgeConfig { + readonly daemons: number; + readonly openai: AIBridgeOpenAIConfig; + readonly anthropic: AIBridgeAnthropicConfig; +} + +// From codersdk/deployment.go +export interface AIBridgeOpenAIConfig { + readonly base_url: string; + readonly key: string; +} + +// From codersdk/deployment.go +export interface AIConfig { + readonly bridge?: AIBridgeConfig; +} + // From codersdk/aitasks.go export const AITaskPromptParameterName = "AI Prompt"; @@ -828,6 +852,7 @@ export interface DeploymentValues { readonly workspace_hostname_suffix?: string; readonly workspace_prebuilds?: PrebuildsConfig; readonly hide_ai_tasks?: boolean; + readonly ai?: AIConfig; readonly config?: string; readonly write_config?: boolean; readonly address?: string;