Skip to content

fix: Guard against CLI cmd running after test exit #1658

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 23, 2022
39 changes: 21 additions & 18 deletions cli/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,12 @@ func TestWorkspaceAgent(t *testing.T) {
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)
go func() {
// A linting error occurs for weakly typing the context value here,
// but it seems reasonable for a one-off test.
// nolint
ctx = context.WithValue(ctx, "azure-client", metadataClient)
err := cmd.ExecuteContext(ctx)
require.NoError(t, err)
// A linting error occurs for weakly typing the context value here.
//nolint // The above seems reasonable for a one-off test.
ctx := context.WithValue(ctx, "azure-client", metadataClient)
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
Expand All @@ -66,6 +65,8 @@ func TestWorkspaceAgent(t *testing.T) {
_, err = dialer.Ping()
require.NoError(t, err)
cancelFunc()
err = <-errC
require.NoError(t, err)
})

t.Run("AWS", func(t *testing.T) {
Expand Down Expand Up @@ -103,13 +104,12 @@ func TestWorkspaceAgent(t *testing.T) {
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)
go func() {
// A linting error occurs for weakly typing the context value here,
// but it seems reasonable for a one-off test.
// nolint
ctx = context.WithValue(ctx, "aws-client", metadataClient)
err := cmd.ExecuteContext(ctx)
require.NoError(t, err)
// A linting error occurs for weakly typing the context value here.
//nolint // The above seems reasonable for a one-off test.
ctx := context.WithValue(ctx, "aws-client", metadataClient)
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
Expand All @@ -120,6 +120,8 @@ func TestWorkspaceAgent(t *testing.T) {
_, err = dialer.Ping()
require.NoError(t, err)
cancelFunc()
err = <-errC
require.NoError(t, err)
})

t.Run("GoogleCloud", func(t *testing.T) {
Expand Down Expand Up @@ -157,13 +159,12 @@ func TestWorkspaceAgent(t *testing.T) {
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
errC := make(chan error)
go func() {
// A linting error occurs for weakly typing the context value here,
// but it seems reasonable for a one-off test.
// nolint
ctx = context.WithValue(ctx, "gcp-client", metadata)
err := cmd.ExecuteContext(ctx)
require.NoError(t, err)
// A linting error occurs for weakly typing the context value here.
//nolint // The above seems reasonable for a one-off test.
ctx := context.WithValue(ctx, "gcp-client", metadata)
errC <- cmd.ExecuteContext(ctx)
}()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID)
Expand All @@ -174,5 +175,7 @@ func TestWorkspaceAgent(t *testing.T) {
_, err = dialer.Ping()
require.NoError(t, err)
cancelFunc()
err = <-errC
require.NoError(t, err)
})
}
4 changes: 1 addition & 3 deletions cli/clitest/clitest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package clitest_test
import (
"testing"

"github.com/stretchr/testify/require"
"go.uber.org/goleak"

"github.com/coder/coder/cli/clitest"
Expand All @@ -25,8 +24,7 @@ func TestCli(t *testing.T) {
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
err := cmd.Execute()
require.NoError(t, err)
_ = cmd.Execute()
}()
pty.ExpectMatch("coder")
}
19 changes: 13 additions & 6 deletions cli/gitssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ func TestGitSSH(t *testing.T) {
clitest.SetupConfig(t, agentClient, root)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
agentErrC := make(chan error)
go func() {
err := cmd.ExecuteContext(ctx)
require.NoError(t, err)
agentErrC <- cmd.ExecuteContext(ctx)
}()

coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
Expand All @@ -85,23 +85,30 @@ func TestGitSSH(t *testing.T) {
return ssh.KeysEqual(publicKey, key)
})
var inc int64
sshErrC := make(chan error)
go func() {
// as long as we get a successful session we don't care if the server errors
_ = ssh.Serve(l, func(s ssh.Session) {
atomic.AddInt64(&inc, 1)
t.Log("got authenticated session")
err := s.Exit(0)
require.NoError(t, err)
sshErrC <- s.Exit(0)
}, publicKeyOption)
}()

// start ssh session
addr, ok := l.Addr().(*net.TCPAddr)
require.True(t, ok)
// set to agent config dir
cmd, _ = clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "127.0.0.1")
err = cmd.ExecuteContext(context.Background())
gitsshCmd, _ := clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "-o", "IdentitiesOnly=yes", "127.0.0.1")
err = gitsshCmd.ExecuteContext(context.Background())
require.NoError(t, err)
require.EqualValues(t, 1, inc)

