|
8 | 8 | "context"
|
9 | 9 | "fmt"
|
10 | 10 | "net"
|
| 11 | + "os" |
11 | 12 | "os/user"
|
| 13 | + "path/filepath" |
12 | 14 | "runtime"
|
13 | 15 | "strings"
|
14 | 16 | "sync"
|
@@ -403,6 +405,81 @@ func TestNewServer_Signal(t *testing.T) {
|
403 | 405 | })
|
404 | 406 | }
|
405 | 407 |
|
| 408 | +func TestSSHServer_ClosesStdin(t *testing.T) { |
| 409 | + t.Parallel() |
| 410 | + if runtime.GOOS == "windows" { |
| 411 | + t.Skip("bash doesn't exist on Windows") |
| 412 | + } |
| 413 | + |
| 414 | + ctx := testutil.Context(t, testutil.WaitMedium) |
| 415 | + logger := testutil.Logger(t) |
| 416 | + s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil) |
| 417 | + require.NoError(t, err) |
| 418 | + defer s.Close() |
| 419 | + err = s.UpdateHostSigner(42) |
| 420 | + assert.NoError(t, err) |
| 421 | + |
| 422 | + ln, err := net.Listen("tcp", "127.0.0.1:0") |
| 423 | + require.NoError(t, err) |
| 424 | + |
| 425 | + done := make(chan struct{}) |
| 426 | + go func() { |
| 427 | + defer close(done) |
| 428 | + err := s.Serve(ln) |
| 429 | + assert.Error(t, err) // Server is closed. |
| 430 | + }() |
| 431 | + defer func() { |
| 432 | + err := s.Close() |
| 433 | + require.NoError(t, err) |
| 434 | + <-done |
| 435 | + }() |
| 436 | + |
| 437 | + c := sshClient(t, ln.Addr().String()) |
| 438 | + |
| 439 | + sess, err := c.NewSession() |
| 440 | + require.NoError(t, err) |
| 441 | + stdout, err := sess.StdoutPipe() |
| 442 | + require.NoError(t, err) |
| 443 | + stdin, err := sess.StdinPipe() |
| 444 | + require.NoError(t, err) |
| 445 | + defer stdin.Close() |
| 446 | + |
| 447 | + dir := t.TempDir() |
| 448 | + err = os.MkdirAll(dir, 0o755) |
| 449 | + require.NoError(t, err) |
| 450 | + filePath := filepath.Join(dir, "result.txt") |
| 451 | + |
| 452 | + // the shell command `read` will block until data is written to stdin, or closed. It will return |
| 453 | + // exit code 1 if it hits EOF, which is what we want to test. |
| 454 | + cmdErrCh := make(chan error, 1) |
| 455 | + go func() { |
| 456 | + cmdErrCh <- sess.Start(fmt.Sprintf("echo started; read; echo \"read exit code: $?\" > %s", filePath)) |
| 457 | + }() |
| 458 | + |
| 459 | + cmdErr := testutil.RequireReceive(ctx, t, cmdErrCh) |
| 460 | + require.NoError(t, cmdErr) |
| 461 | + |
| 462 | + readCh := make(chan error, 1) |
| 463 | + go func() { |
| 464 | + buf := make([]byte, 8) |
| 465 | + _, err := stdout.Read(buf) |
| 466 | + assert.Equal(t, "started\n", string(buf)) |
| 467 | + readCh <- err |
| 468 | + }() |
| 469 | + err = testutil.RequireReceive(ctx, t, readCh) |
| 470 | + require.NoError(t, err) |
| 471 | + |
| 472 | + sess.Close() |
| 473 | + |
| 474 | + var content []byte |
| 475 | + require.Eventually(t, func() bool { |
| 476 | + content, err = os.ReadFile(filePath) |
| 477 | + return err == nil |
| 478 | + }, testutil.WaitMedium, testutil.IntervalFast) |
| 479 | + require.NoError(t, err) |
| 480 | + require.Equal(t, "read exit code: 1\n", string(content)) |
| 481 | +} |
| 482 | + |
406 | 483 | func sshClient(t *testing.T, addr string) *ssh.Client {
|
407 | 484 | conn, err := net.Dial("tcp", addr)
|
408 | 485 | require.NoError(t, err)
|
|
0 commit comments