Skip to content

Commit cab51d9

Browse files
committed
fix(agent/agentssh): pin random seed for RSA key generation
Change-Id: I8c7e3070324e5d558374fd6891eea9d48660e1e9 Signed-off-by: Thomas Kosiewski <tk@coder.com>
1 parent dedc32f commit cab51d9

File tree

5 files changed

+89
-14
lines changed

5 files changed

+89
-14
lines changed

agent/agent.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"hash/fnv"
910
"io"
1011
"net/http"
1112
"net/netip"
@@ -372,7 +373,6 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
372373
// Important: if the command times out, we may see a misleading error like
373374
// "exit status 1", so it's important to include the context error.
374375
err = errors.Join(err, ctx.Err())
375-
376376
if err != nil {
377377
result.Error = fmt.Sprintf("run cmd: %+v", err)
378378
}
@@ -995,7 +995,6 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
995995
if err := manifestOK.wait(ctx); err != nil {
996996
return xerrors.Errorf("no manifest: %w", err)
997997
}
998-
var err error
999998
defer func() {
1000999
networkOK.complete(retErr)
10011000
}()
@@ -1004,9 +1003,20 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
10041003
network := a.network
10051004
a.closeMutex.Unlock()
10061005
if network == nil {
1006+
keySeed, err := workspaceSeed(manifest.WorkspaceID, manifest.AgentName)
1007+
if err != nil {
1008+
return xerrors.Errorf("generate seed from workspace id: %w", err)
1009+
}
10071010
// use the graceful context here, because creating the tailnet is not itself tied to the
10081011
// agent API.
1009-
network, err = a.createTailnet(a.gracefulCtx, manifest.AgentID, manifest.DERPMap, manifest.DERPForceWebSockets, manifest.DisableDirectConnections)
1012+
network, err = a.createTailnet(
1013+
a.gracefulCtx,
1014+
manifest.AgentID,
1015+
manifest.DERPMap,
1016+
manifest.DERPForceWebSockets,
1017+
manifest.DisableDirectConnections,
1018+
keySeed,
1019+
)
10101020
if err != nil {
10111021
return xerrors.Errorf("create tailnet: %w", err)
10121022
}
@@ -1146,7 +1156,13 @@ func (a *agent) trackGoroutine(fn func()) error {
11461156
return nil
11471157
}
11481158

1149-
func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *tailcfg.DERPMap, derpForceWebSockets, disableDirectConnections bool) (_ *tailnet.Conn, err error) {
1159+
func (a *agent) createTailnet(
1160+
ctx context.Context,
1161+
agentID uuid.UUID,
1162+
derpMap *tailcfg.DERPMap,
1163+
derpForceWebSockets, disableDirectConnections bool,
1164+
keySeed int64,
1165+
) (_ *tailnet.Conn, err error) {
11501166
// Inject `CODER_AGENT_HEADER` into the DERP header.
11511167
var header http.Header
11521168
if client, ok := a.client.(*agentsdk.Client); ok {
@@ -1173,6 +1189,10 @@ func (a *agent) createTailnet(ctx context.Context, agentID uuid.UUID, derpMap *t
11731189
}
11741190
}()
11751191

1192+
if err := a.sshServer.UpdateHostSigner(keySeed); err != nil {
1193+
return nil, xerrors.Errorf("update host signer: %w", err)
1194+
}
1195+
11761196
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentSSHPort))
11771197
if err != nil {
11781198
return nil, xerrors.Errorf("listen on the ssh port: %w", err)
@@ -1850,3 +1870,20 @@ func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger sl
18501870
}
18511871
})
18521872
}
1873+
1874+
// workspaceSeed converts a WorkspaceID UUID to an int64 hash.
1875+
// This uses the FNV-1a hash algorithm which provides decent distribution and collision
1876+
// resistance for string inputs.
1877+
func workspaceSeed(workspaceID uuid.UUID, agentName string) (int64, error) {
1878+
h := fnv.New64a()
1879+
_, err := h.Write(workspaceID[:])
1880+
if err != nil {
1881+
return 42, err
1882+
}
1883+
_, err = h.Write([]byte(agentName))
1884+
if err != nil {
1885+
return 42, err
1886+
}
1887+
1888+
return int64(h.Sum64()), nil
1889+
}

