@@ -37,6 +37,7 @@ import (
37
37
"github.com/coder/coder/v2/codersdk/workspacesdk"
38
38
"github.com/coder/coder/v2/cryptorand"
39
39
"github.com/coder/coder/v2/pty"
40
+ "github.com/coder/quartz"
40
41
"github.com/coder/retry"
41
42
"github.com/coder/serpent"
42
43
)
@@ -48,6 +49,8 @@ const (
48
49
var (
49
50
workspacePollInterval = time .Minute
50
51
autostopNotifyCountdown = []time.Duration {30 * time .Minute }
52
+ // gracefulShutdownTimeout is the timeout, per item in the stack of things to close
53
+ gracefulShutdownTimeout = 2 * time .Second
51
54
)
52
55
53
56
func (r * RootCmd ) ssh () * serpent.Command {
@@ -153,7 +156,7 @@ func (r *RootCmd) ssh() *serpent.Command {
153
156
// log HTTP requests
154
157
client .SetLogger (logger )
155
158
}
156
- stack := newCloserStack (ctx , logger )
159
+ stack := newCloserStack (ctx , logger , quartz . NewReal () )
157
160
defer stack .close (nil )
158
161
159
162
for _ , remoteForward := range remoteForwards {
@@ -936,11 +939,18 @@ type closerStack struct {
936
939
closed bool
937
940
logger slog.Logger
938
941
err error
939
- wg sync.WaitGroup
942
+ allDone chan struct {}
943
+
944
+ // for testing
945
+ clock quartz.Clock
940
946
}
941
947
942
- func newCloserStack (ctx context.Context , logger slog.Logger ) * closerStack {
943
- cs := & closerStack {logger : logger }
948
+ func newCloserStack (ctx context.Context , logger slog.Logger , clock quartz.Clock ) * closerStack {
949
+ cs := & closerStack {
950
+ logger : logger ,
951
+ allDone : make (chan struct {}),
952
+ clock : clock ,
953
+ }
944
954
go cs .closeAfterContext (ctx )
945
955
return cs
946
956
}
@@ -954,20 +964,58 @@ func (c *closerStack) close(err error) {
954
964
c .Lock ()
955
965
if c .closed {
956
966
c .Unlock ()
957
- c . wg . Wait ()
967
+ <- c . allDone
958
968
return
959
969
}
960
970
c .closed = true
961
971
c .err = err
962
- c .wg .Add (1 )
963
- defer c .wg .Done ()
964
972
c .Unlock ()
973
+ defer close (c .allDone )
974
+ if len (c .closers ) == 0 {
975
+ return
976
+ }
965
977
966
- for i := len (c .closers ) - 1 ; i >= 0 ; i -- {
967
- cwn := c .closers [i ]
968
- cErr := cwn .closer .Close ()
969
- c .logger .Debug (context .Background (),
970
- "closed item from stack" , slog .F ("name" , cwn .name ), slog .Error (cErr ))
978
+ // We are going to work down the stack in order. If things close quickly, we trigger the
979
+ // closers serially, in order. `done` is a channel that indicates the nth closer is done
980
+ // closing, and we should trigger the (n-1) closer. However, if things take too long we don't
981
+ // want to wait, so we also start a ticker that works down the stack and sends on `done` as
982
+ // well.
983
+ next := len (c .closers ) - 1
984
+ // here we make the buffer 2x the number of closers because we could write once for it being
985
+ // actually done and once via the countdown for each closer
986
+ done := make (chan int , len (c .closers )* 2 )
987
+ startNext := func () {
988
+ go func (i int ) {
989
+ defer func () { done <- i }()
990
+ cwn := c .closers [i ]
991
+ cErr := cwn .closer .Close ()
992
+ c .logger .Debug (context .Background (),
993
+ "closed item from stack" , slog .F ("name" , cwn .name ), slog .Error (cErr ))
994
+ }(next )
995
+ next --
996
+ }
997
+ done <- len (c .closers ) // kick us off right away
998
+
999
+ // start a ticking countdown in case we hang/don't close quickly
1000
+ countdown := len (c .closers ) - 1
1001
+ ctx , cancel := context .WithCancel (context .Background ())
1002
+ defer cancel ()
1003
+ c .clock .TickerFunc (ctx , gracefulShutdownTimeout , func () error {
1004
+ if countdown < 0 {
1005
+ return nil
1006
+ }
1007
+ done <- countdown
1008
+ countdown --
1009
+ return nil
1010
+ }, "closerStack" )
1011
+
1012
+ for n := range done { // the nth closer is done
1013
+ if n == 0 {
1014
+ return
1015
+ }
1016
+ if n - 1 == next {
1017
+ startNext ()
1018
+ }
971
1019
}
972
1020
}
973
1021
0 commit comments