@@ -21,6 +21,7 @@ import (
21
21
"go.uber.org/goleak"
22
22
"golang.org/x/crypto/ssh"
23
23
24
+ "cdr.dev/slog"
24
25
"cdr.dev/slog/sloggers/slogtest"
25
26
26
27
"github.com/coder/coder/v2/agent/agentexec"
@@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
147
148
func TestNewServer_CloseActiveConnections (t * testing.T ) {
148
149
t .Parallel ()
149
150
150
- ctx := context .Background ()
151
- logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true })
152
- s , err := agentssh .NewServer (ctx , logger , prometheus .NewRegistry (), afero .NewMemMapFs (), agentexec .DefaultExecer , nil )
153
- require .NoError (t , err )
154
- defer s .Close ()
155
- err = s .UpdateHostSigner (42 )
156
- assert .NoError (t , err )
151
+ prepare := func (ctx context.Context , t * testing.T ) (* agentssh.Server , func ()) {
152
+ t .Helper ()
153
+ logger := slogtest .Make (t , & slogtest.Options {IgnoreErrors : true }).Leveled (slog .LevelDebug )
154
+ s , err := agentssh .NewServer (ctx , logger , prometheus .NewRegistry (), afero .NewMemMapFs (), agentexec .DefaultExecer , nil )
155
+ require .NoError (t , err )
156
+ defer s .Close ()
157
+ err = s .UpdateHostSigner (42 )
158
+ assert .NoError (t , err )
157
159
158
- ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
159
- require .NoError (t , err )
160
+ ln , err := net .Listen ("tcp" , "127.0.0.1:0" )
161
+ require .NoError (t , err )
160
162
161
- var wg sync.WaitGroup
162
- wg .Add (2 )
163
- go func () {
164
- defer wg .Done ()
165
- err := s .Serve (ln )
166
- assert .Error (t , err ) // Server is closed.
167
- }()
163
+ waitConns := make ([]chan struct {}, 4 )
168
164
169
- pty := ptytest .New (t )
165
+ var wg sync.WaitGroup
166
+ wg .Add (1 + len (waitConns ))
170
167
171
- doClose := make (chan struct {})
172
- go func () {
173
- defer wg .Done ()
174
- c := sshClient (t , ln .Addr ().String ())
175
- sess , err := c .NewSession ()
176
- assert .NoError (t , err )
177
- sess .Stdin = pty .Input ()
178
- sess .Stdout = pty .Output ()
179
- sess .Stderr = pty .Output ()
168
+ go func () {
169
+ defer wg .Done ()
170
+ err := s .Serve (ln )
171
+ assert .Error (t , err ) // Server is closed.
172
+ }()
180
173
181
- assert .NoError (t , err )
182
- err = sess .Start ("" )
183
- assert .NoError (t , err )
174
+ for i := 0 ; i < len (waitConns ); i ++ {
175
+ waitConns [i ] = make (chan struct {})
176
+ go func (ch chan struct {}) {
177
+ defer wg .Done ()
178
+ c := sshClient (t , ln .Addr ().String ())
179
+ sess , err := c .NewSession ()
180
+ assert .NoError (t , err )
181
+ pty := ptytest .New (t )
182
+ sess .Stdin = pty .Input ()
183
+ sess .Stdout = pty .Output ()
184
+ sess .Stderr = pty .Output ()
185
+
186
+ // Every other session will request a PTY.
187
+ if i % 2 == 0 {
188
+ err = sess .RequestPty ("xterm" , 80 , 80 , nil )
189
+ assert .NoError (t , err )
190
+ }
191
+ // The 60 seconds here is intended to be longer than the
192
+ // test. The shutdown should propagate.
193
+ err = sess .Start ("/bin/bash -c 'trap \" sleep 60\" SIGTERM; sleep 60'" )
194
+ assert .NoError (t , err )
195
+
196
+ close (ch )
197
+ err = sess .Wait ()
198
+ assert .Error (t , err )
199
+ }(waitConns [i ])
200
+ }
184
201
185
- close (doClose )
186
- err = sess .Wait ()
187
- assert .Error (t , err )
188
- }()
202
+ for _ , ch := range waitConns {
203
+ <- ch
204
+ }
189
205
190
- <- doClose
191
- err = s .Close ()
192
- require .NoError (t , err )
206
+ return s , wg .Wait
207
+ }
208
+
209
+ t .Run ("Close" , func (t * testing.T ) {
210
+ t .Parallel ()
211
+ ctx := testutil .Context (t , testutil .WaitMedium )
212
+ s , wait := prepare (ctx , t )
213
+ err := s .Close ()
214
+ require .NoError (t , err )
215
+ wait ()
216
+ })
193
217
194
- wg .Wait ()
218
+ t .Run ("Shutdown" , func (t * testing.T ) {
219
+ t .Parallel ()
220
+ ctx := testutil .Context (t , testutil .WaitMedium )
221
+ s , wait := prepare (ctx , t )
222
+ err := s .Shutdown (ctx )
223
+ require .NoError (t , err )
224
+ wait ()
225
+ })
226
+
227
+ t .Run ("Shutdown Early" , func (t * testing.T ) {
228
+ t .Parallel ()
229
+ ctx := testutil .Context (t , testutil .WaitMedium )
230
+ s , wait := prepare (ctx , t )
231
+ ctx , cancel := context .WithCancel (ctx )
232
+ cancel ()
233
+ err := s .Shutdown (ctx )
234
+ require .ErrorIs (t , err , context .Canceled )
235
+ wait ()
236
+ })
195
237
}
196
238
197
239
func TestNewServer_Signal (t * testing.T ) {
0 commit comments