err = <-sshErrC
require.NoError(t, err, "error in ssh session exit")

cancelFunc()
err = <-agentErrC
require.NoError(t, err, "error in agent execute")
})
}
13 changes: 8 additions & 5 deletions cli/list_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package cli_test

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand All @@ -14,6 +16,8 @@ func TestList(t *testing.T) {
t.Parallel()
t.Run("Single", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelFunc()
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true})
user := coderdtest.CreateFirstUser(t, client)
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil)
Expand All @@ -23,17 +27,16 @@ func TestList(t *testing.T) {
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
cmd, root := clitest.New(t, "ls")
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
errC := make(chan error)
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
errC <- cmd.ExecuteContext(ctx)
}()
pty.ExpectMatch(workspace.Name)
pty.ExpectMatch("Running")
<-doneChan
cancelFunc()
require.NoError(t, <-errC)
})
}
81 changes: 39 additions & 42 deletions cli/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import (
"os"
"runtime"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -45,12 +44,11 @@ func TestServer(t *testing.T) {
require.NoError(t, err)
defer closeFunc()
ctx, cancelFunc := context.WithCancel(context.Background())
done := make(chan struct{})
defer cancelFunc()
root, cfg := clitest.New(t, "server", "--address", ":0", "--postgres-url", connectionURL)
errC := make(chan error)
go func() {
defer close(done)
err = root.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)
errC <- root.ExecuteContext(ctx)
}()
var client *codersdk.Client
require.Eventually(t, func() bool {
Expand All @@ -71,8 +69,9 @@ func TestServer(t *testing.T) {
})
require.NoError(t, err)
cancelFunc()
<-done
require.ErrorIs(t, <-errC, context.Canceled)
})

t.Run("Development", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
Expand All @@ -82,26 +81,12 @@ func TestServer(t *testing.T) {

root, cfg := clitest.New(t, "server", "--dev", "--tunnel=false", "--address", ":0")
var buf strings.Builder
errC := make(chan error)
root.SetOutput(&buf)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

err := root.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)

// Verify that credentials were output to the terminal.
assert.Contains(t, buf.String(), fmt.Sprintf("email: %s", wantEmail), "expected output %q; got no match", wantEmail)
// Check that the password line is output and that it's non-empty.
if _, after, found := strings.Cut(buf.String(), "password: "); found {
before, _, _ := strings.Cut(after, "\n")
before = strings.Trim(before, "\r") // Ensure no control character is left.
assert.NotEmpty(t, before, "expected non-empty password; got empty")
} else {
t.Error("expected password line output; got no match")
}
errC <- root.ExecuteContext(ctx)
}()

