Skip to content

Commit 6607464

Browse files
authored
fix(agent/agentssh): use deterministic host key for SSH server (coder#16626)
Fixes: coder#16490 The Agent's SSH server now initially generates fixed host keys and, once it receives its manifest, generates and replaces that host key with the one derived from the workspace ID, ensuring consistency across agent restarts. This prevents SSH warnings and host key verification errors when connecting to workspaces through Coder Desktop. While deterministic keys might seem insecure, the underlying Wireguard tunnel already provides encryption and anti-spoofing protection at the network layer, making this approach acceptable for our use case. --- Change-Id: I8c7e3070324e5d558374fd6891eea9d48660e1e9 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent e8a7b7e commit 6607464

File tree

6 files changed

+226
-17
lines changed

6 files changed

+226
-17
lines changed

agent/agent.go

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"hash/fnv"
910
"io"
1011
"net/http"
1112
"net/netip"
@@ -994,7 +995,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
994995
if err := manifestOK.wait(ctx); err != nil {
995996
return xerrors.Errorf("no manifest: %w", err)
996997
}
997-
var err error
998998
defer func() {
999999
networkOK.complete(retErr)
10001000
}()
@@ -1003,9 +1003,20 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
10031003
network := a.network
10041004
a.closeMutex.Unlock()
10051005
if network == nil {
1006+
keySeed, err := WorkspaceKeySeed(manifest.WorkspaceID, manifest.AgentName)
1007+
if err != nil {
1008+
return xerrors.Errorf("generate seed from workspace id: %w", err)
1009+
}
10061010
// use the graceful context here, because creating the tailnet is not itself tied to the
10071011
// agent API.
1008-
network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections)
1012+
network, err = a.createTailnet(
1013+
a.gracefulCtx,
1014+
manifest.AgentID,
1015+
manifest.DERPMap,
1016+
manifest.DERPForceWebSockets,
1017+
manifest.DisableDirectConnections,
1018+
keySeed,
1019+
)
10091020
if err != nil {
10101021
return xerrors.Errorf("create tailnet: %w", err)
10111022
}
@@ -1145,7 +1156,13 @@ func (a *agent) trackGoroutine(fn func()) error {
11451156
return nil
11461157
}
11471158

1148-
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, derpForceWebSockets, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
1159+
func (a *agent) createTailnet(
1160+
ctx context.Context,
1161+
agentID uuid.UUID,
1162+
derpMap *tailcfg.DERPMap,
1163+
derpForceWebSockets, disableDirectConnections bool,
1164+
keySeed int64,
1165+
) (_ *tailnet.Conn, err error) {
11491166
// Inject `CODER_AGENT_HEADER` into the DERP header.
11501167
var header http.Header
11511168
if client, ok := a.client.(*agentsdk.Client); ok {
@@ -1172,6 +1189,10 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
11721189
}
11731190
}()
11741191

1192+
if err := a.sshServer.UpdateHostSigner(keySeed); err != nil {
1193+
return nil, xerrors.Errorf("update host signer: %w", err)
1194+
}
1195+
11751196
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
11761197
if err != nil {
11771198
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
@@ -1849,3 +1870,20 @@ func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger sl
18491870
}
18501871
})
18511872
}
1873+
1874+
// WorkspaceKeySeed converts a WorkspaceID UUID and agent name to an int64 hash.
1875+
// This uses the FNV-1a hash algorithm which provides decent distribution and collision
1876+
// resistance for string inputs.
1877+
func WorkspaceKeySeed(workspaceID uuid.UUID, agentName string) (int64, error) {
1878+
h := fnv.New64a()
1879+
_, err := h.Write(workspaceID[:])
1880+
if err != nil {
1881+
return 42, err
1882+
}
1883+
_, err = h.Write([]byte(agentName))
1884+
if err != nil {
1885+
return 42, err
1886+
}
1887+
1888+
return int64(h.Sum64()), nil
1889+
}

agent/agentssh/agentssh.go

Lines changed: 108 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package agentssh
33
import (
44
"bufio"
55
"context"
6-
"crypto/rand"
76
"crypto/rsa"
87
"errors"
98
"fmt"
109
"io"
10+
"math/big"
11+
"math/rand"
1112
"net"
1213
"os"
1314
"os/exec"
@@ -128,17 +129,6 @@ type Server struct {
128129
}
129130

130131
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
131-
// Clients' should ignore the host key when connecting.
132-
// The agent needs to authenticate with coderd to SSH,
133-
// so SSH authentication doesn't improve security.
134-
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
135-
if err != nil {
136-
return nil, err
137-
}
138-
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
139-
if err != nil {
140-
return nil, err
141-
}
142132
if config == nil {
143133
config = &Config{}
144134
}
@@ -205,8 +195,10 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
205195
slog.F("local_addr", conn.LocalAddr()),
206196
slog.Error(err))
207197
},
208-
Handler: s.sessionHandler,
209-
HostSigners: []ssh.Signer{randomSigner},
198+
Handler: s.sessionHandler,
199+
// HostSigners are intentionally empty, as the host key will
200+
// be set before we start listening.
201+
HostSigners: []ssh.Signer{},
210202
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
211203
// Allow local port forwarding all!
212204
s.logger.Debug(ctx, "local port forward",
@@ -844,7 +836,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
844836
return cmd, nil
845837
}
846838

