Skip to content

Commit ad275dc

Browse files
committed
fix(agent/agentssh): pin random seed for RSA key generation
Change-Id: I8c7e3070324e5d558374fd6891eea9d48660e1e9 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent fcc9b05 commit ad275dc

File tree

6 files changed

+226
-17
lines changed

6 files changed

+226
-17
lines changed

agent/agent.go

+41-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"hash/fnv"
910
"io"
1011
"net/http"
1112
"net/netip"
@@ -994,7 +995,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
994995
if err := manifestOK.wait(ctx); err != nil {
995996
return xerrors.Errorf("no manifest: %w", err)
996997
}
997-
var err error
998998
defer func() {
999999
networkOK.complete(retErr)
10001000
}()
@@ -1003,9 +1003,20 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
10031003
network := a.network
10041004
a.closeMutex.Unlock()
10051005
if network == nil {
1006+
keySeed, err := WorkspaceKeySeed(manifest.WorkspaceID, manifest.AgentName)
1007+
if err != nil {
1008+
return xerrors.Errorf("generate seed from workspace id: %w", err)
1009+
}
10061010
// use the graceful context here, because creating the tailnet is not itself tied to the
10071011
// agent API.
1008-
network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections)
1012+
network, err = a.createTailnet(
1013+
a.gracefulCtx,
1014+
manifest.AgentID,
1015+
manifest.DERPMap,
1016+
manifest.DERPForceWebSockets,
1017+
manifest.DisableDirectConnections,
1018+
keySeed,
1019+
)
10091020
if err != nil {
10101021
return xerrors.Errorf("create tailnet: %w", err)
10111022
}
@@ -1145,7 +1156,13 @@ func (a *agent) trackGoroutine(fn func()) error {
11451156
return nil
11461157
}
11471158

1148-
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, derpForceWebSockets, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
1159+
func (a *agent) createTailnet(
1160+
ctx context.Context,
1161+
agentID uuid.UUID,
1162+
derpMap *tailcfg.DERPMap,
1163+
derpForceWebSockets, disableDirectConnections bool,
1164+
keySeed int64,
1165+
) (_ *tailnet.Conn, err error) {
11491166
// Inject `CODER_AGENT_HEADER` into the DERP header.
11501167
var header http.Header
11511168
if client, ok := a.client.(*agentsdk.Client); ok {
@@ -1172,6 +1189,10 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
11721189
}
11731190
}()
11741191

1192+
if err := a.sshServer.UpdateHostSigner(keySeed); err != nil {
1193+
return nil, xerrors.Errorf("update host signer: %w", err)
1194+
}
1195+
11751196
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
11761197
if err != nil {
11771198
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
@@ -1849,3 +1870,20 @@ func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger sl
18491870
}
18501871
})
18511872
}
1873+
1874+
// WorkspaceKeySeed converts a WorkspaceID UUID and agent name to an int64 hash.
1875+
// This uses the FNV-1a hash algorithm which provides decent distribution and collision
1876+
// resistance for string inputs.
1877+
func WorkspaceKeySeed(workspaceID uuid.UUID, agentName string) (int64, error) {
1878+
h := fnv.New64a()
1879+
_, err := h.Write(workspaceID[:])
1880+
if err != nil {
1881+
return 42, err
1882+
}
1883+
_, err = h.Write([]byte(agentName))
1884+
if err != nil {
1885+
return 42, err
1886+
}
1887+
1888+
return int64(h.Sum64()), nil
1889+
}

agent/agentssh/agentssh.go

+108-14
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package agentssh
33
import (
44
"bufio"
55
"context"
6-
"crypto/rand"
76
"crypto/rsa"
87
"errors"
98
"fmt"
109
"io"
10+
"math/big"
11+
"math/rand"
1112
"net"
1213
"os"
1314
"os/exec"
@@ -128,17 +129,6 @@ type Server struct {
128129
}
129130

130131
func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prometheus.Registry, fs afero.Fs, execer agentexec.Execer, config *Config) (*Server, error) {
131-
// Clients' should ignore the host key when connecting.
132-
// The agent needs to authenticate with coderd to SSH,
133-
// so SSH authentication doesn't improve security.
134-
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
135-
if err != nil {
136-
return nil, err
137-
}
138-
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
139-
if err != nil {
140-
return nil, err
141-
}
142132
if config == nil {
143133
config = &Config{}
144134
}
@@ -205,8 +195,10 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
205195
slog.F("local_addr", conn.LocalAddr()),
206196
slog.Error(err))
207197
},
208-
Handler: s.sessionHandler,
209-
HostSigners: []ssh.Signer{randomSigner},
198+
Handler: s.sessionHandler,
199+
// HostSigners are intentionally empty, as the host key will
200+
// be set before we start listening.
201+
HostSigners: []ssh.Signer{},
210202
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
211203
// Allow local port forwarding all!
212204
s.logger.Debug(ctx, "local port forward",
@@ -844,7 +836,13 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
844836
return cmd, nil
845837
}
846838

