Skip to content

Commit 6612e3c

Browse files
authored
feat: Add config-ssh command (coder#735)
* feat: Add config-ssh command Closes coder#254 and coder#499. * Fix Windows support
1 parent 6ab1a68 commit 6612e3c

29 files changed

+554
-115
lines changed

.vscode/settings.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{
22
"cSpell.words": [
3+
"cliflag",
34
"cliui",
45
"coderd",
56
"coderdtest",

agent/agent.go

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,24 @@ type agent struct {
5656
sshServer *ssh.Server
5757
}
5858

59-
func (s *agent) run(ctx context.Context) {
59+
func (a *agent) run(ctx context.Context) {
6060
var peerListener *peerbroker.Listener
6161
var err error
6262
// An exponential back-off occurs when the connection is failing to dial.
6363
// This is to prevent server spam in case of a coderd outage.
6464
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
65-
peerListener, err = s.clientDialer(ctx, s.options)
65+
peerListener, err = a.clientDialer(ctx, a.options)
6666
if err != nil {
6767
if errors.Is(err, context.Canceled) {
6868
return
6969
}
70-
if s.isClosed() {
70+
if a.isClosed() {
7171
return
7272
}
73-
s.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
73+
a.options.Logger.Warn(context.Background(), "failed to dial", slog.Error(err))
7474
continue
7575
}
76-
s.options.Logger.Info(context.Background(), "connected")
76+
a.options.Logger.Info(context.Background(), "connected")
7777
break
7878
}
7979
select {
@@ -85,48 +85,48 @@ func (s *agent) run(ctx context.Context) {
8585
for {
8686
conn, err := peerListener.Accept()
8787
if err != nil {
88-
if s.isClosed() {
88+
if a.isClosed() {
8989
return
9090
}
91-
s.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
92-
s.run(ctx)
91+
a.options.Logger.Debug(ctx, "peer listener accept exited; restarting connection", slog.Error(err))
92+
a.run(ctx)
9393
return
9494
}
95-
s.closeMutex.Lock()
96-
s.connCloseWait.Add(1)
97-
s.closeMutex.Unlock()
98-
go s.handlePeerConn(ctx, conn)
95+
a.closeMutex.Lock()
96+
a.connCloseWait.Add(1)
97+
a.closeMutex.Unlock()
98+
go a.handlePeerConn(ctx, conn)
9999
}
100100
}
101101

102-
func (s *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
102+
func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
103103
go func() {
104104
<-conn.Closed()
105-
s.connCloseWait.Done()
105+
a.connCloseWait.Done()
106106
}()
107107
for {
108108
channel, err := conn.Accept(ctx)
109109
if err != nil {
110-
if errors.Is(err, peer.ErrClosed) || s.isClosed() {
110+
if errors.Is(err, peer.ErrClosed) || a.isClosed() {
111111
return
112112
}
113-
s.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
113+
a.options.Logger.Debug(ctx, "accept channel from peer connection", slog.Error(err))
114114
return
115115
}
116116

117117
switch channel.Protocol() {
118118
case "ssh":
119-
s.sshServer.HandleConn(channel.NetConn())
119+
a.sshServer.HandleConn(channel.NetConn())
120120
default:
121-
s.options.Logger.Warn(ctx, "unhandled protocol from channel",
121+
a.options.Logger.Warn(ctx, "unhandled protocol from channel",
122122
slog.F("protocol", channel.Protocol()),
123123
slog.F("label", channel.Label()),
124124
)
125125
}
126126
}
127127
}
128128

129-
func (s *agent) init(ctx context.Context) {
129+
func (a *agent) init(ctx context.Context) {
130130
// Clients' should ignore the host key when connecting.
131131
// The agent needs to authenticate with coderd to SSH,
132132
// so SSH authentication doesn't improve security.
@@ -138,17 +138,17 @@ func (s *agent) init(ctx context.Context) {
138138
if err != nil {
139139
panic(err)
140140
}
141-
sshLogger := s.options.Logger.Named("ssh-server")
141+
sshLogger := a.options.Logger.Named("ssh-server")
142142
forwardHandler := &ssh.ForwardedTCPHandler{}
143-
s.sshServer = &ssh.Server{
143+
a.sshServer = &ssh.Server{
144144
ChannelHandlers: ssh.DefaultChannelHandlers,
145145
ConnectionFailedCallback: func(conn net.Conn, err error) {
146146
sshLogger.Info(ctx, "ssh connection ended", slog.Error(err))
147147
},
148148
Handler: func(session ssh.Session) {
149-
err := s.handleSSHSession(session)
149+
err := a.handleSSHSession(session)
150150
if err != nil {
151-
s.options.Logger.Debug(ctx, "ssh session failed", slog.Error(err))
151+
a.options.Logger.Warn(ctx, "ssh session failed", slog.Error(err))
152152
_ = session.Exit(1)
153153
return
154154
}
@@ -177,35 +177,26 @@ func (s *agent) init(ctx context.Context) {
177177
},
178178
ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig {
179179
return &gossh.ServerConfig{
180-
Config: gossh.Config{
181-
// "arcfour" is the fastest SSH cipher. We prioritize throughput
182-
// over encryption here, because the WebRTC connection is already
183-
// encrypted. If possible, we'd disable encryption entirely here.
184-
Ciphers: []string{"arcfour"},
185-
},
186180
NoClientAuth: true,
187181
}
188182
},
189183
}
190184

191-
go s.run(ctx)
185+
go a.run(ctx)
192186
}
193187

194-
func (*agent) handleSSHSession(session ssh.Session) error {
188+
func (a *agent) handleSSHSession(session ssh.Session) error {
195189
var (
196190
command string
197191
args = []string{}
198192
err error
199193
)
200194

201-
username := session.User()
202-
if username == "" {
203-
currentUser, err := user.Current()
204-
if err != nil {
205-
return xerrors.Errorf("get current user: %w", err)
206-
}
207-
username = currentUser.Username
195+
currentUser, err := user.Current()
196+
if err != nil {
197+
return xerrors.Errorf("get current user: %w", err)
208198
}
199+
username := currentUser.Username
209200

210201
// gliderlabs/ssh returns a command slice of zero
211202
// when a shell is requested.
@@ -249,9 +240,9 @@ func (*agent) handleSSHSession(session ssh.Session) error {
249240
}
250241
go func() {
251242
for win := range windowSize {
252-
err := ptty.Resize(uint16(win.Width), uint16(win.Height))
243+
err = ptty.Resize(uint16(win.Width), uint16(win.Height))
253244
if err != nil {
254-
panic(err)
245+
a.options.Logger.Warn(context.Background(), "failed to resize tty", slog.Error(err))
255246
}
256247
}
257248
}()
@@ -286,24 +277,24 @@ func (*agent) handleSSHSession(session ssh.Session) error {
286277
}
287278

288279
// isClosed returns whether the API is closed or not.
289-
func (s *agent) isClosed() bool {
280+
func (a *agent) isClosed() bool {
290281
select {
291-
case <-s.closed:
282+
case <-a.closed:
292283
return true
293284
default:
294285
return false
295286
}
296287
}
297288

298-
func (s *agent) Close() error {
299-
s.closeMutex.Lock()
300-
defer s.closeMutex.Unlock()
301-
if s.isClosed() {
289+
func (a *agent) Close() error {
290+
a.closeMutex.Lock()
291+
defer a.closeMutex.Unlock()
292+
if a.isClosed() {
302293
return nil
303294
}
304-
close(s.closed)
305-
s.closeCancel()
306-
_ = s.sshServer.Close()
307-
s.connCloseWait.Wait()
295+
close(a.closed)
296+
a.closeCancel()
297+
_ = a.sshServer.Close()
298+
a.connCloseWait.Wait()
308299
return nil
309300
}

agent/conn.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,6 @@ func (c *Conn) SSHClient() (*ssh.Client, error) {
3939
return nil, xerrors.Errorf("ssh: %w", err)
4040
}
4141
sshConn, channels, requests, err := ssh.NewClientConn(netConn, "localhost:22", &ssh.ClientConfig{
42-
Config: ssh.Config{
43-
Ciphers: []string{"arcfour"},
44-
},
4542
// SSH host validation isn't helpful, because obtaining a peer
4643
// connection already signifies user-intent to dial a workspace.
4744
// #nosec

agent/usershell/usershell_other.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ func Get(username string) (string, error) {
2727
}
2828
return parts[6], nil
2929
}
30-
return "", xerrors.New("user not found in /etc/passwd and $SHELL not set")
30+
return "", xerrors.Errorf("user %q not found in /etc/passwd", username)
3131
}

cli/cliui/agent.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package cliui
33
import (
44
"context"
55
"fmt"
6+
"io"
67
"sync"
78
"time"
89

910
"github.com/briandowns/spinner"
10-
"github.com/spf13/cobra"
1111
"golang.org/x/xerrors"
1212

1313
"github.com/coder/coder/codersdk"
@@ -21,15 +21,15 @@ type AgentOptions struct {
2121
}
2222

2323
// Agent displays a spinning indicator that waits for a workspace agent to connect.
24-
func Agent(cmd *cobra.Command, opts AgentOptions) error {
24+
func Agent(ctx context.Context, writer io.Writer, opts AgentOptions) error {
2525
if opts.FetchInterval == 0 {
2626
opts.FetchInterval = 500 * time.Millisecond
2727
}
2828
if opts.WarnInterval == 0 {
2929
opts.WarnInterval = 30 * time.Second
3030
}
3131
var resourceMutex sync.Mutex
32-
resource, err := opts.Fetch(cmd.Context())
32+
resource, err := opts.Fetch(ctx)
3333
if err != nil {
3434
return xerrors.Errorf("fetch: %w", err)
3535
}
@@ -40,7 +40,8 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
4040
opts.WarnInterval = 0
4141
}
4242
spin := spinner.New(spinner.CharSets[78], 100*time.Millisecond, spinner.WithColor("fgHiGreen"))
43-
spin.Writer = cmd.OutOrStdout()
43+
spin.Writer = writer
44+
spin.ForceOutput = true
4445
spin.Suffix = " Waiting for connection from " + Styles.Field.Render(resource.Type+"."+resource.Name) + "..."
4546
spin.Start()
4647
defer spin.Stop()
@@ -51,7 +52,7 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
5152
defer timer.Stop()
5253
go func() {
5354
select {
54-
case <-cmd.Context().Done():
55+
case <-ctx.Done():
5556
return
5657
case <-timer.C:
5758
}
@@ -63,17 +64,17 @@ func Agent(cmd *cobra.Command, opts AgentOptions) error {
6364
}
6465
// This saves the cursor position, then defers clearing from the cursor
6566
// position to the end of the screen.
66-
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message))
67-
defer fmt.Fprintf(cmd.OutOrStdout(), "\033[u\033[J")
67+
_, _ = fmt.Fprintf(writer, "\033[s\r\033[2K%s\n\n", Styles.Paragraph.Render(Styles.Prompt.String()+message))
68+
defer fmt.Fprintf(writer, "\033[u\033[J")
6869
}()
6970
for {
7071
select {
71-
case <-cmd.Context().Done():
72-
return cmd.Context().Err()
72+
case <-ctx.Done():
73+
return ctx.Err()
7374
case <-ticker.C:
7475
}
7576
resourceMutex.Lock()
76-
resource, err = opts.Fetch(cmd.Context())
77+
resource, err = opts.Fetch(ctx)
7778
if err != nil {
7879
return xerrors.Errorf("fetch: %w", err)
7980
}

cli/cliui/agent_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ func TestAgent(t *testing.T) {
2020
ptty := ptytest.New(t)
2121
cmd := &cobra.Command{
2222
RunE: func(cmd *cobra.Command, args []string) error {
23-
err := cliui.Agent(cmd, cliui.AgentOptions{
23+
err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{
2424
WorkspaceName: "example",
2525
Fetch: func(ctx context.Context) (codersdk.WorkspaceResource, error) {
2626
resource := codersdk.WorkspaceResource{

cli/cliui/log.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package cliui
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"strings"
7+
8+
"github.com/charmbracelet/lipgloss"
9+
)
10+
11+
// cliMessage provides a human-readable message for CLI errors and messages.
12+
type cliMessage struct {
13+
Level string
14+
Style lipgloss.Style
15+
Header string
16+
Lines []string
17+
}
18+
19+
// String formats the CLI message for consumption by a human.
20+
func (m cliMessage) String() string {
21+
var str strings.Builder
22+
_, _ = fmt.Fprintf(&str, "%s\r\n",
23+
Styles.Bold.Render(m.Header))
24+
for _, line := range m.Lines {
25+
_, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line)
26+
}
27+
return str.String()
28+
}
29+
30+
// Warn writes a log to the writer provided.
31+
func Warn(wtr io.Writer, header string, lines ...string) {
32+
_, _ = fmt.Fprint(wtr, cliMessage{
33+
Level: "warning",
34+
Style: Styles.Warn,
35+
Header: header,
36+
Lines: lines,
37+
}.String())
38+
}

cli/cliui/prompt.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package cliui
22

33
import (
44
"bufio"
5+
"bytes"
56
"encoding/json"
67
"fmt"
78
"io"
@@ -62,7 +63,11 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) {
6263
var rawMessage json.RawMessage
6364
err := json.NewDecoder(pipeReader).Decode(&rawMessage)
6465
if err == nil {
65-
line = string(rawMessage)
66+
var buf bytes.Buffer
67+
err = json.Compact(&buf, rawMessage)
68+
if err == nil {
69+
line = buf.String()
70+
}
6671
}
6772
}
6873
}

cli/cliui/prompt_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ func TestPrompt(t *testing.T) {
9393
ptty.WriteLine(`{
9494
"test": "wow"
9595
}`)
96-
require.Equal(t, `{
97-
"test": "wow"
98-
}`, <-doneChan)
96+
require.Equal(t, `{"test":"wow"}`, <-doneChan)
9997
})
10098
}
10199

0 commit comments

Comments
 (0)