var token string
require.Eventually(t, func() bool {
var err error
Expand All @@ -119,8 +104,20 @@ func TestServer(t *testing.T) {
require.NoError(t, err)

cancelFunc()
wg.Wait()
require.ErrorIs(t, <-errC, context.Canceled)

// Verify that credentials were output to the terminal.
assert.Contains(t, buf.String(), fmt.Sprintf("email: %s", wantEmail), "expected output %q; got no match", wantEmail)
// Check that the password line is output and that it's non-empty.
if _, after, found := strings.Cut(buf.String(), "password: "); found {
before, _, _ := strings.Cut(after, "\n")
before = strings.Trim(before, "\r") // Ensure no control character is left.
assert.NotEmpty(t, before, "expected non-empty password; got empty")
} else {
t.Error("expected password line output; got no match")
}
})

// Duplicated test from "Development" above to test setting email/password via env.
// Cannot run parallel due to os.Setenv.
//nolint:paralleltest
Expand All @@ -136,18 +133,11 @@ func TestServer(t *testing.T) {
root, cfg := clitest.New(t, "server", "--dev", "--tunnel=false", "--address", ":0")
var buf strings.Builder
root.SetOutput(&buf)
var wg sync.WaitGroup
wg.Add(1)
errC := make(chan error)
go func() {
defer wg.Done()

err := root.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)

// Verify that credentials were output to the terminal.
assert.Contains(t, buf.String(), fmt.Sprintf("email: %s", wantEmail), "expected output %q; got no match", wantEmail)
assert.Contains(t, buf.String(), fmt.Sprintf("password: %s", wantPassword), "expected output %q; got no match", wantPassword)
errC <- root.ExecuteContext(ctx)
}()

var token string
require.Eventually(t, func() bool {
var err error
Expand All @@ -165,8 +155,12 @@ func TestServer(t *testing.T) {
require.NoError(t, err)

cancelFunc()
wg.Wait()
require.ErrorIs(t, <-errC, context.Canceled)
// Verify that credentials were output to the terminal.
assert.Contains(t, buf.String(), fmt.Sprintf("email: %s", wantEmail), "expected output %q; got no match", wantEmail)
assert.Contains(t, buf.String(), fmt.Sprintf("password: %s", wantPassword), "expected output %q; got no match", wantPassword)
})

t.Run("TLSBadVersion", func(t *testing.T) {
t.Parallel()
ctx, cancelFunc := context.WithCancel(context.Background())
Expand Down Expand Up @@ -202,10 +196,12 @@ func TestServer(t *testing.T) {
certPath, keyPath := generateTLSCertificate(t)
root, cfg := clitest.New(t, "server", "--dev", "--tunnel=false", "--address", ":0",
"--tls-enable", "--tls-cert-file", certPath, "--tls-key-file", keyPath)
errC := make(chan error)
go func() {
err := root.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)
errC <- root.ExecuteContext(ctx)
}()

// Verify HTTPS
var accessURLRaw string
require.Eventually(t, func() bool {
var err error
Expand All @@ -226,6 +222,9 @@ func TestServer(t *testing.T) {
}
_, err = client.HasFirstUser(ctx)
require.NoError(t, err)

cancelFunc()
require.ErrorIs(t, <-errC, context.Canceled)
})
// This cannot be ran in parallel because it uses a signal.
//nolint:paralleltest
Expand Down Expand Up @@ -284,14 +283,12 @@ func TestServer(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
root, _ := clitest.New(t, "server", "--dev", "--tunnel=false", "--address", ":0", "--trace=true")
done := make(chan struct{})
errC := make(chan error)
go func() {
defer close(done)
err := root.ExecuteContext(ctx)
require.ErrorIs(t, err, context.Canceled)
errC <- root.ExecuteContext(ctx)
}()
cancelFunc()
<-done
require.ErrorIs(t, <-errC, context.Canceled)
require.Error(t, goleak.Find())
})
}
Expand Down
9 changes: 2 additions & 7 deletions cli/templateinit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,11 @@ func TestTemplateInit(t *testing.T) {
t.Parallel()
tempDir := t.TempDir()
cmd, _ := clitest.New(t, "templates", "init", tempDir)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
<-doneChan
err := cmd.Execute()
require.NoError(t, err)
files, err := os.ReadDir(tempDir)
require.NoError(t, err)
require.Greater(t, len(files), 0)
Expand Down
8 changes: 3 additions & 5 deletions cli/userlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ func TestUserList(t *testing.T) {
coderdtest.CreateFirstUser(t, client)
cmd, root := clitest.New(t, "users", "list")
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
errC := make(chan error)
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
errC <- cmd.Execute()
}()
require.NoError(t, <-errC)
pty.ExpectMatch("coder.com")
<-doneChan
}