Skip to content
Merged
Prev Previous commit
Next Next commit
feat: Add config-ssh and tests for resiliency
  • Loading branch information
kylecarbs committed Mar 27, 2022
commit 2152758b8753d29d8001d91d0428f1e52eedad98
10 changes: 8 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
client := agent.Conn{conn}
client := agent.Conn{
Negotiator: api,
Conn: conn,
}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
Expand All @@ -65,7 +68,10 @@ func TestAgent(t *testing.T) {
t.Cleanup(func() {
_ = conn.Close()
})
client := &agent.Conn{conn}
client := &agent.Conn{
Negotiator: api,
Conn: conn,
}
sshClient, err := client.SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
Expand Down
9 changes: 9 additions & 0 deletions agent/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ import (
"golang.org/x/xerrors"

"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker/proto"
)

// Conn wraps a peer connection with helper functions to
// communicate with the agent.
type Conn struct {
// Negotiator is responsible for exchanging messages.
Negotiator proto.DRPCPeerBrokerClient

*peer.Conn
}

Expand Down Expand Up @@ -48,3 +52,8 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
}
return ssh.NewClient(sshConn, channels, requests), nil
}

func (c *Conn) Close() error {
_ = c.Negotiator.DRPCConn().Close()
return c.Conn.Close()
}
2 changes: 1 addition & 1 deletion cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func Root() *cobra.Command {
projects(),
users(),
workspaces(),
workspaceSSH(),
ssh(),
workspaceTunnel(),
)

Expand Down
20 changes: 9 additions & 11 deletions cli/ssh.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
package cli

import (
"os"

"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/xerrors"

"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
)

func workspaceSSH() *cobra.Command {
func ssh() *cobra.Command {
cmd := &cobra.Command{
Use: "ssh <workspace> [resource]",
RunE: func(cmd *cobra.Command, args []string) error {
Expand Down Expand Up @@ -68,6 +65,7 @@ func workspaceSSH() *cobra.Command {
if err != nil {
return err
}
defer conn.Close()
sshClient, err := conn.SSHClient()
if err != nil {
return err
Expand All @@ -77,16 +75,16 @@ func workspaceSSH() *cobra.Command {
if err != nil {
return err
}
_, _ = term.MakeRaw(int(os.Stdin.Fd()))
err = sshSession.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{
ssh.OCRNL: 1,

err = sshSession.RequestPty("xterm-256color", 128, 128, gossh.TerminalModes{
gossh.OCRNL: 1,
})
if err != nil {
return err
}
sshSession.Stdin = os.Stdin
sshSession.Stdout = os.Stdout
sshSession.Stderr = os.Stderr
sshSession.Stdin = cmd.InOrStdin()
sshSession.Stdout = cmd.OutOrStdout()
sshSession.Stderr = cmd.OutOrStdout()
err = sshSession.Shell()
if err != nil {
return err
Expand Down
77 changes: 77 additions & 0 deletions cli/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package cli_test

import (
"testing"

"github.com/google/uuid"
"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/agent"
"github.com/coder/coder/cli/clitest"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/peer"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/pty/ptytest"
)

func TestSSH(t *testing.T) {
t.Parallel()
t.Run("Echo", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
daemonCloser := coderdtest.NewProvisionerDaemon(t, client)
agentToken := uuid.NewString()
version := coderdtest.CreateProjectVersion(t, client, user.OrganizationID, &echo.Responses{
Parse: echo.ParseComplete,
ProvisionDryRun: echo.ProvisionComplete,
Provision: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "dev",
Type: "google_compute_instance",
Agent: &proto.Agent{
Id: uuid.NewString(),
Auth: &proto.Agent_Token{
Token: agentToken,
},
},
}},
},
},
}},
})
coderdtest.AwaitProjectVersionJob(t, client, version.ID)
project := coderdtest.CreateProject(t, client, user.OrganizationID, version.ID)
workspace := coderdtest.CreateWorkspace(t, client, "", project.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
daemonCloser.Close()
agentClient := codersdk.New(client.URL)
agentClient.SessionToken = agentToken
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug),
})
defer agentCloser.Close()
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)

cmd, root := clitest.New(t, "ssh", workspace.Name)
clitest.SetupConfig(t, client, root)
doneChan := make(chan struct{})
pty := ptytest.New(t)
cmd.SetIn(pty.Input())
cmd.SetOut(pty.Output())
go func() {
defer close(doneChan)
err := cmd.Execute()
require.NoError(t, err)
}()
// Shells on Mac, Windows, and Linux all exit shells with the "exit" command.
pty.WriteLine("exit")
<-doneChan
})
}
2 changes: 1 addition & 1 deletion cli/workspaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func workspaces() *cobra.Command {
cmd.AddCommand(workspaceShow())
cmd.AddCommand(workspaceStop())
cmd.AddCommand(workspaceStart())
cmd.AddCommand(workspaceSSH())
cmd.AddCommand(ssh())
cmd.AddCommand(workspaceUpdate())

return cmd
Expand Down
1 change: 1 addition & 0 deletions coderd/coderdtest/coderdtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ func AwaitWorkspaceAgents(t *testing.T, client *codersdk.Client, build uuid.UUID
if resource.Agent == nil {
continue
}
// fmt.Printf("resources: %+v\n", resource.Agent)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a debug line

if resource.Agent.FirstConnectedAt == nil {
return false
}
Expand Down
10 changes: 2 additions & 8 deletions codersdk/workspaceresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,9 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, resource uuid.UUID, ice
if err != nil {
return nil, xerrors.Errorf("dial peer: %w", err)
}
go func() {
// The stream is kept alive to renegotiate the RTC connection
// if need-be. The calling context can be canceled to end
// the negotiation stream, but not the peer connection.
<-peerConn.Closed()
_ = conn.Close(websocket.StatusNormalClosure, "")
}()
return &agent.Conn{
Conn: peerConn,
Negotiator: client,
Conn: peerConn,
}, nil
}

Expand Down
8 changes: 8 additions & 0 deletions peerbroker/dial.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package peerbroker

import (
"context"
"errors"
"io"
"reflect"

"github.com/pion/webrtc/v3"
Expand Down Expand Up @@ -54,6 +57,11 @@ func Dial(stream proto.DRPCPeerBroker_NegotiateConnectionClient, iceServers []we
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
// p2p connections should never die if this stream does due
// to proper closure or context cancellation!
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return
}
_ = peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
return
}
Expand Down
8 changes: 4 additions & 4 deletions peerbroker/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,10 @@ func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_Nego
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
// p2p connections should never die if this stream does due
// to proper closure or context cancellation!
if errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) {
return nil
}
return peerConn.CloseWithError(xerrors.Errorf("recv: %w", err))
}
Expand All @@ -186,6 +188,4 @@ func (b *peerBrokerService) NegotiateConnection(stream proto.DRPCPeerBroker_Nego
return peerConn.CloseWithError(xerrors.Errorf("unhandled message: %s", reflect.TypeOf(clientToServerMessage).String()))
}
}

return nil
}