Skip to content

Commit 735dc5d

Browse files
feat(agent): add second SSH listener on port 22 (cherry-pick #16627) (#16763)
Cherry-picked feat(agent): add second SSH listener on port 22 (#16627) Fixes: coder/internal#377 Added an additional SSH listener on port 22, so the agent now listens on both, port one and port 22. --- Change-Id: Ifd986b260f8ac317e37d65111cd4e0bd1dc38af8 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent 114cf57 commit 735dc5d

File tree

6 files changed

+153
-95
lines changed

6 files changed

+153
-95
lines changed

agent/agent.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -1193,19 +1193,22 @@ func (a *agent) createTailnet(
11931193
return nil, xerrors.Errorf("update host signer: %w", err)
11941194
}
11951195

1196-
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
1197-
if err != nil {
1198-
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
1199-
}
1200-
defer func() {
1196+
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
1197+
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
12011198
if err != nil {
1202-
_ = sshListener.Close()
1199+
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
1200+
}
1201+
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
1202+
defer func() {
1203+
if err != nil {
1204+
_ = sshListener.Close()
1205+
}
1206+
}()
1207+
if err = a.trackGoroutine(func() {
1208+
_ = a.sshServer.Serve(sshListener)
1209+
}); err != nil {
1210+
return nil, err
12031211
}
1204-
}()
1205-
if err = a.trackGoroutine(func() {
1206-
_ = a.sshServer.Serve(sshListener)
1207-
}); err != nil {
1208-
return nil, err
12091212
}
12101213

12111214
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))

agent/agent_test.go

+120-79
Original file line numberDiff line numberDiff line change
@@ -61,38 +61,48 @@ func TestMain(m *testing.M) {
6161
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
6262
}
6363

64+
var sshPorts = []uint16{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort}
65+
6466
// NOTE: These tests only work when your default shell is bash for some reason.
6567

6668
func TestAgent_Stats_SSH(t *testing.T) {
6769
t.Parallel()
68-
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
69-
defer cancel()
7070

71-
//nolint:dogsled
72-
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
71+
for _, port := range sshPorts {
72+
port := port
73+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
74+
t.Parallel()
7375

74-
sshClient, err := conn.SSHClient(ctx)
75-
require.NoError(t, err)
76-
defer sshClient.Close()
77-
session, err := sshClient.NewSession()
78-
require.NoError(t, err)
79-
defer session.Close()
80-
stdin, err := session.StdinPipe()
81-
require.NoError(t, err)
82-
err = session.Shell()
83-
require.NoError(t, err)
76+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
77+
defer cancel()
8478

85-
var s *proto.Stats
86-
require.Eventuallyf(t, func() bool {
87-
var ok bool
88-
s, ok = <-stats
89-
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
90-
}, testutil.WaitLong, testutil.IntervalFast,
91-
"never saw stats: %+v", s,
92-
)
93-
_ = stdin.Close()
94-
err = session.Wait()
95-
require.NoError(t, err)
79+
//nolint:dogsled
80+
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
81+
82+
sshClient, err := conn.SSHClientOnPort(ctx, port)
83+
require.NoError(t, err)
84+
defer sshClient.Close()
85+
session, err := sshClient.NewSession()
86+
require.NoError(t, err)
87+
defer session.Close()
88+
stdin, err := session.StdinPipe()
89+
require.NoError(t, err)
90+
err = session.Shell()
91+
require.NoError(t, err)
92+
93+
var s *proto.Stats
94+
require.Eventuallyf(t, func() bool {
95+
var ok bool
96+
s, ok = <-stats
97+
return ok && s.ConnectionCount > 0 && s.RxBytes > 0 && s.TxBytes > 0 && s.SessionCountSsh == 1
98+
}, testutil.WaitLong, testutil.IntervalFast,
99+
"never saw stats: %+v", s,
100+
)
101+
_ = stdin.Close()
102+
err = session.Wait()
103+
require.NoError(t, err)
104+
})
105+
}
96106
}
97107

