diff --git a/cli/server.go b/cli/server.go index 2a9d7f2a9149e..41e5905b0f976 100644 --- a/cli/server.go +++ b/cli/server.go @@ -20,6 +20,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" "github.com/coreos/go-systemd/daemon" @@ -111,26 +112,34 @@ func server() *cobra.Command { logger = logger.Leveled(slog.LevelDebug) } + // Main command context for managing cancellation + // of running services. + ctx, cancel := context.WithCancel(cmd.Context()) + defer cancel() + + // Clean up idle connections at the end, e.g. + // embedded-postgres can leave an idle connection + // which is caught by goleaks. + defer http.DefaultClient.CloseIdleConnections() + var ( tracerProvider *sdktrace.TracerProvider err error sqlDriver = "postgres" ) if trace { - tracerProvider, err = tracing.TracerProvider(cmd.Context(), "coderd") + tracerProvider, err = tracing.TracerProvider(ctx, "coderd") if err != nil { - logger.Warn(cmd.Context(), "failed to start telemetry exporter", slog.Error(err)) + logger.Warn(ctx, "failed to start telemetry exporter", slog.Error(err)) } else { + // allow time for traces to flush even if command context is canceled defer func() { - // allow time for traces to flush even if command context is canceled - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = tracerProvider.Shutdown(ctx) + _ = shutdownWithTimeout(tracerProvider, 5*time.Second) }() d, err := tracing.PostgresDriver(tracerProvider, "coderd.database") if err != nil { - logger.Warn(cmd.Context(), "failed to start postgres tracing driver", slog.Error(err)) + logger.Warn(ctx, "failed to start postgres tracing driver", slog.Error(err)) } else { sqlDriver = d } @@ -143,14 +152,16 @@ func server() *cobra.Command { if !inMemoryDatabase && postgresURL == "" { var closeFunc func() error cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath()) - postgresURL, closeFunc, err = startBuiltinPostgres(cmd.Context(), config, logger) + postgresURL, closeFunc, err = startBuiltinPostgres(ctx, config, logger) if err != nil { return err } builtinPostgres = true defer func() { + cmd.Printf("Stopping built-in PostgreSQL...\n") // Gracefully shut PostgreSQL down! _ = closeFunc() + cmd.Printf("Stopped built-in PostgreSQL\n") }() } @@ -189,9 +200,9 @@ func server() *cobra.Command { } var ( - ctxTunnel, closeTunnel = context.WithCancel(cmd.Context()) - devTunnel = (*devtunnel.Tunnel)(nil) - devTunnelErrChan = make(<-chan error, 1) + ctxTunnel, closeTunnel = context.WithCancel(ctx) + devTunnel *devtunnel.Tunnel + devTunnelErr <-chan error ) defer closeTunnel() @@ -199,7 +210,7 @@ func server() *cobra.Command { // needs to be changed to use the tunnel. if tunnel { cmd.Printf("Opening tunnel so workspaces can connect to your deployment\n") - devTunnel, devTunnelErrChan, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel")) + devTunnel, devTunnelErr, err = devtunnel.New(ctxTunnel, logger.Named("devtunnel")) if err != nil { return xerrors.Errorf("create tunnel: %w", err) } @@ -207,7 +218,7 @@ func server() *cobra.Command { } // Warn the user if the access URL appears to be a loopback address. - isLocal, err := isLocalURL(cmd.Context(), accessURL) + isLocal, err := isLocalURL(ctx, accessURL) if isLocal || err != nil { reason := "could not be resolved" if isLocal { @@ -224,7 +235,7 @@ func server() *cobra.Command { } // Used for zero-trust instance identity with Google Cloud. - googleTokenValidator, err := idtoken.NewValidator(cmd.Context(), option.WithoutAuthentication()) + googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication()) if err != nil { return err } @@ -241,6 +252,7 @@ func server() *cobra.Command { if err != nil { return xerrors.Errorf("create turn server: %w", err) } + defer turnServer.Close() iceServers := make([]webrtc.ICEServer, 0) for _, stunServer := range stunServers { @@ -278,6 +290,8 @@ func server() *cobra.Command { if err != nil { return xerrors.Errorf("dial postgres: %w", err) } + defer sqlDB.Close() + err = sqlDB.Ping() if err != nil { return xerrors.Errorf("ping postgres: %w", err) @@ -287,13 +301,14 @@ func server() *cobra.Command { return xerrors.Errorf("migrate up: %w", err) } options.Database = database.New(sqlDB) - options.Pubsub, err = database.NewPubsub(cmd.Context(), sqlDB, postgresURL) + options.Pubsub, err = database.NewPubsub(ctx, sqlDB, postgresURL) if err != nil { return xerrors.Errorf("create pubsub: %w", err) } + defer options.Pubsub.Close() } - deploymentID, err := options.Database.GetDeploymentID(cmd.Context()) + deploymentID, err := options.Database.GetDeploymentID(ctx) if errors.Is(err, sql.ErrNoRows) { err = nil } @@ -302,7 +317,7 @@ func server() *cobra.Command { } if deploymentID == "" { deploymentID = uuid.NewString() - err = options.Database.InsertDeploymentID(cmd.Context(), deploymentID) + err = options.Database.InsertDeploymentID(ctx, deploymentID) if err != nil { return xerrors.Errorf("set deployment id: %w", err) } @@ -336,6 +351,8 @@ func server() *cobra.Command { } coderAPI := coderd.New(options) + defer coderAPI.Close() + client := codersdk.New(localURL) if tlsEnable { // Secure transport isn't needed for locally communicating! @@ -351,64 +368,75 @@ func server() *cobra.Command { _ = pprof.Handler if pprofEnabled { //nolint:revive - defer serveHandler(cmd.Context(), logger, nil, pprofAddress, "pprof")() + defer serveHandler(ctx, logger, nil, pprofAddress, "pprof")() } if promEnabled { //nolint:revive - defer serveHandler(cmd.Context(), logger, promhttp.Handler(), promAddress, "prometheus")() + defer serveHandler(ctx, logger, promhttp.Handler(), promAddress, "prometheus")() } + // Since errCh only has one buffered slot, all routines + // sending on it must be wrapped in a select/default to + // avoid leaving dangling goroutines waiting for the + // channel to be consumed. errCh := make(chan error, 1) provisionerDaemons := make([]*provisionerd.Server, 0) + defer func() { + // We have no graceful shutdown of provisionerDaemons + // here because that's handled at the end of main, this + // is here in case the program exits early. + for _, daemon := range provisionerDaemons { + _ = daemon.Close() + } + }() for i := 0; uint8(i) < provisionerDaemonCount; i++ { - daemonClose, err := newProvisionerDaemon(cmd.Context(), coderAPI, logger, cacheDir, errCh, false) + daemon, err := newProvisionerDaemon(ctx, coderAPI, logger, cacheDir, errCh, false) if err != nil { return xerrors.Errorf("create provisioner daemon: %w", err) } - provisionerDaemons = append(provisionerDaemons, daemonClose) + provisionerDaemons = append(provisionerDaemons, daemon) + } + + shutdownConnsCtx, shutdownConns := context.WithCancel(ctx) + defer shutdownConns() + server := &http.Server{ + // These errors are typically noise like "TLS: EOF". Vault does similar: + // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 + ErrorLog: log.New(io.Discard, "", 0), + Handler: coderAPI.Handler, + BaseContext: func(_ net.Listener) context.Context { + return shutdownConnsCtx + }, } defer func() { - for _, provisionerDaemon := range provisionerDaemons { - _ = provisionerDaemon.Close() - } + _ = shutdownWithTimeout(server, 5*time.Second) }() - shutdownConnsCtx, shutdownConns := context.WithCancel(cmd.Context()) - defer shutdownConns() - go func() { - server := http.Server{ - // These errors are typically noise like "TLS: EOF". Vault does similar: - // https://github.com/hashicorp/vault/blob/e2490059d0711635e529a4efcbaa1b26998d6e1c/command/server.go#L2714 - ErrorLog: log.New(io.Discard, "", 0), - Handler: coderAPI.Handler, - BaseContext: func(_ net.Listener) context.Context { - return shutdownConnsCtx - }, + eg := errgroup.Group{} + eg.Go(func() error { + // Make sure to close the tunnel listener if we exit so the + // errgroup doesn't wait forever! + if tunnel { + defer devTunnel.Listener.Close() } - wg := errgroup.Group{} - wg.Go(func() error { - // Make sure to close the tunnel listener if we exit so the - // errgroup doesn't wait forever! - if tunnel { - defer devTunnel.Listener.Close() - } + return server.Serve(listener) + }) + if tunnel { + eg.Go(func() error { + defer listener.Close() - return server.Serve(listener) + return server.Serve(devTunnel.Listener) }) - - if tunnel { - wg.Go(func() error { - defer listener.Close() - - return server.Serve(devTunnel.Listener) - }) + } + go func() { + select { + case errCh <- eg.Wait(): + default: } - - errCh <- wg.Wait() }() - hasFirstUser, err := client.HasFirstUser(cmd.Context()) + hasFirstUser, err := client.HasFirstUser(ctx) if !hasFirstUser && err == nil { cmd.Println() cmd.Println("Get started by creating the first user (in a new terminal):") @@ -425,75 +453,117 @@ func server() *cobra.Command { autobuildPoller := time.NewTicker(autobuildPollInterval) defer autobuildPoller.Stop() - autobuildExecutor := executor.New(cmd.Context(), options.Database, logger, autobuildPoller.C) + autobuildExecutor := executor.New(ctx, options.Database, logger, autobuildPoller.C) autobuildExecutor.Run() + // This is helpful for tests, but can be silently ignored. + // Coder may be ran as users that don't have permission to write in the homedir, + // such as via the systemd service. + _ = config.URL().Write(client.URL.String()) + // Because the graceful shutdown includes cleaning up workspaces in dev mode, we're // going to make it harder to accidentally skip the graceful shutdown by hitting ctrl+c // two or more times. So the stopChan is unlimited in size and we don't call // signal.Stop() until graceful shutdown finished--this means we swallow additional // SIGINT after the first. To get out of a graceful shutdown, the user can send SIGQUIT // with ctrl+\ or SIGTERM with `kill`. - stopChan := make(chan os.Signal, 1) - defer signal.Stop(stopChan) - signal.Notify(stopChan, os.Interrupt) - - // This is helpful for tests, but can be silently ignored. - // Coder may be ran as users that don't have permission to write in the homedir, - // such as via the systemd service. - _ = config.URL().Write(client.URL.String()) + ctx, stop := signal.NotifyContext(ctx, os.Interrupt) + defer stop() + // Currently there is no way to ask the server to shut + // itself down, so any exit signal will result in a non-zero + // exit of the server. + var exitErr error select { - case <-cmd.Context().Done(): - coderAPI.Close() - return cmd.Context().Err() - case err := <-devTunnelErrChan: - if err != nil { - return err + case <-ctx.Done(): + exitErr = ctx.Err() + _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render( + "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", + )) + case exitErr = <-devTunnelErr: + if exitErr == nil { + exitErr = xerrors.New("dev tunnel closed unexpectedly") } - case err := <-errCh: - shutdownConns() - coderAPI.Close() - return err - case <-stopChan: + case exitErr = <-errCh: } + if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) { + cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr) + } + + // Begin clean shut down stage, we try to shut down services + // gracefully in an order that gives the best experience. + // This procedure should not differ greatly from the order + // of `defer`s in this function, but allows us to inform + // the user about what's going on and handle errors more + // explicitly. + _, err = daemon.SdNotify(false, daemon.SdNotifyStopping) if err != nil { - return xerrors.Errorf("notify systemd: %w", err) + cmd.Printf("Notify systemd failed: %s", err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render( - "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit")) - for _, provisionerDaemon := range provisionerDaemons { - if verbose { - cmd.Println("Shutting down provisioner daemon...") - } - err = provisionerDaemon.Shutdown(cmd.Context()) - if err != nil { - cmd.PrintErrf("Failed to shutdown provisioner daemon: %s\n", err) - continue - } - err = provisionerDaemon.Close() - if err != nil { - return xerrors.Errorf("close provisioner daemon: %w", err) - } - if verbose { - cmd.Println("Gracefully shut down provisioner daemon!") - } + // Stop accepting new connections without interrupting + // in-flight requests, give in-flight requests 5 seconds to + // complete. + cmd.Println("Shutting down API server...") + err = shutdownWithTimeout(server, 5*time.Second) + if err != nil { + cmd.Printf("API server shutdown took longer than 5s: %s", err) + } else { + cmd.Printf("Gracefully shut down API server\n") } + // Cancel any remaining in-flight requests. + shutdownConns() + + // Shut down provisioners before waiting for WebSockets + // connections to close. + var wg sync.WaitGroup + for i, provisionerDaemon := range provisionerDaemons { + id := i + 1 + provisionerDaemon := provisionerDaemon + wg.Add(1) + go func() { + defer wg.Done() + + if verbose { + cmd.Printf("Shutting down provisioner daemon %d...\n", id) + } + err := shutdownWithTimeout(provisionerDaemon, 5*time.Second) + if err != nil { + cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err) + return + } + err = provisionerDaemon.Close() + if err != nil { + cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err) + return + } + if verbose { + cmd.Printf("Gracefully shut down provisioner daemon %d\n", id) + } + }() + } + wg.Wait() + + cmd.Println("Waiting for WebSocket connections to close...") + _ = coderAPI.Close() + cmd.Println("Done wainting for WebSocket connections") + // Close tunnel after we no longer have in-flight connections. if tunnel { cmd.Println("Waiting for tunnel to close...") closeTunnel() - <-devTunnelErrChan + <-devTunnelErr + cmd.Println("Done waiting for tunnel") } // Ensures a last report can be sent before exit! options.Telemetry.Close() - cmd.Println("Waiting for WebSocket connections to close...") - shutdownConns() - coderAPI.Close() - return nil + + // Trigger context cancellation for any remaining services. + cancel() + + return exitErr }, } @@ -602,16 +672,37 @@ func server() *cobra.Command { return root } +func shutdownWithTimeout(s interface{ Shutdown(context.Context) error }, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return s.Shutdown(ctx) +} + // nolint:revive func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API, - logger slog.Logger, cacheDir string, errChan chan error, dev bool) (*provisionerd.Server, error) { - err := os.MkdirAll(cacheDir, 0700) + logger slog.Logger, cacheDir string, errCh chan error, dev bool, +) (srv *provisionerd.Server, err error) { + ctx, cancel := context.WithCancel(ctx) + defer func() { + if err != nil { + cancel() + } + }() + + err = os.MkdirAll(cacheDir, 0o700) if err != nil { return nil, xerrors.Errorf("mkdir %q: %w", cacheDir, err) } terraformClient, terraformServer := provisionersdk.TransportPipe() go func() { + <-ctx.Done() + _ = terraformClient.Close() + _ = terraformServer.Close() + }() + go func() { + defer cancel() + err := terraform.Serve(ctx, &terraform.ServeOptions{ ServeOptions: &provisionersdk.ServeOptions{ Listener: terraformServer, @@ -620,7 +711,10 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API, Logger: logger, }) if err != nil && !xerrors.Is(err, context.Canceled) { - errChan <- err + select { + case errCh <- err: + default: + } } }() @@ -636,9 +730,19 @@ func newProvisionerDaemon(ctx context.Context, coderAPI *coderd.API, if dev { echoClient, echoServer := provisionersdk.TransportPipe() go func() { + <-ctx.Done() + _ = echoClient.Close() + _ = echoServer.Close() + }() + go func() { + defer cancel() + err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer}) if err != nil { - errChan <- err + select { + case errCh <- err: + default: + } } }() provisioners[string(database.ProvisionerTypeEcho)] = proto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)) diff --git a/cli/server_test.go b/cli/server_test.go index 4571c40a33c1e..ae70324b33741 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -17,7 +17,6 @@ import ( "net/url" "os" "runtime" - "strings" "testing" "time" @@ -30,6 +29,7 @@ import ( "github.com/coder/coder/coderd/database/postgres" "github.com/coder/coder/coderd/telemetry" "github.com/coder/coder/codersdk" + "github.com/coder/coder/pty/ptytest" ) // This cannot be ran in parallel because it uses a signal. @@ -45,13 +45,14 @@ func TestServer(t *testing.T) { defer closeFunc() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, cfg := clitest.New(t, "server", "--address", ":0", "--postgres-url", connectionURL, "--cache-dir", t.TempDir(), ) - errC := make(chan error) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() @@ -80,12 +81,17 @@ func TestServer(t *testing.T) { t.SkipNow() } ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + root, cfg := clitest.New(t, "server", "--address", ":0", "--cache-dir", t.TempDir(), ) - errC := make(chan error) + pty := ptytest.New(t) + root.SetOutput(pty.Output()) + root.SetErr(pty.Output()) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() @@ -99,11 +105,12 @@ func TestServer(t *testing.T) { t.Run("BuiltinPostgresURL", func(t *testing.T) { t.Parallel() root, _ := clitest.New(t, "server", "postgres-builtin-url") - var buf strings.Builder - root.SetOutput(&buf) + pty := ptytest.New(t) + root.SetOutput(pty.Output()) err := root.Execute() require.NoError(t, err) - require.Contains(t, buf.String(), "psql") + + pty.ExpectMatch("psql") }) t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) { @@ -118,9 +125,9 @@ func TestServer(t *testing.T) { "--access-url", "http://1.2.3.4:3000/", "--cache-dir", t.TempDir(), ) - var buf strings.Builder - errC := make(chan error) - root.SetOutput(&buf) + buf := newThreadSafeBuffer() + root.SetOutput(buf) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() @@ -142,6 +149,7 @@ func TestServer(t *testing.T) { t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, _ := clitest.New(t, "server", "--in-memory", @@ -157,6 +165,7 @@ func TestServer(t *testing.T) { t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, _ := clitest.New(t, "server", "--in-memory", @@ -172,6 +181,7 @@ func TestServer(t *testing.T) { t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, _ := clitest.New(t, "server", "--in-memory", @@ -197,7 +207,7 @@ func TestServer(t *testing.T) { "--tls-key-file", keyPath, "--cache-dir", t.TempDir(), ) - errC := make(chan error) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() @@ -236,6 +246,7 @@ func TestServer(t *testing.T) { } ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, cfg := clitest.New(t, "server", "--in-memory", @@ -243,7 +254,7 @@ func TestServer(t *testing.T) { "--provisioner-daemons", "1", "--cache-dir", t.TempDir(), ) - serverErr := make(chan error) + serverErr := make(chan error, 1) go func() { serverErr <- root.ExecuteContext(ctx) }() @@ -259,12 +270,13 @@ func TestServer(t *testing.T) { // We cannot send more signals here, because it's possible Coder // has already exited, which could cause the test to fail due to interrupt. err = <-serverErr - require.NoError(t, err) + require.ErrorIs(t, err, context.Canceled) }) t.Run("TracerNoLeak", func(t *testing.T) { t.Parallel() ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() + root, _ := clitest.New(t, "server", "--in-memory", @@ -272,7 +284,7 @@ func TestServer(t *testing.T) { "--trace=true", "--cache-dir", t.TempDir(), ) - errC := make(chan error) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() @@ -310,7 +322,7 @@ func TestServer(t *testing.T) { "--telemetry-url", server.URL, "--cache-dir", t.TempDir(), ) - errC := make(chan error) + errC := make(chan error, 1) go func() { errC <- root.ExecuteContext(ctx) }() diff --git a/coderd/autobuild/executor/lifecycle_executor.go b/coderd/autobuild/executor/lifecycle_executor.go index 56dff0a8e6d2e..97582f10de6ca 100644 --- a/coderd/autobuild/executor/lifecycle_executor.go +++ b/coderd/autobuild/executor/lifecycle_executor.go @@ -54,15 +54,27 @@ func (e *Executor) WithStatsChannel(ch chan<- Stats) *Executor { // its channel is closed. func (e *Executor) Run() { go func() { - for t := range e.tick { - stats := e.runOnce(t) - if stats.Error != nil { - e.log.Error(e.ctx, "error running once", slog.Error(stats.Error)) + for { + select { + case <-e.ctx.Done(): + return + case t, ok := <-e.tick: + if !ok { + return + } + stats := e.runOnce(t) + if stats.Error != nil { + e.log.Error(e.ctx, "error running once", slog.Error(stats.Error)) + } + if e.statsCh != nil { + select { + case <-e.ctx.Done(): + return + case e.statsCh <- stats: + } + } + e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions)) } - if e.statsCh != nil { - e.statsCh <- stats - } - e.log.Debug(e.ctx, "run stats", slog.F("elapsed", stats.Elapsed), slog.F("transitions", stats.Transitions)) } }() } diff --git a/coderd/devtunnel/tunnel.go b/coderd/devtunnel/tunnel.go index b177e6fb88141..898a03f899202 100644 --- a/coderd/devtunnel/tunnel.go +++ b/coderd/devtunnel/tunnel.go @@ -104,7 +104,7 @@ allowed_ip=%s/128`, return nil, nil, xerrors.Errorf("wireguard device listen: %w", err) } - ch := make(chan error) + ch := make(chan error, 1) go func() { select { case <-ctx.Done(): diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index 0ecbf4d841eec..befcf2c1283e5 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -38,6 +38,10 @@ func Serve(ctx context.Context, server proto.DRPCProvisionerServer, options *Ser if err != nil { return xerrors.Errorf("create yamux: %w", err) } + go func() { + <-ctx.Done() + _ = stdio.Close() + }() options.Listener = stdio }