agent/agentssh/agentssh.go

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@ package agentssh
33
import (
44
"bufio"
55
"context"
6-
"crypto/rand"
76
"crypto/rsa"
87
"errors"
98
"fmt"
109
"io"
10+
"math/rand"
1111
"net"
1212
"os"
1313
"os/exec"
@@ -131,14 +131,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
131131
// Clients' should ignore the host key when connecting.
132132
// The agent needs to authenticate with coderd to SSH,
133133
// so SSH authentication doesn't improve security.
134-
randomHostKey, err := rsa.GenerateKey(rand.Reader, 2048)
135-
if err != nil {
136-
return nil, err
137-
}
138-
randomSigner, err := gossh.NewSignerFromKey(randomHostKey)
139-
if err != nil {
140-
return nil, err
141-
}
134+
142135
if config == nil {
143136
config = &Config{}
144137
}
@@ -206,7 +199,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
206199
slog.Error(err))
207200
},
208201
Handler: s.sessionHandler,
209-
HostSigners: []ssh.Signer{randomSigner},
202+
HostSigners: []ssh.Signer{},
210203
LocalPortForwardingCallback: func(ctx ssh.Context, destinationHost string, destinationPort uint32) bool {
211204
// Allow local port forwarding all!
212205
s.logger.Debug(ctx, "local port forward",
@@ -845,6 +838,10 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
845838
}
846839

847840
func (s *Server) Serve(l net.Listener) (retErr error) {
841+
if len(s.srv.HostSigners) == 0 {
842+
return xerrors.New("no host keys set")
843+
}
844+
848845
s.logger.Info(context.Background(), "started serving listener", slog.F("listen_addr", l.Addr()))
849846
defer func() {
850847
s.logger.Info(context.Background(), "stopped serving listener",
@@ -1099,3 +1096,32 @@ func userHomeDir() (string, error) {
10991096
}
11001097
return u.HomeDir, nil
11011098
}
1099+
1100+
// UpdateHostSigner updates the host signer with a new key generated from the provided seed.
1101+
// If an existing host key exists with the same algorithm, it is overwritten
1102+
func (s *Server) UpdateHostSigner(seed int64) error {
1103+
key, err := coderSigner(seed)
1104+
if err != nil {
1105+
return err
1106+
}
1107+
1108+
s.mu.Lock()
1109+
defer s.mu.Unlock()
1110+
1111+
s.srv.AddHostKey(key)
1112+
1113+
return nil
1114+
}
1115+
1116+
// coderSigner generates a deterministic SSH signer based on the provided seed.
1117+
// It uses RSA with a key size of 2048 bits.
1118+
func coderSigner(seed int64) (gossh.Signer, error) {
1119+
// nolint: gosec
1120+
deterministicRand := rand.New(rand.NewSource(seed))
1121+
coderHostKey, err := rsa.GenerateKey(deterministicRand, 2048)
1122+
if err != nil {
1123+
return nil, err
1124+
}
1125+
coderSigner, err := gossh.NewSignerFromKey(coderHostKey)
1126+
return coderSigner, err
1127+
}

agent/agentssh/agentssh_internal_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ func Test_sessionStart_orphan(t *testing.T) {
3939
s, err := NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4040
require.NoError(t, err)
4141
defer s.Close()
42+
err = s.UpdateHostSigner(42)
43+
assert.NoError(t, err)
4244

4345
// Here we're going to call the handler directly with a faked SSH session
4446
// that just uses io.Pipes instead of a network socket. There is a large

agent/agentssh/agentssh_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ func TestNewServer_ServeClient(t *testing.T) {
4141
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
4242
require.NoError(t, err)
4343
defer s.Close()
44+
err = s.UpdateHostSigner(42)
45+
assert.NoError(t, err)
4446

4547
ln, err := net.Listen("tcp", "127.0.0.1:0")
4648
require.NoError(t, err)
@@ -146,6 +148,8 @@ func TestNewServer_CloseActiveConnections(t *testing.T) {
146148
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
147149
require.NoError(t, err)
148150
defer s.Close()
151+
err = s.UpdateHostSigner(42)
152+
assert.NoError(t, err)
149153

150154
ln, err := net.Listen("tcp", "127.0.0.1:0")
151155
require.NoError(t, err)
@@ -197,6 +201,8 @@ func TestNewServer_Signal(t *testing.T) {
197201
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
198202
require.NoError(t, err)
199203
defer s.Close()
204+
err = s.UpdateHostSigner(42)
205+
assert.NoError(t, err)
200206

201207
ln, err := net.Listen("tcp", "127.0.0.1:0")
202208
require.NoError(t, err)
@@ -262,6 +268,8 @@ func TestNewServer_Signal(t *testing.T) {
262268
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
263269
require.NoError(t, err)
264270
defer s.Close()
271+
err = s.UpdateHostSigner(42)
272+
assert.NoError(t, err)
265273

266274
ln, err := net.Listen("tcp", "127.0.0.1:0")
267275
require.NoError(t, err)

agent/agentssh/x11_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ func TestServer_X11(t *testing.T) {
3838
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), fs, agentexec.DefaultExecer, &agentssh.Config{})
3939
require.NoError(t, err)
4040
defer s.Close()
41+
err = s.UpdateHostSigner(42)
42+
assert.NoError(t, err)
4143

4244
ln, err := net.Listen("tcp", "127.0.0.1:0")
4345
require.NoError(t, err)

0 commit comments

Comments
 (0)