839+
// Serve starts the server to handle incoming connections on the provided listener.
840+
// It returns an error if no host keys are set or if there is an issue accepting connections.
847841
func (s *Server) Serve(l net.Listener) (retErr error) {
842+
if len(s.srv.HostSigners) == 0 {
843+
return xerrors.New("no host keys set")
844+
}
845+
848846
s.logger.Info(context.Background(), "started serving listener", slog.F("listen_addr", l.Addr()))
849847
defer func() {
850848
s.logger.Info(context.Background(), "stopped serving listener",
@@ -1099,3 +1097,99 @@ func userHomeDir() (string, error) {
10991097
}
11001098
return u.HomeDir, nil
11011099
}
1100+
1101+
// UpdateHostSigner updates the host signer with a new key generated from the provided seed.
1102+
// If an existing host key exists with the same algorithm, it is overwritten
1103+
func (s *Server) UpdateHostSigner(seed int64) error {
1104+
key, err := CoderSigner(seed)
1105+
if err != nil {
1106+
return err
1107+
}
1108+
1109+
s.mu.Lock()
1110+
defer s.mu.Unlock()
1111+
1112+
s.srv.AddHostKey(key)
1113+
1114+
return nil
1115+
}
1116+
1117+
// CoderSigner generates a deterministic SSH signer based on the provided seed.
1118+
// It uses RSA with a key size of 2048 bits.
1119+
func CoderSigner(seed int64) (gossh.Signer, error) {
1120+
// Clients should ignore the host key when connecting.
1121+
// The agent needs to authenticate with coderd to SSH,
1122+
// so SSH authentication doesn't improve security.
1123+
1124+
// Since the standard lib purposefully does not generate
1125+
// deterministic rsa keys, we need to do it ourselves.
1126+
coderHostKey := func() *rsa.PrivateKey {
1127+
// Create deterministic random source
1128+
// nolint: gosec
1129+
deterministicRand := rand.New(rand.NewSource(seed))
1130+
1131+
// Use fixed values for p and q based on the seed
1132+
p := big.NewInt(0)
1133+
q := big.NewInt(0)
1134+
e := big.NewInt(65537) // Standard RSA public exponent
1135+
1136+
// Generate deterministic primes using the seeded random
1137+
// Each prime should be ~1024 bits to get a 2048-bit key
1138+
for {
1139+
p.SetBit(p, 1024, 1) // Ensure it's large enough
1140+
for i := 0; i < 1024; i++ {
1141+
if deterministicRand.Int63()%2 == 1 {
1142+
p.SetBit(p, i, 1)
1143+
} else {
1144+
p.SetBit(p, i, 0)
1145+
}
1146+
}
1147+
if p.ProbablyPrime(20) {
1148+
break
1149+
}
1150+
}
1151+
1152+
for {
1153+
q.SetBit(q, 1024, 1) // Ensure it's large enough
1154+
for i := 0; i < 1024; i++ {
1155+
if deterministicRand.Int63()%2 == 1 {
1156+
q.SetBit(q, i, 1)
1157+
} else {
1158+
q.SetBit(q, i, 0)
1159+
}
1160+
}
1161+
if q.ProbablyPrime(20) && p.Cmp(q) != 0 {
1162+
break
1163+
}
1164+
}
1165+
1166+
// Calculate n = p * q
1167+
n := new(big.Int).Mul(p, q)
1168+
1169+
// Calculate phi = (p-1) * (q-1)
1170+
p1 := new(big.Int).Sub(p, big.NewInt(1))
1171+
q1 := new(big.Int).Sub(q, big.NewInt(1))
1172+
phi := new(big.Int).Mul(p1, q1)
1173+
1174+
// Calculate private exponent d
1175+
d := new(big.Int).ModInverse(e, phi)
1176+
1177+
// Create the private key
1178+
privateKey := &rsa.PrivateKey{
1179+
PublicKey: rsa.PublicKey{
1180+
N: n,
1181+
E: int(e.Int64()),
1182+
},
1183+
D: d,
1184+
Primes: []*big.Int{p, q},
1185+
}
1186+
1187+
// Compute precomputed values
1188+
privateKey.Precompute()
1189+
1190+
return privateKey
1191+
}()
1192+
1193+
coderSigner, err := gossh.NewSignerFromKey(coderHostKey)
1194+
return coderSigner, err
1195+
}

agent/agentssh/agentssh_internal_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func Test_sessionStart_orphan(t *testing.T) {
3939
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4040
require.NoError(t, err)
4141
defer s.Close()
42+
err = s.UpdateHostSigner(42)
43+
assert.NoError(t, err)
4244

4345
// Here we're going to call the handler directly with a faked SSH session
4446
// that just uses io.Pipes instead of a network socket. There is a large

agent/agentssh/agentssh_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func TestNewServer_ServeClient(t *testing.T) {
4141
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4242
require.NoError(t, err)
4343
defer s.Close()
44+
err = s.UpdateHostSigner(42)
45+
assert.NoError(t, err)
4446

4547
ln, err := net.Listen("tcp", "127.0.0.1:0")
4648
require.NoError(t, err)
@@ -146,6 +148,8 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
146148
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
147149
require.NoError(t, err)
148150
defer s.Close()
151+
err = s.UpdateHostSigner(42)
152+
assert.NoError(t, err)
149153

150154
ln, err := net.Listen("tcp", "127.0.0.1:0")
151155
require.NoError(t, err)
@@ -197,6 +201,8 @@ func TestNewServer_Signal(t *testing.T) {
197201
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
198202
require.NoError(t, err)
199203
defer s.Close()
204+
err = s.UpdateHostSigner(42)
205+
assert.NoError(t, err)
200206

201207
ln, err := net.Listen("tcp", "127.0.0.1:0")
202208
require.NoError(t, err)
@@ -262,6 +268,8 @@ func TestNewServer_Signal(t *testing.T) {
262268
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
263269
require.NoError(t, err)
264270
defer s.Close()
271+
err = s.UpdateHostSigner(42)
272+
assert.NoError(t, err)
265273

266274
ln, err := net.Listen("tcp", "127.0.0.1:0")
267275
require.NoError(t, err)

agent/agentssh/x11_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func TestServer_X11(t *testing.T) {
3838
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
3939
require.NoError(t, err)
4040
defer s.Close()
41+
err = s.UpdateHostSigner(42)
42+
assert.NoError(t, err)
4143

4244
ln, err := net.Listen("tcp", "127.0.0.1:0")
4345
require.NoError(t, err)

cli/ssh_test.go

+65
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,71 @@ func TestSSH(t *testing.T) {
453453
<-cmdDone
454454
})
455455

456+
t.Run("DeterministicHostKey", func(t *testing.T) {
457+
t.Parallel()
458+
client, workspace, agentToken := setupWorkspaceForAgent(t)
459+
_, _ = tGoContext(t, func(ctx context.Context) {
460+
// Run this async so the SSH command has to wait for
461+
// the build and agent to connect!
462+
_ = agenttest.New(t, client.URL, agentToken)
463+
<-ctx.Done()
464+
})
465+
466+
clientOutput, clientInput := io.Pipe()
467+
serverOutput, serverInput := io.Pipe()
468+
defer func() {
469+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
470+
_ = c.Close()
471+
}
472+
}()
473+
474+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
475+
defer cancel()
476+
477+
inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name)
478+
clitest.SetupConfig(t, client, root)
479+
inv.Stdin = clientOutput
480+
inv.Stdout = serverInput
481+
inv.Stderr = io.Discard
482+
483+
cmdDone := tGo(t, func() {
484+
err := inv.WithContext(ctx).Run()
485+
assert.NoError(t, err)
486+
})
487+
488+
keySeed, err := agent.WorkspaceKeySeed(workspace.ID, "dev")
489+
assert.NoError(t, err)
490+
491+
signer, err := agentssh.CoderSigner(keySeed)
492+
assert.NoError(t, err)
493+
494+
conn, channels, requests, err := ssh.NewClientConn(&stdioConn{
495+
Reader: serverOutput,
496+
Writer: clientInput,
497+
}, "", &ssh.ClientConfig{
498+
HostKeyCallback: ssh.FixedHostKey(signer.PublicKey()),
499+
})
500+
require.NoError(t, err)
501+
defer conn.Close()
502+
503+
sshClient := ssh.NewClient(conn, channels, requests)
504+
session, err := sshClient.NewSession()
505+
require.NoError(t, err)
506+
defer session.Close()
507+
508+
command := "sh -c exit"
509+
if runtime.GOOS == "windows" {
510+
command = "cmd.exe /c exit"
511+
}
512+
err = session.Run(command)
513+
require.NoError(t, err)
514+
err = sshClient.Close()
515+
require.NoError(t, err)
516+
_ = clientOutput.Close()
517+
518+
<-cmdDone
519+
})
520+
456521
t.Run("NetworkInfo", func(t *testing.T) {
457522
t.Parallel()
458523
client, workspace, agentToken := setupWorkspaceForAgent(t)

0 commit comments

Comments
 (0)