diff --git a/tailnet/test/integration/integration.go b/tailnet/test/integration/integration.go index 3877542c8eafc..938ed29e8d555 100644 --- a/tailnet/test/integration/integration.go +++ b/tailnet/test/integration/integration.go @@ -111,6 +111,53 @@ type SimpleServerOptions struct { var _ ServerStarter = SimpleServerOptions{} +type connManager struct { + mu sync.Mutex + conns map[uuid.UUID]net.Conn +} + +func (c *connManager) Add(id uuid.UUID, conn net.Conn) func() { + c.mu.Lock() + defer c.mu.Unlock() + c.conns[id] = conn + return func() { + c.mu.Lock() + defer c.mu.Unlock() + delete(c.conns, id) + } +} + +func (c *connManager) CloseAll() { + c.mu.Lock() + defer c.mu.Unlock() + for _, conn := range c.conns { + _ = conn.Close() + } + c.conns = make(map[uuid.UUID]net.Conn) +} + +type derpServer struct { + http.Handler + srv *derp.Server + closeFn func() +} + +func newDerpServer(t *testing.T, logger slog.Logger) *derpServer { + derpSrv := derp.NewServer(key.NewNode(), tailnet.Logger(logger.Named("derp"))) + derpHandler, derpCloseFunc := tailnet.WithWebsocketSupport(derpSrv, derphttp.Handler(derpSrv)) + t.Cleanup(derpCloseFunc) + return &derpServer{ + srv: derpSrv, + Handler: derpHandler, + closeFn: derpCloseFunc, + } +} + +func (s *derpServer) Close() { + s.srv.Close() + s.closeFn() +} + //nolint:revive func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { coord := tailnet.NewCoordinator(logger) @@ -118,6 +165,10 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { coordPtr.Store(&coord) t.Cleanup(func() { _ = coord.Close() }) + cm := connManager{ + conns: make(map[uuid.UUID]net.Conn), + } + csvc, err := tailnet.NewClientService(logger, &coordPtr, 10*time.Minute, func() *tailcfg.DERPMap { return &tailcfg.DERPMap{ // Clients will set their own based on their custom access URL. @@ -126,9 +177,11 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { }) require.NoError(t, err) - derpServer := derp.NewServer(key.NewNode(), tailnet.Logger(logger.Named("derp"))) - derpHandler, derpCloseFunc := tailnet.WithWebsocketSupport(derpServer, derphttp.Handler(derpServer)) - t.Cleanup(derpCloseFunc) + derpServer := atomic.Pointer[derpServer]{} + derpServer.Store(newDerpServer(t, logger)) + t.Cleanup(func() { + derpServer.Load().Close() + }) r := chi.NewRouter() r.Use( @@ -166,11 +219,32 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { return } - derpHandler.ServeHTTP(w, r) + derpServer.Load().ServeHTTP(w, r) }) r.Get("/latency-check", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + r.Post("/restart", func(w http.ResponseWriter, r *http.Request) { + oldServer := derpServer.Swap(newDerpServer(t, logger)) + oldServer.Close() + w.WriteHeader(http.StatusOK) + }) + }) + + // /restart?derp=[true|false]&coordinator=[true|false] + r.Post("/restart", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("derp") == "true" { + logger.Info(r.Context(), "killing DERP server") + oldServer := derpServer.Swap(newDerpServer(t, logger)) + oldServer.Close() + logger.Info(r.Context(), "restarted DERP server") + } + + if r.URL.Query().Get("coordinator") == "true" { + logger.Info(r.Context(), "simulating coordinator restart") + cm.CloseAll() + } + w.WriteHeader(http.StatusOK) }) r.Get("/api/v2/workspaceagents/{id}/coordinate", func(w http.ResponseWriter, r *http.Request) { @@ -199,6 +273,9 @@ func (o SimpleServerOptions) Router(t *testing.T, logger slog.Logger) *chi.Mux { ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageBinary) defer wsNetConn.Close() + cleanFn := cm.Add(id, wsNetConn) + defer cleanFn() + err = csvc.ServeConnV2(ctx, wsNetConn, tailnet.StreamID{ Name: "client-" + id.String(), ID: id, diff --git a/tailnet/test/integration/suite.go b/tailnet/test/integration/suite.go index 32d9adb2e4a14..e3403da32b359 100644 --- a/tailnet/test/integration/suite.go +++ b/tailnet/test/integration/suite.go @@ -4,8 +4,10 @@ package integration import ( + "net/http" "net/url" "testing" + "time" "github.com/stretchr/testify/require" @@ -14,9 +16,34 @@ import ( "github.com/coder/coder/v2/testutil" ) +// nolint:revive +func sendRestart(t *testing.T, serverURL *url.URL, derp bool, coordinator bool) { + t.Helper() + ctx := testutil.Context(t, 2*time.Second) + + serverURL, err := url.Parse(serverURL.String() + "/restart") + q := serverURL.Query() + if derp { + q.Set("derp", "true") + } + if coordinator { + q.Set("coordinator", "true") + } + serverURL.RawQuery = q.Encode() + require.NoError(t, err) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, serverURL.String(), nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode, "unexpected status code %d", resp.StatusCode) +} + // TODO: instead of reusing one conn for each suite, maybe we should make a new // one for each subtest? -func TestSuite(t *testing.T, _ slog.Logger, _ *url.URL, conn *tailnet.Conn, _, peer Client) { +func TestSuite(t *testing.T, _ slog.Logger, serverURL *url.URL, conn *tailnet.Conn, _, peer Client) { t.Parallel() t.Run("Connectivity", func(t *testing.T) { @@ -26,5 +53,30 @@ func TestSuite(t *testing.T, _ slog.Logger, _ *url.URL, conn *tailnet.Conn, _, p require.NoError(t, err, "ping peer") }) - // TODO: more + t.Run("RestartDERP", func(t *testing.T) { + peerIP := tailnet.IPFromUUID(peer.ID) + _, _, _, err := conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer") + sendRestart(t, serverURL, true, false) + _, _, _, err = conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer after derp restart") + }) + + t.Run("RestartCoordinator", func(t *testing.T) { + peerIP := tailnet.IPFromUUID(peer.ID) + _, _, _, err := conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer") + sendRestart(t, serverURL, false, true) + _, _, _, err = conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer after coordinator restart") + }) + + t.Run("RestartBoth", func(t *testing.T) { + peerIP := tailnet.IPFromUUID(peer.ID) + _, _, _, err := conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer") + sendRestart(t, serverURL, true, true) + _, _, _, err = conn.Ping(testutil.Context(t, testutil.WaitLong), peerIP) + require.NoError(t, err, "ping peer after restart") + }) }