diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index dd6c958448391..e4edd51bc9cb1 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -4693,6 +4693,44 @@ const docTemplate = `{ } } }, + "/workspaceagents/{workspaceagent}/legacy": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": [ + "application/json" + ], + "tags": [ + "Enterprise" + ], + "summary": "Agent is legacy", + "operationId": "agent-is-legacy", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace Agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspaceagents/{workspaceagent}/listening-ports": { "get": { "security": [ @@ -5147,6 +5185,28 @@ const docTemplate = `{ } } }, + "/workspaceproxies/me/coordinate": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": [ + "Enterprise" + ], + "summary": "Workspace Proxy Coordinate", + "operationId": "workspace-proxy-coordinate", + "responses": { + "101": { + "description": "Switching Protocols" + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspaceproxies/me/goingaway": { "post": { "security": [ @@ -10881,6 +10941,17 @@ const docTemplate = `{ } } }, + "wsproxysdk.AgentIsLegacyResponse": { + "type": "object", + "properties": { + "found": { + "type": "boolean" + }, + "legacy": { + "type": "boolean" + } + } + }, "wsproxysdk.IssueSignedAppTokenResponse": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index aa2fedba869a6..283c2290c8437 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -4129,6 +4129,40 @@ } } }, + "/workspaceagents/{workspaceagent}/legacy": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "produces": ["application/json"], + "tags": ["Enterprise"], + "summary": "Agent is legacy", + "operationId": "agent-is-legacy", + "parameters": [ + { + "type": "string", + "format": "uuid", + "description": "Workspace Agent ID", + "name": "workspaceagent", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/wsproxysdk.AgentIsLegacyResponse" + } + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspaceagents/{workspaceagent}/listening-ports": { "get": { "security": [ @@ -4537,6 +4571,26 @@ } } }, + "/workspaceproxies/me/coordinate": { + "get": { + "security": [ + { + "CoderSessionToken": [] + } + ], + "tags": ["Enterprise"], + "summary": "Workspace Proxy Coordinate", + "operationId": "workspace-proxy-coordinate", + "responses": { + "101": { + "description": "Switching Protocols" + } + }, + "x-apidocgen": { + "skip": true + } + } + }, "/workspaceproxies/me/goingaway": { "post": { "security": [ @@ -9912,6 +9966,17 @@ } } }, + "wsproxysdk.AgentIsLegacyResponse": { + "type": "object", + "properties": { + "found": { + "type": "boolean" + }, + "legacy": { + "type": "boolean" + } + } + }, "wsproxysdk.IssueSignedAppTokenResponse": { "type": "object", "properties": { diff --git a/coderd/coderd.go b/coderd/coderd.go index 3d41d62fded00..abfeaf89c9e86 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -199,7 +199,7 @@ func New(options *Options) *API { options.Authorizer, options.Logger.Named("authz_querier"), ) - experiments := initExperiments( + experiments := ReadExperiments( options.Logger, options.DeploymentValues.Experiments.Value(), ) if options.AppHostname != "" && options.AppHostnameRegex == nil || options.AppHostname == "" && options.AppHostnameRegex != nil { @@ -370,7 +370,9 @@ func New(options *Options) *API { options.Logger, options.DERPServer, options.DERPMap, - &api.TailnetCoordinator, + func(context.Context) (tailnet.MultiAgentConn, error) { + return (*api.TailnetCoordinator.Load()).ServeMultiAgent(uuid.New()), nil + }, wsconncache.New(api._dialWorkspaceAgentTailnet, 0), ) if err != nil { @@ -1081,7 +1083,7 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce ti } // nolint:revive -func initExperiments(log slog.Logger, raw []string) codersdk.Experiments { +func ReadExperiments(log slog.Logger, raw []string) codersdk.Experiments { exps := make([]codersdk.Experiment, 0, len(raw)) for _, v := range raw { switch v { diff --git a/coderd/coderdtest/coderdtest.go b/coderd/coderdtest/coderdtest.go index f03741a46e648..71882acec4a10 100644 --- a/coderd/coderdtest/coderdtest.go +++ b/coderd/coderdtest/coderdtest.go @@ -384,6 +384,7 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can TemplateScheduleStore: &templateScheduleStore, TLSCertificates: options.TLSCertificates, TrialGenerator: options.TrialGenerator, + TailnetCoordinator: options.Coordinator, DERPMap: derpMap, MetricsCacheRefreshInterval: options.MetricsCacheRefreshInterval, AgentStatsRefreshInterval: options.AgentStatsRefreshInterval, diff --git a/coderd/httpapi/websocket.go b/coderd/httpapi/websocket.go index 62ff0431107d3..60904396099a1 100644 --- a/coderd/httpapi/websocket.go +++ b/coderd/httpapi/websocket.go @@ -25,3 +25,24 @@ func Heartbeat(ctx context.Context, conn *websocket.Conn) { } } } + +// Heartbeat loops to ping a WebSocket to keep it alive. It kills the connection +// on ping failure. +func HeartbeatClose(ctx context.Context, exit func(), conn *websocket.Conn) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + } + err := conn.Ping(ctx) + if err != nil { + _ = conn.Close(websocket.StatusGoingAway, "Ping failed") + exit() + return + } + } +} diff --git a/coderd/httpmw/groupparam.go b/coderd/httpmw/groupparam.go index db226c263b6d9..5b6d3bfe2dd15 100644 --- a/coderd/httpmw/groupparam.go +++ b/coderd/httpmw/groupparam.go @@ -64,7 +64,7 @@ func ExtractGroupParam(db database.Store) func(http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - groupID, parsed := parseUUID(rw, r, "group") + groupID, parsed := ParseUUIDParam(rw, r, "group") if !parsed { return } diff --git a/coderd/httpmw/httpmw.go b/coderd/httpmw/httpmw.go index 74dd987248b87..f6a0dac8b0b65 100644 --- a/coderd/httpmw/httpmw.go +++ b/coderd/httpmw/httpmw.go @@ -11,8 +11,8 @@ import ( "github.com/coder/coder/codersdk" ) -// parseUUID consumes a url parameter and parses it as a UUID. -func parseUUID(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) { +// ParseUUIDParam consumes a url parameter and parses it as a UUID. +func ParseUUIDParam(rw http.ResponseWriter, r *http.Request, param string) (uuid.UUID, bool) { rawID := chi.URLParam(r, param) if rawID == "" { httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ diff --git a/coderd/httpmw/httpmw_internal_test.go b/coderd/httpmw/httpmw_internal_test.go index 381c8608d2649..87aa3a6960822 100644 --- a/coderd/httpmw/httpmw_internal_test.go +++ b/coderd/httpmw/httpmw_internal_test.go @@ -29,7 +29,7 @@ func TestParseUUID_Valid(t *testing.T) { ctx.URLParams.Add(testParam, testWorkspaceAgentID) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) - parsed, ok := parseUUID(rw, r, "workspaceagent") + parsed, ok := ParseUUIDParam(rw, r, "workspaceagent") assert.True(t, ok, "UUID should be parsed") assert.Equal(t, testWorkspaceAgentID, parsed.String()) } @@ -44,7 +44,7 @@ func TestParseUUID_Invalid(t *testing.T) { ctx.URLParams.Add(testParam, "wrong-id") r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) - _, ok := parseUUID(rw, r, "workspaceagent") + _, ok := ParseUUIDParam(rw, r, "workspaceagent") assert.False(t, ok, "UUID should not be parsed") assert.Equal(t, http.StatusBadRequest, rw.Code) diff --git a/coderd/httpmw/organizationparam.go b/coderd/httpmw/organizationparam.go index ce2e4f483c5b4..55ceec57387ff 100644 --- a/coderd/httpmw/organizationparam.go +++ b/coderd/httpmw/organizationparam.go @@ -39,7 +39,7 @@ func ExtractOrganizationParam(db database.Store) func(http.Handler) http.Handler return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - orgID, ok := parseUUID(rw, r, "organization") + orgID, ok := ParseUUIDParam(rw, r, "organization") if !ok { return } diff --git a/coderd/httpmw/templateparam.go b/coderd/httpmw/templateparam.go index 1ba57167d5483..eadb072d50131 100644 --- a/coderd/httpmw/templateparam.go +++ b/coderd/httpmw/templateparam.go @@ -27,7 +27,7 @@ func ExtractTemplateParam(db database.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - templateID, parsed := parseUUID(rw, r, "template") + templateID, parsed := ParseUUIDParam(rw, r, "template") if !parsed { return } diff --git a/coderd/httpmw/templateversionparam.go b/coderd/httpmw/templateversionparam.go index de86a5d1ac5f0..9f8f1c58561c6 100644 --- a/coderd/httpmw/templateversionparam.go +++ b/coderd/httpmw/templateversionparam.go @@ -29,7 +29,7 @@ func ExtractTemplateVersionParam(db database.Store) func(http.Handler) http.Hand return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - templateVersionID, parsed := parseUUID(rw, r, "templateversion") + templateVersionID, parsed := ParseUUIDParam(rw, r, "templateversion") if !parsed { return } diff --git a/coderd/httpmw/workspaceagentparam.go b/coderd/httpmw/workspaceagentparam.go index cc0f372dc4c04..7e31c9e15be31 100644 --- a/coderd/httpmw/workspaceagentparam.go +++ b/coderd/httpmw/workspaceagentparam.go @@ -29,7 +29,7 @@ func ExtractWorkspaceAgentParam(db database.Store) func(http.Handler) http.Handl return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - agentUUID, parsed := parseUUID(rw, r, "workspaceagent") + agentUUID, parsed := ParseUUIDParam(rw, r, "workspaceagent") if !parsed { return } diff --git a/coderd/httpmw/workspacebuildparam.go b/coderd/httpmw/workspacebuildparam.go index 285c3ffae7a97..518029465eb12 100644 --- a/coderd/httpmw/workspacebuildparam.go +++ b/coderd/httpmw/workspacebuildparam.go @@ -27,7 +27,7 @@ func ExtractWorkspaceBuildParam(db database.Store) func(http.Handler) http.Handl return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - workspaceBuildID, parsed := parseUUID(rw, r, "workspacebuild") + workspaceBuildID, parsed := ParseUUIDParam(rw, r, "workspacebuild") if !parsed { return } diff --git a/coderd/httpmw/workspaceparam.go b/coderd/httpmw/workspaceparam.go index fc7b1ade08316..b0f264abe3619 100644 --- a/coderd/httpmw/workspaceparam.go +++ b/coderd/httpmw/workspaceparam.go @@ -30,7 +30,7 @@ func ExtractWorkspaceParam(db database.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - workspaceID, parsed := parseUUID(rw, r, "workspace") + workspaceID, parsed := ParseUUIDParam(rw, r, "workspace") if !parsed { return } diff --git a/coderd/httpmw/workspaceresourceparam.go b/coderd/httpmw/workspaceresourceparam.go index ecc1acd67614b..41d19a4ea0519 100644 --- a/coderd/httpmw/workspaceresourceparam.go +++ b/coderd/httpmw/workspaceresourceparam.go @@ -29,7 +29,7 @@ func ExtractWorkspaceResourceParam(db database.Store) func(http.Handler) http.Ha return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - resourceUUID, parsed := parseUUID(rw, r, "workspaceresource") + resourceUUID, parsed := ParseUUIDParam(rw, r, "workspaceresource") if !parsed { return } diff --git a/coderd/tailnet.go b/coderd/tailnet.go index a1559e4efcd52..bfbfcabdebc19 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -22,6 +22,7 @@ import ( "github.com/coder/coder/codersdk" "github.com/coder/coder/site" "github.com/coder/coder/tailnet" + "github.com/coder/retry" ) var tailnetTransport *http.Transport @@ -41,7 +42,7 @@ func NewServerTailnet( logger slog.Logger, derpServer *derp.Server, derpMap *tailcfg.DERPMap, - coord *atomic.Pointer[tailnet.Coordinator], + getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error), cache *wsconncache.Cache, ) (*ServerTailnet, error) { logger = logger.Named("servertailnet") @@ -56,20 +57,23 @@ func NewServerTailnet( serverCtx, cancel := context.WithCancel(ctx) tn := &ServerTailnet{ - ctx: serverCtx, - cancel: cancel, - logger: logger, - conn: conn, - coord: coord, - cache: cache, - agentNodes: map[uuid.UUID]time.Time{}, - agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, - transport: tailnetTransport.Clone(), + ctx: serverCtx, + cancel: cancel, + logger: logger, + conn: conn, + getMultiAgent: getMultiAgent, + cache: cache, + agentNodes: map[uuid.UUID]time.Time{}, + agentTickets: map[uuid.UUID]map[uuid.UUID]struct{}{}, + transport: tailnetTransport.Clone(), } tn.transport.DialContext = tn.dialContext tn.transport.MaxIdleConnsPerHost = 10 tn.transport.MaxIdleConns = 0 - agentConn := (*coord.Load()).ServeMultiAgent(uuid.New()) + agentConn, err := getMultiAgent(ctx) + if err != nil { + return nil, xerrors.Errorf("get initial multi agent: %w", err) + } tn.agentConn.Store(&agentConn) err = tn.getAgentConn().UpdateSelf(conn.Node()) @@ -86,19 +90,21 @@ func NewServerTailnet( // This is set to allow local DERP traffic to be proxied through memory // instead of needing to hit the external access URL. Don't use the ctx // given in this callback, it's only valid while connecting. - conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { - if !region.EmbeddedRelay { - return nil - } - left, right := net.Pipe() - go func() { - defer left.Close() - defer right.Close() - brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) - derpServer.Accept(ctx, right, brw, "internal") - }() - return left - }) + if derpServer != nil { + conn.SetDERPRegionDialer(func(_ context.Context, region *tailcfg.DERPRegion) net.Conn { + if !region.EmbeddedRelay { + return nil + } + left, right := net.Pipe() + go func() { + defer left.Close() + defer right.Close() + brw := bufio.NewReadWriter(bufio.NewReader(right), bufio.NewWriter(right)) + derpServer.Accept(ctx, right, brw, "internal") + }() + return left + }) + } go tn.watchAgentUpdates() go tn.expireOldAgents() @@ -167,30 +173,38 @@ func (s *ServerTailnet) getAgentConn() tailnet.MultiAgentConn { } func (s *ServerTailnet) reinitCoordinator() { - s.nodesMu.Lock() - agentConn := (*s.coord.Load()).ServeMultiAgent(uuid.New()) - s.agentConn.Store(&agentConn) - - // Resubscribe to all of the agents we're tracking. - for agentID := range s.agentNodes { - err := agentConn.SubscribeAgent(agentID) + for retrier := retry.New(25*time.Millisecond, 5*time.Second); retrier.Wait(s.ctx); { + s.nodesMu.Lock() + agentConn, err := s.getMultiAgent(s.ctx) if err != nil { - s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) + s.nodesMu.Unlock() + s.logger.Error(s.ctx, "reinit multi agent", slog.Error(err)) + continue } + s.agentConn.Store(&agentConn) + + // Resubscribe to all of the agents we're tracking. + for agentID := range s.agentNodes { + err := agentConn.SubscribeAgent(agentID) + if err != nil { + s.logger.Warn(s.ctx, "resubscribe to agent", slog.Error(err), slog.F("agent_id", agentID)) + } + } + s.nodesMu.Unlock() + return } - s.nodesMu.Unlock() } type ServerTailnet struct { ctx context.Context cancel func() - logger slog.Logger - conn *tailnet.Conn - coord *atomic.Pointer[tailnet.Coordinator] - agentConn atomic.Pointer[tailnet.MultiAgentConn] - cache *wsconncache.Cache - nodesMu sync.Mutex + logger slog.Logger + conn *tailnet.Conn + getMultiAgent func(context.Context) (tailnet.MultiAgentConn, error) + agentConn atomic.Pointer[tailnet.MultiAgentConn] + cache *wsconncache.Cache + nodesMu sync.Mutex // agentNodes is a map of agent tailnetNodes the server wants to keep a // connection to. It contains the last time the agent was connected to. agentNodes map[uuid.UUID]time.Time diff --git a/coderd/tailnet_test.go b/coderd/tailnet_test.go index 16d597607312c..d6341391934a7 100644 --- a/coderd/tailnet_test.go +++ b/coderd/tailnet_test.go @@ -8,7 +8,6 @@ import ( "net/http/httptest" "net/netip" "net/url" - "sync/atomic" "testing" "github.com/google/uuid" @@ -133,9 +132,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A DERPMap: derpMap, } - var coordPtr atomic.Pointer[tailnet.Coordinator] coord := tailnet.NewCoordinator(logger) - coordPtr.Store(&coord) t.Cleanup(func() { _ = coord.Close() }) @@ -194,7 +191,7 @@ func setupAgent(t *testing.T, agentAddresses []netip.Prefix) (uuid.UUID, agent.A logger, derpServer, manifest.DERPMap, - &coordPtr, + func(context.Context) (tailnet.MultiAgentConn, error) { return coord.ServeMultiAgent(uuid.New()), nil }, cache, ) require.NoError(t, err) diff --git a/codersdk/agentsdk/agentsdk.go b/codersdk/agentsdk/agentsdk.go index 6c60cd2303315..90f15eff649e3 100644 --- a/codersdk/agentsdk/agentsdk.go +++ b/codersdk/agentsdk/agentsdk.go @@ -13,17 +13,14 @@ import ( "time" "cloud.google.com/go/compute/metadata" + "github.com/google/uuid" "golang.org/x/xerrors" "nhooyr.io/websocket" "tailscale.com/tailcfg" - "github.com/coder/retry" - "cdr.dev/slog" - - "github.com/google/uuid" - "github.com/coder/coder/codersdk" + "github.com/coder/retry" ) // New returns a client that is used to interact with the diff --git a/docs/api/schemas.md b/docs/api/schemas.md index 51fe80fdab213..c7860c4d22a0a 100644 --- a/docs/api/schemas.md +++ b/docs/api/schemas.md @@ -6919,6 +6919,22 @@ _None_ | `username_or_id` | string | false | | For the following fields, if the AccessMethod is AccessMethodTerminal, then only AgentNameOrID may be set and it must be a UUID. The other fields must be left blank. | | `workspace_name_or_id` | string | false | | | +## wsproxysdk.AgentIsLegacyResponse + +```json +{ + "found": true, + "legacy": true +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +| -------- | ------- | -------- | ------------ | ----------- | +| `found` | boolean | false | | | +| `legacy` | boolean | false | | | + ## wsproxysdk.IssueSignedAppTokenResponse ```json diff --git a/enterprise/cli/licenses.go b/enterprise/cli/licenses.go index 4258081df3e24..e4bf3e0731636 100644 --- a/enterprise/cli/licenses.go +++ b/enterprise/cli/licenses.go @@ -10,12 +10,12 @@ import ( "strings" "time" + "github.com/google/uuid" "golang.org/x/xerrors" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" - "github.com/google/uuid" ) var jwtRegexp = regexp.MustCompile(`^[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+$`) diff --git a/enterprise/cli/proxyserver.go b/enterprise/cli/proxyserver.go index e04162fa9f196..822bebc699940 100644 --- a/enterprise/cli/proxyserver.go +++ b/enterprise/cli/proxyserver.go @@ -25,6 +25,7 @@ import ( "github.com/coder/coder/cli" "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/codersdk" @@ -220,6 +221,7 @@ func (*RootCmd) proxyServer() *clibase.Cmd { proxy, err := wsproxy.New(ctx, &wsproxy.Options{ Logger: logger, + Experiments: coderd.ReadExperiments(logger, cfg.Experiments.Value()), HTTPClient: httpClient, DashboardURL: primaryAccessURL.Value(), AccessURL: cfg.AccessURL.Value(), diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index ab517375a457f..d815100af7f25 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -125,6 +125,15 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { r.Use(apiKeyMiddleware) r.Post("/", api.reconnectingPTYSignedToken) }) + + r.With( + apiKeyMiddlewareOptional, + httpmw.ExtractWorkspaceProxy(httpmw.ExtractWorkspaceProxyConfig{ + DB: options.Database, + Optional: true, + }), + httpmw.RequireAPIKeyOrWorkspaceProxyAuth(), + ).Get("/workspaceagents/{workspaceagent}/legacy", api.agentIsLegacy) r.Route("/workspaceproxies", func(r chi.Router) { r.Use( api.moonsEnabledMW, @@ -143,6 +152,7 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { Optional: false, }), ) + r.Get("/coordinate", api.workspaceProxyCoordinate) r.Post("/issue-signed-app-token", api.workspaceProxyIssueSignedAppToken) r.Post("/register", api.workspaceProxyRegister) r.Post("/goingaway", api.workspaceProxyGoingAway) diff --git a/enterprise/coderd/coderdenttest/proxytest.go b/enterprise/coderd/coderdenttest/proxytest.go index 86b2538f7a673..baaa9a308b89a 100644 --- a/enterprise/coderd/coderdenttest/proxytest.go +++ b/enterprise/coderd/coderdenttest/proxytest.go @@ -25,7 +25,8 @@ import ( ) type ProxyOptions struct { - Name string + Name string + Experiments codersdk.Experiments TLSCertificates []tls.Certificate AppHostname string @@ -118,6 +119,7 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie wssrv, err := wsproxy.New(ctx, &wsproxy.Options{ Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + Experiments: options.Experiments, DashboardURL: coderdAPI.AccessURL, AccessURL: accessURL, AppHostname: options.AppHostname, diff --git a/enterprise/coderd/workspaceproxycoordinate.go b/enterprise/coderd/workspaceproxycoordinate.go new file mode 100644 index 0000000000000..919098a3d8b6a --- /dev/null +++ b/enterprise/coderd/workspaceproxycoordinate.go @@ -0,0 +1,78 @@ +package coderd + +import ( + "net/http" + + "github.com/google/uuid" + "nhooyr.io/websocket" + + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/coderd/httpmw" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/tailnet" + "github.com/coder/coder/enterprise/wsproxy/wsproxysdk" +) + +// @Summary Agent is legacy +// @ID agent-is-legacy +// @Security CoderSessionToken +// @Produce json +// @Tags Enterprise +// @Param workspaceagent path string true "Workspace Agent ID" format(uuid) +// @Success 200 {object} wsproxysdk.AgentIsLegacyResponse +// @Router /workspaceagents/{workspaceagent}/legacy [get] +// @x-apidocgen {"skip": true} +func (api *API) agentIsLegacy(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + agentID, ok := httpmw.ParseUUIDParam(rw, r, "workspaceagent") + if !ok { + httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing UUID in URL.", + }) + return + } + + node := (*api.AGPL.TailnetCoordinator.Load()).Node(agentID) + httpapi.Write(ctx, rw, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{ + Found: node != nil, + Legacy: node != nil && + len(node.Addresses) > 0 && + node.Addresses[0].Addr() == codersdk.WorkspaceAgentIP, + }) +} + +// @Summary Workspace Proxy Coordinate +// @ID workspace-proxy-coordinate +// @Security CoderSessionToken +// @Tags Enterprise +// @Success 101 +// @Router /workspaceproxies/me/coordinate [get] +// @x-apidocgen {"skip": true} +func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + api.AGPL.WebsocketWaitMutex.Lock() + api.AGPL.WebsocketWaitGroup.Add(1) + api.AGPL.WebsocketWaitMutex.Unlock() + defer api.AGPL.WebsocketWaitGroup.Done() + + conn, err := websocket.Accept(rw, r, nil) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Failed to accept websocket.", + Detail: err.Error(), + }) + return + } + + id := uuid.New() + sub := (*api.AGPL.TailnetCoordinator.Load()).ServeMultiAgent(id) + nc := websocket.NetConn(ctx, conn, websocket.MessageText) + defer nc.Close() + + err = tailnet.ServeWorkspaceProxy(ctx, nc, sub) + if err != nil { + _ = conn.Close(websocket.StatusInternalError, err.Error()) + } +} diff --git a/enterprise/coderd/workspaceproxycoordinator_test.go b/enterprise/coderd/workspaceproxycoordinator_test.go new file mode 100644 index 0000000000000..6a2df0d6cd279 --- /dev/null +++ b/enterprise/coderd/workspaceproxycoordinator_test.go @@ -0,0 +1,158 @@ +package coderd_test + +import ( + "context" + "net/netip" + "testing" + "time" + + "github.com/google/uuid" + "github.com/moby/moby/pkg/namesgenerator" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/key" + + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/coderdtest" + "github.com/coder/coder/coderd/database/dbtestutil" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/coderd/coderdenttest" + "github.com/coder/coder/enterprise/coderd/license" + "github.com/coder/coder/enterprise/wsproxy/wsproxysdk" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/testutil" +) + +// workspaceProxyCoordinate and agentIsLegacy are both tested by wsproxy tests. + +func Test_agentIsLegacy(t *testing.T) { + t.Parallel() + + t.Run("Legacy", func(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + dv.Experiments = []string{ + string(codersdk.ExperimentMoons), + "*", + } + + var ( + ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) + db, pubsub = dbtestutil.NewDB(t) + logger = slogtest.Make(t, nil) + coordinator = agpl.NewCoordinator(logger) + client, _ = coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + DeploymentValues: dv, + Coordinator: coordinator, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + ) + defer cancel() + + nodeID := uuid.New() + ma := coordinator.ServeMultiAgent(nodeID) + defer ma.Close() + require.NoError(t, ma.UpdateSelf(&agpl.Node{ + ID: 55, + AsOf: time.Unix(1689653252, 0), + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + PreferredDERP: 0, + DERPLatency: map[string]float64{ + "0": 1.0, + }, + DERPForcedWebsocket: map[int]string{}, + Addresses: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + AllowedIPs: []netip.Prefix{netip.PrefixFrom(codersdk.WorkspaceAgentIP, 128)}, + Endpoints: []string{"192.168.1.1:18842"}, + })) + + proxyRes, err := client.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{ + Name: namesgenerator.GetRandomName(1), + Icon: "/emojis/flag.png", + }) + require.NoError(t, err) + + proxyClient := wsproxysdk.New(client.URL) + proxyClient.SetSessionToken(proxyRes.ProxyToken) + + legacyRes, err := proxyClient.AgentIsLegacy(ctx, nodeID) + require.NoError(t, err) + + assert.True(t, legacyRes.Found) + assert.True(t, legacyRes.Legacy) + }) + + t.Run("NotLegacy", func(t *testing.T) { + t.Parallel() + + dv := coderdtest.DeploymentValues(t) + dv.Experiments = []string{ + string(codersdk.ExperimentMoons), + "*", + } + + var ( + ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) + db, pubsub = dbtestutil.NewDB(t) + logger = slogtest.Make(t, nil) + coordinator = agpl.NewCoordinator(logger) + client, _ = coderdenttest.New(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + Database: db, + Pubsub: pubsub, + DeploymentValues: dv, + Coordinator: coordinator, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + ) + defer cancel() + + nodeID := uuid.New() + ma := coordinator.ServeMultiAgent(nodeID) + defer ma.Close() + require.NoError(t, ma.UpdateSelf(&agpl.Node{ + ID: 55, + AsOf: time.Unix(1689653252, 0), + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + PreferredDERP: 0, + DERPLatency: map[string]float64{ + "0": 1.0, + }, + DERPForcedWebsocket: map[int]string{}, + Addresses: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, + AllowedIPs: []netip.Prefix{netip.PrefixFrom(agpl.IPFromUUID(nodeID), 128)}, + Endpoints: []string{"192.168.1.1:18842"}, + })) + + proxyRes, err := client.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{ + Name: namesgenerator.GetRandomName(1), + Icon: "/emojis/flag.png", + }) + require.NoError(t, err) + + proxyClient := wsproxysdk.New(client.URL) + proxyClient.SetSessionToken(proxyRes.ProxyToken) + + legacyRes, err := proxyClient.AgentIsLegacy(ctx, nodeID) + require.NoError(t, err) + + assert.True(t, legacyRes.Found) + assert.False(t, legacyRes.Legacy) + }) +} diff --git a/enterprise/tailnet/coordinator.go b/enterprise/tailnet/coordinator.go index 889df136710c5..672095eb3a989 100644 --- a/enterprise/tailnet/coordinator.go +++ b/enterprise/tailnet/coordinator.go @@ -56,7 +56,6 @@ func NewCoordinator(logger slog.Logger, ps pubsub.Pubsub) (agpl.Coordinator, err func (c *haCoordinator) ServeMultiAgent(id uuid.UUID) agpl.MultiAgentConn { m := (&agpl.MultiAgent{ ID: id, - Logger: c.log, AgentIsLegacyFunc: c.agentIsLegacy, OnSubscribe: c.clientSubscribeToAgent, OnNodeUpdate: c.clientNodeUpdate, diff --git a/enterprise/tailnet/workspaceproxy.go b/enterprise/tailnet/workspaceproxy.go new file mode 100644 index 0000000000000..13e1f6663a2c0 --- /dev/null +++ b/enterprise/tailnet/workspaceproxy.go @@ -0,0 +1,95 @@ +package tailnet + +import ( + "bytes" + "context" + "encoding/json" + "net" + "time" + + "golang.org/x/xerrors" + + "github.com/coder/coder/enterprise/wsproxy/wsproxysdk" + agpl "github.com/coder/coder/tailnet" +) + +func ServeWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { + go func() { + err := forwardNodesToWorkspaceProxy(ctx, conn, ma) + if err != nil { + _ = conn.Close() + } + }() + + decoder := json.NewDecoder(conn) + for { + var msg wsproxysdk.CoordinateMessage + err := decoder.Decode(&msg) + if err != nil { + return xerrors.Errorf("read json: %w", err) + } + + switch msg.Type { + case wsproxysdk.CoordinateMessageTypeSubscribe: + err := ma.SubscribeAgent(msg.AgentID) + if err != nil { + return xerrors.Errorf("subscribe agent: %w", err) + } + case wsproxysdk.CoordinateMessageTypeUnsubscribe: + err := ma.UnsubscribeAgent(msg.AgentID) + if err != nil { + return xerrors.Errorf("unsubscribe agent: %w", err) + } + case wsproxysdk.CoordinateMessageTypeNodeUpdate: + err := ma.UpdateSelf(msg.Node) + if err != nil { + return xerrors.Errorf("update self: %w", err) + } + + default: + return xerrors.Errorf("unknown message type %q", msg.Type) + } + } +} + +func forwardNodesToWorkspaceProxy(ctx context.Context, conn net.Conn, ma agpl.MultiAgentConn) error { + var lastData []byte + for { + nodes, ok := ma.NextUpdate(ctx) + if !ok { + return xerrors.New("multiagent is closed") + } + + data, err := json.Marshal(wsproxysdk.CoordinateNodes{Nodes: nodes}) + if err != nil { + return err + } + if bytes.Equal(lastData, data) { + continue + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = conn.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + return err + } + _, err = conn.Write(data) + if err != nil { + // often, this is just because the connection is closed/broken, so only log at debug. + return err + } + + // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are + // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() + // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. + // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after + // our successful write, it is important that we reset the deadline before it fires. + err = conn.SetWriteDeadline(time.Time{}) + if err != nil { + return err + } + lastData = data + } +} diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index ae5da832054e2..e68083095b9b0 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -27,10 +27,12 @@ import ( "github.com/coder/coder/codersdk" "github.com/coder/coder/enterprise/wsproxy/wsproxysdk" "github.com/coder/coder/site" + agpl "github.com/coder/coder/tailnet" ) type Options struct { - Logger slog.Logger + Logger slog.Logger + Experiments codersdk.Experiments HTTPClient *http.Client // DashboardURL is the URL of the primary coderd instance. @@ -168,6 +170,30 @@ func New(ctx context.Context, opts *Options) (*Server, error) { cancel: cancel, } + connInfo, err := client.SDKClient.WorkspaceAgentConnectionInfo(ctx) + if err != nil { + return nil, xerrors.Errorf("get derpmap: %w", err) + } + + var agentProvider workspaceapps.AgentProvider + if opts.Experiments.Enabled(codersdk.ExperimentSingleTailnet) { + stn, err := coderd.NewServerTailnet(ctx, + s.Logger.Named("server_tailnet"), + nil, + connInfo.DERPMap, + s.DialCoordinator, + wsconncache.New(s.DialWorkspaceAgent, 0), + ) + if err != nil { + return nil, xerrors.Errorf("create server tailnet: %w", err) + } + agentProvider = stn + } else { + agentProvider = &wsconncache.AgentProvider{ + Cache: wsconncache.New(s.DialWorkspaceAgent, 0), + } + } + s.AppServer = &workspaceapps.Server{ Logger: opts.Logger.Named("workspaceapps"), DashboardURL: opts.DashboardURL, @@ -185,10 +211,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { }, AppSecurityKey: secKey, - // TODO: Convert wsproxy to use coderd.ServerTailnet. - AgentProvider: &wsconncache.AgentProvider{ - Cache: wsconncache.New(s.DialWorkspaceAgent, 0), - }, + AgentProvider: agentProvider, DisablePathApps: opts.DisablePathApps, SecureAuthCookie: opts.SecureAuthCookie, } @@ -285,6 +308,10 @@ func (s *Server) DialWorkspaceAgent(id uuid.UUID) (*codersdk.WorkspaceAgentConn, return s.SDKClient.DialWorkspaceAgent(s.ctx, id, nil) } +func (s *Server) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) { + return s.SDKClient.DialCoordinator(ctx) +} + func (s *Server) buildInfo(rw http.ResponseWriter, r *http.Request) { httpapi.Write(r.Context(), rw, http.StatusOK, codersdk.BuildInfoResponse{ ExternalURL: buildinfo.ExternalURL(), diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index fa4a168dba5e1..f918daa82736a 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/coder/coder/cli/clibase" + "github.com/coder/coder/coderd" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/workspaceapps/apptest" @@ -13,7 +14,7 @@ import ( "github.com/coder/coder/enterprise/coderd/license" ) -func TestWorkspaceProxyWorkspaceApps(t *testing.T) { +func TestWorkspaceProxyWorkspaceApps_Wsconncache(t *testing.T) { t.Parallel() apptest.Run(t, false, func(t *testing.T, opts *apptest.DeploymentOptions) *apptest.Deployment { @@ -66,3 +67,59 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { } }) } + +func TestWorkspaceProxyWorkspaceApps_SingleTailnet(t *testing.T) { + t.Parallel() + + apptest.Run(t, false, func(t *testing.T, opts *apptest.DeploymentOptions) *apptest.Deployment { + deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues.DisablePathApps = clibase.Bool(opts.DisablePathApps) + deploymentValues.Dangerous.AllowPathAppSharing = clibase.Bool(opts.DangerousAllowPathAppSharing) + deploymentValues.Dangerous.AllowPathAppSiteOwnerAccess = clibase.Bool(opts.DangerousAllowPathAppSiteOwnerAccess) + deploymentValues.Experiments = []string{ + string(codersdk.ExperimentMoons), + string(codersdk.ExperimentSingleTailnet), + "*", + } + + client, _, api, user := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: deploymentValues, + AppHostname: "*.primary.test.coder.com", + IncludeProvisionerDaemon: true, + RealIPConfig: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(8, 32), + }}, + TrustedHeaders: []string{ + "CF-Connecting-IP", + }, + }, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + + // Create the external proxy + if opts.DisableSubdomainApps { + opts.AppHost = "" + } + proxyAPI := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + Name: "best-proxy", + Experiments: coderd.ReadExperiments(api.Logger, deploymentValues.Experiments.Value()), + AppHostname: opts.AppHost, + DisablePathApps: opts.DisablePathApps, + }) + + return &apptest.Deployment{ + Options: opts, + SDKClient: client, + FirstUser: user, + PathAppBaseURL: proxyAPI.Options.AccessURL, + } + }) +} diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 592610ec73afd..5703114281096 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -3,16 +3,25 @@ package wsproxysdk import ( "context" "encoding/json" + "fmt" "io" + "net" "net/http" "net/url" + "sync" + "time" "github.com/google/uuid" "golang.org/x/xerrors" + "nhooyr.io/websocket" + "tailscale.com/util/singleflight" + "cdr.dev/slog" + "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/workspaceapps" "github.com/coder/coder/codersdk" + agpl "github.com/coder/coder/tailnet" ) // Client is a HTTP client for a subset of Coder API routes that external @@ -186,3 +195,206 @@ func (c *Client) WorkspaceProxyGoingAway(ctx context.Context) error { } return nil } + +type CoordinateMessageType int + +const ( + CoordinateMessageTypeSubscribe CoordinateMessageType = 1 + iota + CoordinateMessageTypeUnsubscribe + CoordinateMessageTypeNodeUpdate +) + +type CoordinateMessage struct { + Type CoordinateMessageType `json:"type"` + AgentID uuid.UUID `json:"agent_id"` + Node *agpl.Node `json:"node"` +} + +type CoordinateNodes struct { + Nodes []*agpl.Node +} + +func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, error) { + ctx, cancel := context.WithCancel(ctx) + + coordinateURL, err := c.SDKClient.URL.Parse("/api/v2/workspaceproxies/me/coordinate") + if err != nil { + cancel() + return nil, xerrors.Errorf("parse url: %w", err) + } + coordinateHeaders := make(http.Header) + tokenHeader := codersdk.SessionTokenHeader + if c.SDKClient.SessionTokenHeader != "" { + tokenHeader = c.SDKClient.SessionTokenHeader + } + coordinateHeaders.Set(tokenHeader, c.SessionToken()) + + //nolint:bodyclose + conn, _, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{ + HTTPClient: c.SDKClient.HTTPClient, + HTTPHeader: coordinateHeaders, + }) + if err != nil { + cancel() + return nil, xerrors.Errorf("dial coordinate websocket: %w", err) + } + + go httpapi.HeartbeatClose(ctx, cancel, conn) + + nc := websocket.NetConn(ctx, conn, websocket.MessageText) + rma := remoteMultiAgentHandler{ + sdk: c, + nc: nc, + legacyAgentCache: map[uuid.UUID]bool{}, + } + + ma := (&agpl.MultiAgent{ + ID: uuid.New(), + AgentIsLegacyFunc: rma.AgentIsLegacy, + OnSubscribe: rma.OnSubscribe, + OnUnsubscribe: rma.OnUnsubscribe, + OnNodeUpdate: rma.OnNodeUpdate, + OnRemove: func(uuid.UUID) { conn.Close(websocket.StatusGoingAway, "closed") }, + }).Init() + + go func() { + defer cancel() + dec := json.NewDecoder(nc) + for { + var msg CoordinateNodes + err := dec.Decode(&msg) + if err != nil { + if xerrors.Is(err, io.EOF) { + return + } + + c.SDKClient.Logger().Error(ctx, "failed to decode coordinator nodes", slog.Error(err)) + return + } + + err = ma.Enqueue(msg.Nodes) + if err != nil { + c.SDKClient.Logger().Error(ctx, "enqueue nodes from coordinator", slog.Error(err)) + continue + } + } + }() + + return ma, nil +} + +type remoteMultiAgentHandler struct { + sdk *Client + nc net.Conn + + legacyMu sync.RWMutex + legacyAgentCache map[uuid.UUID]bool + legacySingleflight singleflight.Group[uuid.UUID, AgentIsLegacyResponse] +} + +func (a *remoteMultiAgentHandler) writeJSON(v interface{}) error { + data, err := json.Marshal(v) + if err != nil { + return xerrors.Errorf("json marshal message: %w", err) + } + + // Set a deadline so that hung connections don't put back pressure on the system. + // Node updates are tiny, so even the dinkiest connection can handle them if it's not hung. + err = a.nc.SetWriteDeadline(time.Now().Add(agpl.WriteTimeout)) + if err != nil { + return xerrors.Errorf("set write deadline: %w", err) + } + _, err = a.nc.Write(data) + if err != nil { + return xerrors.Errorf("write message: %w", err) + } + + // nhooyr.io/websocket has a bugged implementation of deadlines on a websocket net.Conn. What they are + // *supposed* to do is set a deadline for any subsequent writes to complete, otherwise the call to Write() + // fails. What nhooyr.io/websocket does is set a timer, after which it expires the websocket write context. + // If this timer fires, then the next write will fail *even if we set a new write deadline*. So, after + // our successful write, it is important that we reset the deadline before it fires. + err = a.nc.SetWriteDeadline(time.Time{}) + if err != nil { + return xerrors.Errorf("clear write deadline: %w", err) + } + + return nil +} + +func (a *remoteMultiAgentHandler) OnNodeUpdate(_ uuid.UUID, node *agpl.Node) error { + return a.writeJSON(CoordinateMessage{ + Type: CoordinateMessageTypeNodeUpdate, + Node: node, + }) +} + +func (a *remoteMultiAgentHandler) OnSubscribe(_ agpl.Queue, agentID uuid.UUID) (*agpl.Node, error) { + return nil, a.writeJSON(CoordinateMessage{ + Type: CoordinateMessageTypeSubscribe, + AgentID: agentID, + }) +} + +func (a *remoteMultiAgentHandler) OnUnsubscribe(_ agpl.Queue, agentID uuid.UUID) error { + return a.writeJSON(CoordinateMessage{ + Type: CoordinateMessageTypeUnsubscribe, + AgentID: agentID, + }) +} + +func (a *remoteMultiAgentHandler) AgentIsLegacy(agentID uuid.UUID) bool { + a.legacyMu.RLock() + if isLegacy, ok := a.legacyAgentCache[agentID]; ok { + a.legacyMu.RUnlock() + return isLegacy + } + a.legacyMu.RUnlock() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + resp, err, _ := a.legacySingleflight.Do(agentID, func() (AgentIsLegacyResponse, error) { + return a.sdk.AgentIsLegacy(ctx, agentID) + }) + if err != nil { + a.sdk.SDKClient.Logger().Error(ctx, "failed to check agent legacy status", slog.Error(err)) + + // Assume that the agent is legacy since this failed, while less + // efficient it will always work. + return true + } + // Assume legacy since the agent didn't exist. + if !resp.Found { + return true + } + + a.legacyMu.Lock() + a.legacyAgentCache[agentID] = resp.Legacy + a.legacyMu.Unlock() + + return resp.Legacy +} + +type AgentIsLegacyResponse struct { + Found bool `json:"found"` + Legacy bool `json:"legacy"` +} + +func (c *Client) AgentIsLegacy(ctx context.Context, agentID uuid.UUID) (AgentIsLegacyResponse, error) { + res, err := c.Request(ctx, http.MethodGet, + fmt.Sprintf("/api/v2/workspaceagents/%s/legacy", agentID.String()), + nil, + ) + if err != nil { + return AgentIsLegacyResponse{}, xerrors.Errorf("make request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return AgentIsLegacyResponse{}, codersdk.ReadBodyAsError(res) + } + + var resp AgentIsLegacyResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go index a266d607bba13..207283a098532 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go @@ -1,21 +1,36 @@ package wsproxysdk_test import ( + "context" "encoding/json" "io" "net/http" "net/http/httptest" "net/http/httputil" + "net/netip" "net/url" "sync/atomic" "testing" + "time" + "github.com/go-chi/chi/v5" + "github.com/golang/mock/gomock" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + "nhooyr.io/websocket" + "tailscale.com/types/key" + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/httpmw" "github.com/coder/coder/coderd/workspaceapps" + "github.com/coder/coder/enterprise/tailnet" "github.com/coder/coder/enterprise/wsproxy/wsproxysdk" + agpl "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" "github.com/coder/coder/testutil" ) @@ -136,6 +151,135 @@ func Test_IssueSignedAppTokenHTML(t *testing.T) { }) } +func TestDialCoordinator(t *testing.T) { + t.Parallel() + t.Run("OK", func(t *testing.T) { + t.Parallel() + var ( + ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitShort) + logger = slogtest.Make(t, nil).Leveled(slog.LevelDebug) + agentID = uuid.New() + serverMultiAgent = tailnettest.NewMockMultiAgentConn(gomock.NewController(t)) + r = chi.NewRouter() + srv = httptest.NewServer(r) + ) + defer cancel() + + r.Get("/api/v2/workspaceproxies/me/coordinate", func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + nc := websocket.NetConn(r.Context(), conn, websocket.MessageText) + defer serverMultiAgent.Close() + + err = tailnet.ServeWorkspaceProxy(ctx, nc, serverMultiAgent) + if !xerrors.Is(err, io.EOF) { + assert.NoError(t, err) + } + }) + r.Get("/api/v2/workspaceagents/{workspaceagent}/legacy", func(w http.ResponseWriter, r *http.Request) { + httpapi.Write(ctx, w, http.StatusOK, wsproxysdk.AgentIsLegacyResponse{ + Found: true, + Legacy: true, + }) + }) + + u, err := url.Parse(srv.URL) + require.NoError(t, err) + client := wsproxysdk.New(u) + client.SDKClient.SetLogger(logger) + + expected := []*agpl.Node{{ + ID: 55, + AsOf: time.Unix(1689653252, 0), + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + PreferredDERP: 0, + DERPLatency: map[string]float64{ + "0": 1.0, + }, + DERPForcedWebsocket: map[int]string{}, + Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4}), 128)}, + Endpoints: []string{"192.168.1.1:18842"}, + }} + sendNode := make(chan struct{}) + + serverMultiAgent.EXPECT().NextUpdate(gomock.Any()).AnyTimes(). + DoAndReturn(func(ctx context.Context) ([]*agpl.Node, bool) { + select { + case <-sendNode: + return expected, true + case <-ctx.Done(): + return nil, false + } + }) + + rma, err := client.DialCoordinator(ctx) + require.NoError(t, err) + + // Subscribe + { + ch := make(chan struct{}) + serverMultiAgent.EXPECT().SubscribeAgent(agentID).Do(func(uuid.UUID) { + close(ch) + }) + require.NoError(t, rma.SubscribeAgent(agentID)) + waitOrCancel(ctx, t, ch) + } + // Read updated agent node + { + sendNode <- struct{}{} + got, ok := rma.NextUpdate(ctx) + assert.True(t, ok) + got[0].AsOf = got[0].AsOf.In(time.Local) + assert.Equal(t, *expected[0], *got[0]) + } + // Check legacy + { + isLegacy := rma.AgentIsLegacy(agentID) + assert.True(t, isLegacy) + } + // UpdateSelf + { + ch := make(chan struct{}) + serverMultiAgent.EXPECT().UpdateSelf(gomock.Any()).Do(func(node *agpl.Node) { + node.AsOf = node.AsOf.In(time.Local) + assert.Equal(t, expected[0], node) + close(ch) + }) + require.NoError(t, rma.UpdateSelf(expected[0])) + waitOrCancel(ctx, t, ch) + } + // Unsubscribe + { + ch := make(chan struct{}) + serverMultiAgent.EXPECT().UnsubscribeAgent(agentID).Do(func(uuid.UUID) { + close(ch) + }) + require.NoError(t, rma.UnsubscribeAgent(agentID)) + waitOrCancel(ctx, t, ch) + } + // Close + { + ch := make(chan struct{}) + serverMultiAgent.EXPECT().Close().Do(func() { + close(ch) + }) + require.NoError(t, rma.Close()) + waitOrCancel(ctx, t, ch) + } + }) +} + +func waitOrCancel(ctx context.Context, t testing.TB, ch <-chan struct{}) { + t.Helper() + select { + case <-ch: + case <-ctx.Done(): + t.Fatal("timed out waiting for channel") + } +} + type ResponseRecorder struct { rw *httptest.ResponseRecorder wasWritten atomic.Bool diff --git a/go.mod b/go.mod index da02ddabca648..d49fe2211c86d 100644 --- a/go.mod +++ b/go.mod @@ -135,6 +135,7 @@ require ( github.com/mitchellh/go-wordwrap v1.0.1 github.com/mitchellh/mapstructure v1.5.0 github.com/moby/moby v24.0.1+incompatible + github.com/muesli/termenv v0.15.1 github.com/open-policy-agent/opa v0.51.0 github.com/ory/dockertest/v3 v3.10.0 github.com/pion/udp v0.1.2 @@ -305,7 +306,6 @@ require ( github.com/muesli/ansi v0.0.0-20221106050444-61f0cd9a192a // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/muesli/reflow v0.3.0 // indirect - github.com/muesli/termenv v0.15.1 github.com/niklasfasching/go-org v1.7.0 // indirect github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect diff --git a/tailnet/coordinator.go b/tailnet/coordinator.go index 93cf8c67af56b..51c95aca4d2e6 100644 --- a/tailnet/coordinator.go +++ b/tailnet/coordinator.go @@ -140,7 +140,6 @@ type coordinator struct { func (c *coordinator) ServeMultiAgent(id uuid.UUID) MultiAgentConn { m := (&MultiAgent{ ID: id, - Logger: c.core.logger, AgentIsLegacyFunc: c.core.agentIsLegacy, OnSubscribe: c.core.clientSubscribeToAgent, OnUnsubscribe: c.core.clientUnsubscribeFromAgent, diff --git a/tailnet/multiagent.go b/tailnet/multiagent.go index 13300fdce677a..ee76e4b88d8aa 100644 --- a/tailnet/multiagent.go +++ b/tailnet/multiagent.go @@ -8,8 +8,6 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" - - "cdr.dev/slog" ) type MultiAgentConn interface { @@ -25,10 +23,7 @@ type MultiAgentConn interface { type MultiAgent struct { mu sync.RWMutex - closed bool - - ID uuid.UUID - Logger slog.Logger + ID uuid.UUID AgentIsLegacyFunc func(agentID uuid.UUID) bool OnSubscribe func(enq Queue, agent uuid.UUID) (*Node, error) @@ -36,6 +31,7 @@ type MultiAgent struct { OnNodeUpdate func(id uuid.UUID, node *Node) error OnRemove func(id uuid.UUID) + closed bool updates chan []*Node closeOnce sync.Once start int64 diff --git a/tailnet/tailnettest/multiagentmock.go b/tailnet/tailnettest/multiagentmock.go new file mode 100644 index 0000000000000..e6bde4f8c2367 --- /dev/null +++ b/tailnet/tailnettest/multiagentmock.go @@ -0,0 +1,150 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coder/coder/tailnet/tailnettest (interfaces: MultiAgentConn) + +// Package tailnettest is a generated GoMock package. +package tailnettest + +import ( + context "context" + reflect "reflect" + + tailnet "github.com/coder/coder/tailnet" + gomock "github.com/golang/mock/gomock" + uuid "github.com/google/uuid" +) + +// MockMultiAgentConn is a mock of MultiAgentConn interface. +type MockMultiAgentConn struct { + ctrl *gomock.Controller + recorder *MockMultiAgentConnMockRecorder +} + +// MockMultiAgentConnMockRecorder is the mock recorder for MockMultiAgentConn. +type MockMultiAgentConnMockRecorder struct { + mock *MockMultiAgentConn +} + +// NewMockMultiAgentConn creates a new mock instance. +func NewMockMultiAgentConn(ctrl *gomock.Controller) *MockMultiAgentConn { + mock := &MockMultiAgentConn{ctrl: ctrl} + mock.recorder = &MockMultiAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultiAgentConn) EXPECT() *MockMultiAgentConnMockRecorder { + return m.recorder +} + +// AgentIsLegacy mocks base method. +func (m *MockMultiAgentConn) AgentIsLegacy(arg0 uuid.UUID) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AgentIsLegacy", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AgentIsLegacy indicates an expected call of AgentIsLegacy. +func (mr *MockMultiAgentConnMockRecorder) AgentIsLegacy(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AgentIsLegacy", reflect.TypeOf((*MockMultiAgentConn)(nil).AgentIsLegacy), arg0) +} + +// Close mocks base method. +func (m *MockMultiAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockMultiAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiAgentConn)(nil).Close)) +} + +// Enqueue mocks base method. +func (m *MockMultiAgentConn) Enqueue(arg0 []*tailnet.Node) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Enqueue", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Enqueue indicates an expected call of Enqueue. +func (mr *MockMultiAgentConnMockRecorder) Enqueue(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Enqueue", reflect.TypeOf((*MockMultiAgentConn)(nil).Enqueue), arg0) +} + +// IsClosed mocks base method. +func (m *MockMultiAgentConn) IsClosed() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClosed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsClosed indicates an expected call of IsClosed. +func (mr *MockMultiAgentConnMockRecorder) IsClosed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiAgentConn)(nil).IsClosed)) +} + +// NextUpdate mocks base method. +func (m *MockMultiAgentConn) NextUpdate(arg0 context.Context) ([]*tailnet.Node, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextUpdate", arg0) + ret0, _ := ret[0].([]*tailnet.Node) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// NextUpdate indicates an expected call of NextUpdate. +func (mr *MockMultiAgentConnMockRecorder) NextUpdate(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextUpdate", reflect.TypeOf((*MockMultiAgentConn)(nil).NextUpdate), arg0) +} + +// SubscribeAgent mocks base method. +func (m *MockMultiAgentConn) SubscribeAgent(arg0 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SubscribeAgent", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// SubscribeAgent indicates an expected call of SubscribeAgent. +func (mr *MockMultiAgentConnMockRecorder) SubscribeAgent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).SubscribeAgent), arg0) +} + +// UnsubscribeAgent mocks base method. +func (m *MockMultiAgentConn) UnsubscribeAgent(arg0 uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnsubscribeAgent", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnsubscribeAgent indicates an expected call of UnsubscribeAgent. +func (mr *MockMultiAgentConnMockRecorder) UnsubscribeAgent(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnsubscribeAgent", reflect.TypeOf((*MockMultiAgentConn)(nil).UnsubscribeAgent), arg0) +} + +// UpdateSelf mocks base method. +func (m *MockMultiAgentConn) UpdateSelf(arg0 *tailnet.Node) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateSelf", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateSelf indicates an expected call of UpdateSelf. +func (mr *MockMultiAgentConnMockRecorder) UpdateSelf(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSelf", reflect.TypeOf((*MockMultiAgentConn)(nil).UpdateSelf), arg0) +} diff --git a/tailnet/tailnettest/tailnettest.go b/tailnet/tailnettest/tailnettest.go index 0cb7dbd330ed3..655568a341ccb 100644 --- a/tailnet/tailnettest/tailnettest.go +++ b/tailnet/tailnettest/tailnettest.go @@ -21,6 +21,8 @@ import ( "github.com/coder/coder/tailnet" ) +//go:generate mockgen -destination ./multiagentmock.go -package tailnettest github.com/coder/coder/tailnet MultiAgentConn + // RunDERPAndSTUN creates a DERP mapping for tests. func RunDERPAndSTUN(t *testing.T) (*tailcfg.DERPMap, *derp.Server) { logf := tailnet.Logger(slogtest.Make(t, nil))