diff --git a/enterprise/coderd/coderdenttest/proxytest.go b/enterprise/coderd/coderdenttest/proxytest.go index 9b43cbe6c316d..831c4be86f640 100644 --- a/enterprise/coderd/coderdenttest/proxytest.go +++ b/enterprise/coderd/coderdenttest/proxytest.go @@ -38,15 +38,29 @@ type ProxyOptions struct { // ProxyURL is optional ProxyURL *url.URL + // Token is optional. If specified, a new workspace proxy region will not be + // created, and the proxy will become a replica of the existing proxy + // region. + Token string + // FlushStats is optional FlushStats chan chan<- struct{} } -// NewWorkspaceProxy will configure a wsproxy.Server with the given options. -// The new wsproxy will register itself with the given coderd.API instance. -// The first user owner client is required to create the wsproxy on the coderd -// api server. -func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) *wsproxy.Server { +type WorkspaceProxy struct { + *wsproxy.Server + + ServerURL *url.URL +} + +// NewWorkspaceProxyReplica will configure a wsproxy.Server with the given +// options. The new wsproxy replica will register itself with the given +// coderd.API instance. +// +// If a token is not provided, a new workspace proxy region is created using the +// owner client. If a token is provided, the proxy will become a replica of the +// existing proxy region. +func NewWorkspaceProxyReplica(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Client, options *ProxyOptions) WorkspaceProxy { ctx, cancelFunc := context.WithCancel(context.Background()) t.Cleanup(cancelFunc) @@ -107,11 +121,15 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie options.Name = namesgenerator.GetRandomName(1) } - proxyRes, err := owner.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{ - Name: options.Name, - Icon: "/emojis/flag.png", - }) - require.NoError(t, err, "failed to create workspace proxy") + token := options.Token + if token == "" { + proxyRes, err := owner.CreateWorkspaceProxy(ctx, codersdk.CreateWorkspaceProxyRequest{ + Name: options.Name, + Icon: "/emojis/flag.png", + }) + require.NoError(t, err, "failed to create workspace proxy") + token = proxyRes.ProxyToken + } // Inherit collector options from coderd, but keep the wsproxy reporter. statsCollectorOptions := coderdAPI.Options.WorkspaceAppsStatsCollectorOptions @@ -121,7 +139,7 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie } wssrv, err := wsproxy.New(ctx, &wsproxy.Options{ - Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug).With(slog.F("server_url", serverURL.String())), Experiments: options.Experiments, DashboardURL: coderdAPI.AccessURL, AccessURL: accessURL, @@ -131,14 +149,14 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie Tracing: coderdAPI.TracerProvider, APIRateLimit: coderdAPI.APIRateLimit, SecureAuthCookie: coderdAPI.SecureAuthCookie, - ProxySessionToken: proxyRes.ProxyToken, + ProxySessionToken: token, DisablePathApps: options.DisablePathApps, // We need a new registry to not conflict with the coderd internal // proxy metrics. PrometheusRegistry: prometheus.NewRegistry(), DERPEnabled: !options.DerpDisabled, DERPOnly: options.DerpOnly, - DERPServerRelayAddress: accessURL.String(), + DERPServerRelayAddress: serverURL.String(), StatsCollectorOptions: statsCollectorOptions, }) require.NoError(t, err) @@ -151,5 +169,8 @@ func NewWorkspaceProxy(t *testing.T, coderdAPI *coderd.API, owner *codersdk.Clie handler = wssrv.Handler mutex.Unlock() - return wssrv + return WorkspaceProxy{ + Server: wssrv, + ServerURL: serverURL, + } } diff --git a/enterprise/coderd/workspaceproxy_test.go b/enterprise/coderd/workspaceproxy_test.go index 17e17240dcace..b7d4e8cf2f8f9 100644 --- a/enterprise/coderd/workspaceproxy_test.go +++ b/enterprise/coderd/workspaceproxy_test.go @@ -99,7 +99,7 @@ func TestRegions(t *testing.T) { require.NoError(t, err) const proxyName = "hello" - _ = coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + _ = coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: proxyName, AppHostname: appHostname + ".proxy", }) @@ -734,7 +734,7 @@ func TestReconnectingPTYSignedToken(t *testing.T) { proxyURL, err := url.Parse(fmt.Sprintf("https://%s.com", namesgenerator.GetRandomName(1))) require.NoError(t, err) - _ = coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + _ = coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: namesgenerator.GetRandomName(1), ProxyURL: proxyURL, AppHostname: "*.sub.example.com", diff --git a/enterprise/wsproxy/wsproxy.go b/enterprise/wsproxy/wsproxy.go index 68693f4633871..17fae2b791695 100644 --- a/enterprise/wsproxy/wsproxy.go +++ b/enterprise/wsproxy/wsproxy.go @@ -128,7 +128,7 @@ type Server struct { ctx context.Context cancel context.CancelFunc derpCloseFunc func() - registerDone <-chan struct{} + registerLoop *wsproxysdk.RegisterWorkspaceProxyLoop } // New creates a new workspace proxy server. This requires a primary coderd @@ -210,7 +210,7 @@ func New(ctx context.Context, opts *Options) (*Server, error) { // goroutine to periodically re-register. replicaID := uuid.New() osHostname := cliutil.Hostname() - regResp, registerDone, err := client.RegisterWorkspaceProxyLoop(ctx, wsproxysdk.RegisterWorkspaceProxyLoopOpts{ + registerLoop, regResp, err := client.RegisterWorkspaceProxyLoop(ctx, wsproxysdk.RegisterWorkspaceProxyLoopOpts{ Logger: opts.Logger, Request: wsproxysdk.RegisterWorkspaceProxyRequest{ AccessURL: opts.AccessURL.String(), @@ -230,12 +230,13 @@ func New(ctx context.Context, opts *Options) (*Server, error) { if err != nil { return nil, xerrors.Errorf("register proxy: %w", err) } - s.registerDone = registerDone - err = s.handleRegister(ctx, regResp) + s.registerLoop = registerLoop + + derpServer.SetMeshKey(regResp.DERPMeshKey) + err = s.handleRegister(regResp) if err != nil { return nil, xerrors.Errorf("handle register: %w", err) } - derpServer.SetMeshKey(regResp.DERPMeshKey) secKey, err := workspaceapps.KeyFromString(regResp.AppSecurityKey) if err != nil { @@ -409,16 +410,16 @@ func New(ctx context.Context, opts *Options) (*Server, error) { return s, nil } +func (s *Server) RegisterNow() error { + _, err := s.registerLoop.RegisterNow() + return err +} + func (s *Server) Close() error { s.cancel() var err error - registerDoneWaitTicker := time.NewTicker(11 * time.Second) // the attempt timeout is 10s - select { - case <-registerDoneWaitTicker.C: - err = multierror.Append(err, xerrors.New("timed out waiting for registerDone")) - case <-s.registerDone: - } + s.registerLoop.Close() s.derpCloseFunc() appServerErr := s.AppServer.Close() if appServerErr != nil { @@ -437,11 +438,12 @@ func (*Server) mutateRegister(_ *wsproxysdk.RegisterWorkspaceProxyRequest) { // package in the primary and update req.ReplicaError accordingly. } -func (s *Server) handleRegister(_ context.Context, res wsproxysdk.RegisterWorkspaceProxyResponse) error { +func (s *Server) handleRegister(res wsproxysdk.RegisterWorkspaceProxyResponse) error { addresses := make([]string, len(res.SiblingReplicas)) for i, replica := range res.SiblingReplicas { addresses[i] = replica.RelayAddress } + s.Logger.Debug(s.ctx, "setting DERP mesh sibling addresses", slog.F("addresses", addresses)) s.derpMesh.SetAddresses(addresses, false) s.latestDERPMap.Store(res.DERPMap) diff --git a/enterprise/wsproxy/wsproxy_test.go b/enterprise/wsproxy/wsproxy_test.go index 0d440165dfb16..e8fed4f35c594 100644 --- a/enterprise/wsproxy/wsproxy_test.go +++ b/enterprise/wsproxy/wsproxy_test.go @@ -1,14 +1,18 @@ package wsproxy_test import ( + "context" "fmt" "net" + "net/url" "testing" + "time" "github.com/davecgh/go-spew/spew" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -22,6 +26,7 @@ import ( "github.com/coder/coder/v2/coderd/httpmw" "github.com/coder/coder/v2/coderd/workspaceapps/apptest" "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/cryptorand" "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" "github.com/coder/coder/v2/enterprise/coderd/license" "github.com/coder/coder/v2/provisioner/echo" @@ -62,7 +67,7 @@ func TestDERPOnly(t *testing.T) { }) // Create an external proxy. - _ = coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + _ = coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "best-proxy", DerpOnly: true, }) @@ -109,15 +114,15 @@ func TestDERP(t *testing.T) { }) // Create two running external proxies. - proxyAPI1 := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + proxyAPI1 := coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "best-proxy", }) - proxyAPI2 := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + proxyAPI2 := coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "worst-proxy", }) // Create a running external proxy with DERP disabled. - proxyAPI3 := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + proxyAPI3 := coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "no-derp-proxy", DerpDisabled: true, }) @@ -340,7 +345,7 @@ func TestDERPEndToEnd(t *testing.T) { _ = closer.Close() }) - coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "best-proxy", }) @@ -430,6 +435,105 @@ resourceLoop: require.False(t, p2p) } +// TestDERPMesh spawns 6 workspace proxy replicas and tries to connect to a +// single DERP peer via every single one. +func TestDERPMesh(t *testing.T) { + t.Parallel() + + deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues.Experiments = []string{ + "*", + } + + client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: deploymentValues, + AppHostname: "*.primary.test.coder.com", + IncludeProvisionerDaemon: true, + RealIPConfig: &httpmw.RealIPConfig{ + TrustedOrigins: []*net.IPNet{{ + IP: net.ParseIP("127.0.0.1"), + Mask: net.CIDRMask(8, 32), + }}, + TrustedHeaders: []string{ + "CF-Connecting-IP", + }, + }, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureWorkspaceProxy: 1, + }, + }, + }) + t.Cleanup(func() { + _ = closer.Close() + }) + + proxyURL, err := url.Parse("https://proxy.test.coder.com") + require.NoError(t, err) + + // Create 6 proxy replicas. + const count = 6 + var ( + sessionToken = "" + proxies = [count]coderdenttest.WorkspaceProxy{} + derpURLs = [count]string{} + ) + for i := range proxies { + proxies[i] = coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ + Name: "best-proxy", + Token: sessionToken, + ProxyURL: proxyURL, + }) + if i == 0 { + sessionToken = proxies[i].Options.ProxySessionToken + } + + derpURL := *proxies[i].ServerURL + derpURL.Path = "/derp" + derpURLs[i] = derpURL.String() + } + + // Force all proxies to re-register immediately. This ensures the DERP mesh + // is up-to-date. In production this will happen automatically after about + // 15 seconds. + for i, proxy := range proxies { + err := proxy.RegisterNow() + require.NoErrorf(t, err, "failed to force proxy %d to re-register", i) + } + + // Generate cases. We have a case for: + // - Each proxy to itself. + // - Each proxy to each other proxy (one way, no duplicates). + cases := [][2]string{} + for i, derpURL := range derpURLs { + cases = append(cases, [2]string{derpURL, derpURL}) + for j := i + 1; j < len(derpURLs); j++ { + cases = append(cases, [2]string{derpURL, derpURLs[j]}) + } + } + require.Len(t, cases, (count*(count+1))/2) // triangle number + + for i, c := range cases { + i, c := i, c + t.Run(fmt.Sprintf("Proxy%d", i), func(t *testing.T) { + t.Parallel() + + t.Logf("derp1=%s, derp2=%s", c[0], c[1]) + ctx := testutil.Context(t, testutil.WaitLong) + client1, client1Recv := createDERPClient(t, ctx, "client1", c[0]) + client2, client2Recv := createDERPClient(t, ctx, "client2", c[1]) + + // Send a packet from client 1 to client 2. + testDERPSend(t, ctx, client2.SelfPublicKey(), client2Recv, client1) + + // Send a packet from client 2 to client 1. + testDERPSend(t, ctx, client1.SelfPublicKey(), client1Recv, client2) + }) + } +} + func TestWorkspaceProxyWorkspaceApps(t *testing.T) { t.Parallel() @@ -482,7 +586,7 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { if opts.DisableSubdomainApps { opts.AppHost = "" } - proxyAPI := coderdenttest.NewWorkspaceProxy(t, api, client, &coderdenttest.ProxyOptions{ + proxyAPI := coderdenttest.NewWorkspaceProxyReplica(t, api, client, &coderdenttest.ProxyOptions{ Name: "best-proxy", AppHostname: opts.AppHost, DisablePathApps: opts.DisablePathApps, @@ -498,3 +602,84 @@ func TestWorkspaceProxyWorkspaceApps(t *testing.T) { } }) } + +// createDERPClient creates a DERP client and spawns a goroutine that reads from +// the client and sends the received packets to a channel. +// +//nolint:revive +func createDERPClient(t *testing.T, ctx context.Context, name string, derpURL string) (*derphttp.Client, <-chan derp.ReceivedPacket) { + t.Helper() + + client, err := derphttp.NewClient(key.NewNode(), derpURL, func(format string, args ...any) { + t.Logf(name+": "+format, args...) + }) + require.NoError(t, err, "create client") + t.Cleanup(func() { + _ = client.Close() + }) + err = client.Connect(ctx) + require.NoError(t, err, "connect to DERP server") + + ch := make(chan derp.ReceivedPacket, 1) + go func() { + defer close(ch) + for { + msg, err := client.Recv() + if err != nil { + t.Logf("Recv error: %v", err) + return + } + switch msg := msg.(type) { + case derp.ReceivedPacket: + ch <- msg + return + default: + // We don't care about other messages. + } + } + }() + + return client, ch +} + +// testDERPSend sends a message from src to dstKey and waits for it to be +// received on dstCh. +// +// If the packet doesn't arrive within 500ms, it will try to send it again until +// testutil.WaitLong is reached. +// +//nolint:revive +func testDERPSend(t *testing.T, ctx context.Context, dstKey key.NodePublic, dstCh <-chan derp.ReceivedPacket, src *derphttp.Client) { + t.Helper() + + // The prefix helps identify where the packet starts if you get garbled data + // in logs. + const msgStrPrefix = "test_packet_" + msgStr, err := cryptorand.String(64 - len(msgStrPrefix)) + require.NoError(t, err, "generate random msg string") + msg := []byte(msgStrPrefix + msgStr) + + err = src.Send(dstKey, msg) + require.NoError(t, err, "send message via DERP") + + ticker := time.NewTicker(time.Millisecond * 500) + defer ticker.Stop() + for { + select { + case pkt := <-dstCh: + require.Equal(t, src.SelfPublicKey(), pkt.Source, "packet came from wrong source") + require.Equal(t, msg, pkt.Data, "packet data is wrong") + return + case <-ctx.Done(): + t.Fatal("timed out waiting for packet") + return + case <-ticker.C: + } + + // Send another packet. Since we're sending packets immediately + // after opening the clients, they might not be meshed together + // properly yet. + err = src.Send(dstKey, msg) + require.NoError(t, err, "send message via DERP") + } +} diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index 1163f7c435001..37636102bb413 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -277,135 +277,214 @@ type RegisterWorkspaceProxyLoopOpts struct { // called in a blocking manner, so it should avoid blocking for too long. If // the callback returns an error, the loop will stop immediately and the // error will be returned to the FailureFn. - CallbackFn func(ctx context.Context, res RegisterWorkspaceProxyResponse) error + CallbackFn func(res RegisterWorkspaceProxyResponse) error // FailureFn is called with the last error returned from the server if the // context is canceled, registration fails for more than MaxFailureCount, // or if any permanent values in the response change. FailureFn func(err error) } -// RegisterWorkspaceProxyLoop will register the workspace proxy and then start a -// goroutine to keep registering periodically in the background. -// -// The first response is returned immediately, and subsequent responses will be -// notified to the given CallbackFn. When the context is canceled the loop will -// stop immediately and the context error will be returned to the FailureFn. -// -// The returned channel will be closed when the loop stops and can be used to -// ensure the loop is dead before continuing. When a fatal error is encountered, -// the proxy will be deregistered (with the same ReplicaID and AttemptTimeout) -// before calling the FailureFn. -func (c *Client) RegisterWorkspaceProxyLoop(ctx context.Context, opts RegisterWorkspaceProxyLoopOpts) (RegisterWorkspaceProxyResponse, <-chan struct{}, error) { - if opts.Interval == 0 { - opts.Interval = 30 * time.Second - } - if opts.MaxFailureCount == 0 { - opts.MaxFailureCount = 10 - } - if opts.AttemptTimeout == 0 { - opts.AttemptTimeout = 10 * time.Second - } - if opts.MutateFn == nil { - opts.MutateFn = func(_ *RegisterWorkspaceProxyRequest) {} - } - if opts.CallbackFn == nil { - opts.CallbackFn = func(_ context.Context, _ RegisterWorkspaceProxyResponse) error { - return nil - } +type RegisterWorkspaceProxyLoop struct { + opts RegisterWorkspaceProxyLoopOpts + c *Client + + // runLoopNow takes a response channel to send the response to and triggers + // the loop to run immediately if it's waiting. + runLoopNow chan chan RegisterWorkspaceProxyResponse + closedCtx context.Context + close context.CancelFunc + done chan struct{} +} + +func (l *RegisterWorkspaceProxyLoop) register(ctx context.Context) (RegisterWorkspaceProxyResponse, error) { + registerCtx, registerCancel := context.WithTimeout(ctx, l.opts.AttemptTimeout) + res, err := l.c.RegisterWorkspaceProxy(registerCtx, l.opts.Request) + registerCancel() + if err != nil { + return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("register workspace proxy: %w", err) } - failureFn := func(err error) { - // We have to use background context here because the original context - // may be canceled. - deregisterCtx, cancel := context.WithTimeout(context.Background(), opts.AttemptTimeout) - defer cancel() - deregisterErr := c.DeregisterWorkspaceProxy(deregisterCtx, DeregisterWorkspaceProxyRequest{ - ReplicaID: opts.Request.ReplicaID, - }) - if deregisterErr != nil { - opts.Logger.Error(ctx, - "failed to deregister workspace proxy with Coder primary (it will be automatically deregistered shortly)", - slog.Error(deregisterErr), - ) - } + return res, nil +} - if opts.FailureFn != nil { - opts.FailureFn(err) - } +// Start starts the proxy registration loop. The provided context is only used +// for the initial registration. Use Close() to stop. +func (l *RegisterWorkspaceProxyLoop) Start(ctx context.Context) (RegisterWorkspaceProxyResponse, error) { + if l.opts.Interval == 0 { + l.opts.Interval = 15 * time.Second + } + if l.opts.MaxFailureCount == 0 { + l.opts.MaxFailureCount = 10 + } + if l.opts.AttemptTimeout == 0 { + l.opts.AttemptTimeout = 10 * time.Second } - originalRes, err := c.RegisterWorkspaceProxy(ctx, opts.Request) + var err error + originalRes, err := l.register(ctx) if err != nil { - return RegisterWorkspaceProxyResponse{}, nil, xerrors.Errorf("register workspace proxy: %w", err) + return RegisterWorkspaceProxyResponse{}, xerrors.Errorf("initial registration: %w", err) } - done := make(chan struct{}) go func() { - defer close(done) + defer close(l.done) var ( failedAttempts = 0 - ticker = time.NewTicker(opts.Interval) + ticker = time.NewTicker(l.opts.Interval) ) for { + var respCh chan RegisterWorkspaceProxyResponse select { - case <-ctx.Done(): - failureFn(ctx.Err()) + case <-l.closedCtx.Done(): + l.failureFn(xerrors.Errorf("proxy registration loop closed")) return + case respCh = <-l.runLoopNow: case <-ticker.C: } - opts.Logger.Debug(ctx, + l.opts.Logger.Debug(context.Background(), "re-registering workspace proxy with Coder primary", - slog.F("req", opts.Request), - slog.F("timeout", opts.AttemptTimeout), + slog.F("req", l.opts.Request), + slog.F("timeout", l.opts.AttemptTimeout), slog.F("failed_attempts", failedAttempts), ) - opts.MutateFn(&opts.Request) - registerCtx, cancel := context.WithTimeout(ctx, opts.AttemptTimeout) - res, err := c.RegisterWorkspaceProxy(registerCtx, opts.Request) - cancel() + + l.mutateFn(&l.opts.Request) + resp, err := l.register(l.closedCtx) if err != nil { failedAttempts++ - opts.Logger.Warn(ctx, + l.opts.Logger.Warn(context.Background(), "failed to re-register workspace proxy with Coder primary", - slog.F("req", opts.Request), - slog.F("timeout", opts.AttemptTimeout), + slog.F("req", l.opts.Request), + slog.F("timeout", l.opts.AttemptTimeout), slog.F("failed_attempts", failedAttempts), slog.Error(err), ) - if failedAttempts > opts.MaxFailureCount { - failureFn(xerrors.Errorf("exceeded re-registration failure count of %d: last error: %w", opts.MaxFailureCount, err)) + if failedAttempts > l.opts.MaxFailureCount { + l.failureFn(xerrors.Errorf("exceeded re-registration failure count of %d: last error: %w", l.opts.MaxFailureCount, err)) return } continue } failedAttempts = 0 - if res.AppSecurityKey != originalRes.AppSecurityKey { - failureFn(xerrors.New("app security key has changed, proxy must be restarted")) + // Check for consistency. + if originalRes.AppSecurityKey != resp.AppSecurityKey { + l.failureFn(xerrors.New("app security key has changed, proxy must be restarted")) return } - if res.DERPMeshKey != originalRes.DERPMeshKey { - failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted")) + if originalRes.DERPMeshKey != resp.DERPMeshKey { + l.failureFn(xerrors.New("DERP mesh key has changed, proxy must be restarted")) return } - if res.DERPRegionID != originalRes.DERPRegionID { - failureFn(xerrors.New("DERP region ID has changed, proxy must be restarted")) + if originalRes.DERPRegionID != resp.DERPRegionID { + l.failureFn(xerrors.New("DERP region ID has changed, proxy must be restarted")) + return } - err = opts.CallbackFn(ctx, res) + err = l.callbackFn(resp) if err != nil { - failureFn(xerrors.Errorf("callback fn returned error: %w", err)) + l.failureFn(xerrors.Errorf("callback function returned an error: %w", err)) return } - ticker.Reset(opts.Interval) + // If we were triggered by RegisterNow(), send the response back. + if respCh != nil { + respCh <- resp + close(respCh) + } + + ticker.Reset(l.opts.Interval) } }() - return originalRes, done, nil + return originalRes, nil +} + +// RegisterNow asks the registration loop to register immediately. A timeout of +// 2x the attempt timeout is used to wait for the response. +func (l *RegisterWorkspaceProxyLoop) RegisterNow() (RegisterWorkspaceProxyResponse, error) { + // The channel is closed by the loop after sending the response. + respCh := make(chan RegisterWorkspaceProxyResponse, 1) + select { + case <-l.done: + return RegisterWorkspaceProxyResponse{}, xerrors.New("proxy registration loop closed") + case l.runLoopNow <- respCh: + } + select { + case <-l.done: + return RegisterWorkspaceProxyResponse{}, xerrors.New("proxy registration loop closed") + case resp := <-respCh: + return resp, nil + } +} + +func (l *RegisterWorkspaceProxyLoop) Close() { + l.close() + <-l.done +} + +func (l *RegisterWorkspaceProxyLoop) mutateFn(req *RegisterWorkspaceProxyRequest) { + if l.opts.MutateFn != nil { + l.opts.MutateFn(req) + } +} + +func (l *RegisterWorkspaceProxyLoop) callbackFn(res RegisterWorkspaceProxyResponse) error { + if l.opts.CallbackFn != nil { + return l.opts.CallbackFn(res) + } + return nil +} + +func (l *RegisterWorkspaceProxyLoop) failureFn(err error) { + // We have to use background context here because the original context may + // be canceled. + deregisterCtx, cancel := context.WithTimeout(context.Background(), l.opts.AttemptTimeout) + defer cancel() + deregisterErr := l.c.DeregisterWorkspaceProxy(deregisterCtx, DeregisterWorkspaceProxyRequest{ + ReplicaID: l.opts.Request.ReplicaID, + }) + if deregisterErr != nil { + l.opts.Logger.Error(context.Background(), + "failed to deregister workspace proxy with Coder primary (it will be automatically deregistered shortly)", + slog.Error(deregisterErr), + ) + } + + if l.opts.FailureFn != nil { + l.opts.FailureFn(err) + } +} + +// RegisterWorkspaceProxyLoop will register the workspace proxy and then start a +// goroutine to keep registering periodically in the background. +// +// The first response is returned immediately, and subsequent responses will be +// notified to the given CallbackFn. When the loop is Close()d it will stop +// immediately and an error will be returned to the FailureFn. +// +// When a fatal error is encountered (or the proxy is closed), the proxy will be +// deregistered (with the same ReplicaID and AttemptTimeout) before calling the +// FailureFn. +func (c *Client) RegisterWorkspaceProxyLoop(ctx context.Context, opts RegisterWorkspaceProxyLoopOpts) (*RegisterWorkspaceProxyLoop, RegisterWorkspaceProxyResponse, error) { + closedCtx, closeFn := context.WithCancel(context.Background()) + loop := &RegisterWorkspaceProxyLoop{ + opts: opts, + c: c, + runLoopNow: make(chan chan RegisterWorkspaceProxyResponse), + closedCtx: closedCtx, + close: closeFn, + done: make(chan struct{}), + } + + regResp, err := loop.Start(ctx) + if err != nil { + return nil, RegisterWorkspaceProxyResponse{}, xerrors.Errorf("start loop: %w", err) + } + return loop, regResp, nil } type CoordinateMessageType int