98108
func TestAgent_Stats_ReconnectingPTY(t *testing.T) {
@@ -266,15 +276,23 @@ func TestAgent_Stats_Magic(t *testing.T) {
266276

267277
func TestAgent_SessionExec(t *testing.T) {
268278
t.Parallel()
269-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
270279

271-
command := "echo test"
272-
if runtime.GOOS == "windows" {
273-
command = "cmd.exe /c echo test"
280+
for _, port := range sshPorts {
281+
port := port
282+
t.Run(fmt.Sprintf("(:%d)", port), func(t *testing.T) {
283+
t.Parallel()
284+
285+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
286+
287+
command := "echo test"
288+
if runtime.GOOS == "windows" {
289+
command = "cmd.exe /c echo test"
290+
}
291+
output, err := session.Output(command)
292+
require.NoError(t, err)
293+
require.Equal(t, "test", strings.TrimSpace(string(output)))
294+
})
274295
}
275-
output, err := session.Output(command)
276-
require.NoError(t, err)
277-
require.Equal(t, "test", strings.TrimSpace(string(output)))
278296
}
279297

280298
//nolint:tparallel // Sub tests need to run sequentially.
@@ -384,25 +402,33 @@ func TestAgent_SessionTTYShell(t *testing.T) {
384402
// it seems like it could be either.
385403
t.Skip("ConPTY appears to be inconsistent on Windows.")
386404
}
387-
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
388-
command := "sh"
389-
if runtime.GOOS == "windows" {
390-
command = "cmd.exe"
405+
406+
for _, port := range sshPorts {
407+
port := port
408+
t.Run(fmt.Sprintf("(%d)", port), func(t *testing.T) {
409+
t.Parallel()
410+
411+
session := setupSSHSessionOnPort(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil, port)
412+
command := "sh"
413+
if runtime.GOOS == "windows" {
414+
command = "cmd.exe"
415+
}
416+
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
417+
require.NoError(t, err)
418+
ptty := ptytest.New(t)
419+
session.Stdout = ptty.Output()
420+
session.Stderr = ptty.Output()
421+
session.Stdin = ptty.Input()
422+
err = session.Start(command)
423+
require.NoError(t, err)
424+
_ = ptty.Peek(ctx, 1) // wait for the prompt
425+
ptty.WriteLine("echo test")
426+
ptty.ExpectMatch("test")
427+
ptty.WriteLine("exit")
428+
err = session.Wait()
429+
require.NoError(t, err)
430+
})
391431
}
392-
err := session.RequestPty("xterm", 128, 128, ssh.TerminalModes{})
393-
require.NoError(t, err)
394-
ptty := ptytest.New(t)
395-
session.Stdout = ptty.Output()
396-
session.Stderr = ptty.Output()
397-
session.Stdin = ptty.Input()
398-
err = session.Start(command)
399-
require.NoError(t, err)
400-
_ = ptty.Peek(ctx, 1) // wait for the prompt
401-
ptty.WriteLine("echo test")
402-
ptty.ExpectMatch("test")
403-
ptty.WriteLine("exit")
404-
err = session.Wait()
405-
require.NoError(t, err)
406432
}
407433

408434
func TestAgent_SessionTTYExitCode(t *testing.T) {
@@ -596,37 +622,41 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
596622
//nolint:dogsled // Allow the blank identifiers.
597623
conn, client, _, _, _ := setupAgent(t, agentsdk.Manifest{}, 0, setSBInterval)
598624

599-
sshClient, err := conn.SSHClient(ctx)
600-
require.NoError(t, err)
601-
t.Cleanup(func() {
602-
_ = sshClient.Close()
603-
})
604-
605625
//nolint:paralleltest // These tests need to swap the banner func.
606-
for i, test := range tests {
607-
test := test
608-
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
609-
// Set new banner func and wait for the agent to call it to update the
610-
// banner.
611-
ready := make(chan struct{}, 2)
612-
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
613-
select {
614-
case ready <- struct{}{}:
615-
default:
616-
}
617-
return []codersdk.BannerConfig{test.banner}, nil
618-
})
619-
<-ready
620-
<-ready // Wait for two updates to ensure the value has propagated.
621-
622-
session, err := sshClient.NewSession()
623-
require.NoError(t, err)
624-
t.Cleanup(func() {
625-
_ = session.Close()
626-
})
626+
for _, port := range sshPorts {
627+
port := port
627628

628-
testSessionOutput(t, session, test.expected, test.unexpected, nil)
629+
sshClient, err := conn.SSHClientOnPort(ctx, port)
630+
require.NoError(t, err)
631+
t.Cleanup(func() {
632+
_ = sshClient.Close()
629633
})
634+
635+
for i, test := range tests {
636+
test := test
637+
t.Run(fmt.Sprintf("(:%d)/%d", port, i), func(t *testing.T) {
638+
// Set new banner func and wait for the agent to call it to update the
639+
// banner.
640+
ready := make(chan struct{}, 2)
641+
client.SetAnnouncementBannersFunc(func() ([]codersdk.BannerConfig, error) {
642+
select {
643+
case ready <- struct{}{}:
644+
default:
645+
}
646+
return []codersdk.BannerConfig{test.banner}, nil
647+
})
648+
<-ready
649+
<-ready // Wait for two updates to ensure the value has propagated.
650+
651+
session, err := sshClient.NewSession()
652+
require.NoError(t, err)
653+
t.Cleanup(func() {
654+
_ = session.Close()
655+
})
656+
657+
testSessionOutput(t, session, test.expected, test.unexpected, nil)
658+
})
659+
}
630660
}
631661
}
632662

@@ -2313,6 +2343,17 @@ func setupSSHSession(
23132343
banner codersdk.BannerConfig,
23142344
prepareFS func(fs afero.Fs),
23152345
opts ...func(*agenttest.Client, *agent.Options),
2346+
) *ssh.Session {
2347+
return setupSSHSessionOnPort(t, manifest, banner, prepareFS, workspacesdk.AgentSSHPort, opts...)
2348+
}
2349+
2350+
func setupSSHSessionOnPort(
2351+
t *testing.T,
2352+
manifest agentsdk.Manifest,
2353+
banner codersdk.BannerConfig,
2354+
prepareFS func(fs afero.Fs),
2355+
port uint16,
2356+
opts ...func(*agenttest.Client, *agent.Options),
23162357
) *ssh.Session {
23172358
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
23182359
defer cancel()
@@ -2326,7 +2367,7 @@ func setupSSHSession(
23262367
if prepareFS != nil {
23272368
prepareFS(fs)
23282369
}
2329-
sshClient, err := conn.SSHClient(ctx)
2370+
sshClient, err := conn.SSHClientOnPort(ctx, port)
23302371
require.NoError(t, err)
23312372
t.Cleanup(func() {
23322373
_ = sshClient.Close()

agent/usershell/usershell_darwin.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func Get(username string) (string, error) {
1717
return "", xerrors.Errorf("username is nonlocal path: %s", username)
1818
}
1919
//nolint: gosec // input checked above
20-
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output()
20+
out, _ := exec.Command("dscl", ".", "-read", filepath.Join("/Users", username), "UserShell").Output() //nolint:gocritic
2121
s, ok := strings.CutPrefix(string(out), "UserShell: ")
2222
if ok {
2323
return strings.TrimSpace(s), nil

codersdk/workspacesdk/agentconn.go

+15-3
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,36 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w
143143
// SSH pipes the SSH protocol over the returned net.Conn.
144144
// This connects to the built-in SSH server in the workspace agent.
145145
func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) {
146+
return c.SSHOnPort(ctx, AgentSSHPort)
147+
}
148+
149+
// SSHOnPort pipes the SSH protocol over the returned net.Conn.
150+
// This connects to the built-in SSH server in the workspace agent on the specified port.
151+
func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) {
146152
ctx, span := tracing.StartSpan(ctx)
147153
defer span.End()
148154

149155
if !c.AwaitReachable(ctx) {
150156
return nil, xerrors.Errorf("workspace agent not reachable in time: %v", ctx.Err())
151157
}
152158

153-
c.Conn.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
154-
return c.Conn.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), AgentSSHPort))
159+
c.SendConnectedTelemetry(c.agentAddress(), tailnet.TelemetryApplicationSSH)
160+
return c.DialContextTCP(ctx, netip.AddrPortFrom(c.agentAddress(), port))
155161
}
156162

157163
// SSHClient calls SSH to create a client that uses a weak cipher
158164
// to improve throughput.
159165
func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) {
166+
return c.SSHClientOnPort(ctx, AgentSSHPort)
167+
}
168+
169+
// SSHClientOnPort calls SSH to create a client on a specific port
170+
// that uses a weak cipher to improve throughput.
171+
func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) {
160172
ctx, span := tracing.StartSpan(ctx)
161173
defer span.End()
162174

163-
netConn, err := c.SSH(ctx)
175+
netConn, err := c.SSHOnPort(ctx, port)
164176
if err != nil {
165177
return nil, xerrors.Errorf("ssh: %w", err)
166178
}

codersdk/workspacesdk/workspacesdk.go

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var ErrSkipClose = xerrors.New("skip tailnet close")
2929

3030
const (
3131
AgentSSHPort = tailnet.WorkspaceAgentSSHPort
32+
AgentStandardSSHPort = tailnet.WorkspaceAgentStandardSSHPort
3233
AgentReconnectingPTYPort = tailnet.WorkspaceAgentReconnectingPTYPort
3334
AgentSpeedtestPort = tailnet.WorkspaceAgentSpeedtestPort
3435
// AgentHTTPAPIServerPort serves a HTTP server with endpoints for e.g.

tailnet/conn.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ const (
5252
WorkspaceAgentSSHPort = 1
5353
WorkspaceAgentReconnectingPTYPort = 2
5454
WorkspaceAgentSpeedtestPort = 3
55+
WorkspaceAgentStandardSSHPort = 22
5556
)
5657

5758
// EnvMagicsockDebugLogging enables super-verbose logging for the magicsock
@@ -745,7 +746,7 @@ func (c *Conn) forwardTCP(src, dst netip.AddrPort) (handler func(net.Conn), opts
745746
return nil, nil, false
746747
}
747748
// See: https://github.com/tailscale/tailscale/blob/c7cea825aea39a00aca71ea02bab7266afc03e7c/wgengine/netstack/netstack.go#L888
748-
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == 22 {
749+
if dst.Port() == WorkspaceAgentSSHPort || dst.Port() == WorkspaceAgentStandardSSHPort {
749750
opt := tcpip.KeepaliveIdleOption(72 * time.Hour)
750751
opts = append(opts, &opt)
751752
}

0 commit comments

Comments
 (0)