839+
// Serve starts the server to handle incoming connections on the provided listener.
840+
// It returns an error if no host keys are set or if there is an issue accepting connections.
847841
func (s *Server) Serve(l net.Listener) (retErr error) {
842+
if len(s.srv.HostSigners) == 0 {
843+
return xerrors.New("no host keys set")
844+
}
845+
848846
s.logger.Info(context.Background(), "started serving listener", slog.F("listen_addr", l.Addr()))
849847
defer func() {
850848
s.logger.Info(context.Background(), "stopped serving listener",
@@ -1099,3 +1097,99 @@ func userHomeDir() (string, error) {
10991097
}
11001098
return u.HomeDir, nil
11011099
}
1100+
1101+
// UpdateHostSigner updates the host signer with a new key generated from the provided seed.
1102+
// If an existing host key exists with the same algorithm, it is overwritten
1103+
func (s *Server) UpdateHostSigner(seed int64) error {
1104+
key, err := CoderSigner(seed)
1105+
if err != nil {
1106+
return err
1107+
}
1108+
1109+
s.mu.Lock()
1110+
defer s.mu.Unlock()
1111+
1112+
s.srv.AddHostKey(key)
1113+
1114+
return nil
1115+
}
1116+
1117+
// CoderSigner generates a deterministic SSH signer based on the provided seed.
1118+
// It uses RSA with a key size of 2048 bits.
1119+
func CoderSigner(seed int64) (gossh.Signer, error) {
1120+
// Clients should ignore the host key when connecting.
1121+
// The agent needs to authenticate with coderd to SSH,
1122+
// so SSH authentication doesn't improve security.
1123+
1124+
// Since the standard lib purposefully does not generate
1125+
// deterministic rsa keys, we need to do it ourselves.
1126+
coderHostKey := func() *rsa.PrivateKey {
1127+
// Create deterministic random source
1128+
// nolint: gosec
1129+
deterministicRand := rand.New(rand.NewSource(seed))
1130+
1131+
// Use fixed values for p and q based on the seed
1132+
p := big.NewInt(0)
1133+
q := big.NewInt(0)
1134+
e := big.NewInt(65537) // Standard RSA public exponent
1135+
1136+
// Generate deterministic primes using the seeded random
1137+
// Each prime should be ~1024 bits to get a 2048-bit key
1138+
for {
1139+
p.SetBit(p, 1024, 1) // Ensure it's large enough
1140+
for i := 0; i < 1024; i++ {
1141+
if deterministicRand.Int63()%2 == 1 {
1142+
p.SetBit(p, i, 1)
1143+
} else {
1144+
p.SetBit(p, i, 0)
1145+
}
1146+
}
1147+
if p.ProbablyPrime(20) {
1148+
break
1149+
}
1150+
}
1151+
1152+
for {
1153+
q.SetBit(q, 1024, 1) // Ensure it's large enough
1154+
for i := 0; i < 1024; i++ {
1155+
if deterministicRand.Int63()%2 == 1 {
1156+
q.SetBit(q, i, 1)
1157+
} else {
1158+
q.SetBit(q, i, 0)
1159+
}
1160+
}
1161+
if q.ProbablyPrime(20) && p.Cmp(q) != 0 {
1162+
break
1163+
}
1164+
}
1165+
1166+
// Calculate n = p * q
1167+
n := new(big.Int).Mul(p, q)
1168+
1169+
// Calculate phi = (p-1) * (q-1)
1170+
p1 := new(big.Int).Sub(p, big.NewInt(1))
1171+
q1 := new(big.Int).Sub(q, big.NewInt(1))
1172+
phi := new(big.Int).Mul(p1, q1)
1173+
1174+
// Calculate private exponent d
1175+
d := new(big.Int).ModInverse(e, phi)
1176+
1177+
// Create the private key
1178+
privateKey := &rsa.PrivateKey{
1179+
PublicKey: rsa.PublicKey{
1180+
N: n,
1181+
E: int(e.Int64()),
1182+
},
1183+
D: d,
1184+
Primes: []*big.Int{p, q},
1185+
}
1186+
1187+
// Compute precomputed values
1188+
privateKey.Precompute()
1189+
1190+
return privateKey
1191+
}()
1192+
1193+
coderSigner, err := gossh.NewSignerFromKey(coderHostKey)
1194+
return coderSigner, err
1195+
}

agent/agentssh/agentssh_internal_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func Test_sessionStart_orphan(t *testing.T) {
3939
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4040
require.NoError(t, err)
4141
defer s.Close()
42+
err = s.UpdateHostSigner(42)
43+
assert.NoError(t, err)
4244

4345
// Here we're going to call the handler directly with a faked SSH session
4446
// that just uses io.Pipes instead of a network socket. There is a large

agent/agentssh/agentssh_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func TestNewServer_ServeClient(t *testing.T) {
4141
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4242
require.NoError(t, err)
4343
defer s.Close()
44+
err = s.UpdateHostSigner(42)
45+
assert.NoError(t, err)
4446

4547
ln, err := net.Listen("tcp", "127.0.0.1:0")
4648
require.NoError(t, err)
@@ -146,6 +148,8 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
146148
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
147149
require.NoError(t, err)
148150
defer s.Close()
151+
err = s.UpdateHostSigner(42)
152+
assert.NoError(t, err)
149153

150154
ln, err := net.Listen("tcp", "127.0.0.1:0")
151155
require.NoError(t, err)
@@ -197,6 +201,8 @@ func TestNewServer_Signal(t *testing.T) {
197201
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
198202
require.NoError(t, err)
199203
defer s.Close()
204+
err = s.UpdateHostSigner(42)
205+
assert.NoError(t, err)
200206

201207
ln, err := net.Listen("tcp", "127.0.0.1:0")
202208
require.NoError(t, err)
@@ -262,6 +268,8 @@ func TestNewServer_Signal(t *testing.T) {
262268
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
263269
require.NoError(t, err)
264270
defer s.Close()
271+
err = s.UpdateHostSigner(42)
272+
assert.NoError(t, err)
265273

266274
ln, err := net.Listen("tcp", "127.0.0.1:0")
267275
require.NoError(t, err)

agent/agentssh/x11_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func TestServer_X11(t *testing.T) {
3838
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
3939
require.NoError(t, err)
4040
defer s.Close()
41+
err = s.UpdateHostSigner(42)
42+
assert.NoError(t, err)
4143

4244
ln, err := net.Listen("tcp", "127.0.0.1:0")
4345
require.NoError(t, err)

cli/ssh_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,71 @@ func TestSSH(t *testing.T) {
453453
<-cmdDone
454454
})
455455

456+
t.Run("DeterministicHostKey", func(t *testing.T) {
457+
t.Parallel()
458+
client, workspace, agentToken := setupWorkspaceForAgent(t)
459+
_, _ = tGoContext(t, func(ctx context.Context) {
460+
// Run this async so the SSH command has to wait for
461+
// the build and agent to connect!
462+
_ = agenttest.New(t, client.URL, agentToken)
463+
<-ctx.Done()
464+
})
465+
466+
clientOutput, clientInput := io.Pipe()
467+
serverOutput, serverInput := io.Pipe()
468+
defer func() {
469+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
470+
_ = c.Close()
471+
}
472+
}()
473+
474+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
475+
defer cancel()
476+
477+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
478+
clitest.SetupConfig(t, client, root)
479+
inv.Stdin = clientOutput
480+
inv.Stdout = serverInput
481+
inv.Stderr = io.Discard
482+
483+
cmdDone := tGo(t, func() {
484+
err := inv.WithContext(ctx).Run()
485+
assert.NoError(t, err)
486+
})
487+
488+
keySeed, err := agent.WorkspaceKeySeed(workspace.ID, "dev")
489+
assert.NoError(t, err)
490+
491+
signer, err := agentssh.CoderSigner(keySeed)
492+
assert.NoError(t, err)
493+
494+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
495+
Reader: serverOutput,
496+
Writer: clientInput,
497+
}, "", &ssh.ClientConfig{
498+
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
499+
})
500+
require.NoError(t, err)
501+
defer conn.Close()
502+
503+
sshClient := ssh.NewClient(conn, channels, requests)
504+
session, err := sshClient.NewSession()
505+
require.NoError(t, err)
506+
defer session.Close()
507+
508+
command := "sh -c exit"
509+
if runtime.GOOS == "windows" {
510+
command = "cmd.exe /c exit"
511+
}
512+
err = session.Run(command)
513+
require.NoError(t, err)
514+
err = sshClient.Close()
515+
require.NoError(t, err)
516+
_ = clientOutput.Close()
517+
518+
<-cmdDone
519+
})
520+
456521
t.Run("NetworkInfo", func(t *testing.T) {
457522
t.Parallel()
458523
client, workspace, agentToken := setupWorkspaceForAgent(t)

0 commit comments

Comments
 (0)