Skip to content

Commit 70d11a1

Browse files
authored
Merge branch 'main' into zed-as-coder-app
2 parents 495d678 + e9863ab commit 70d11a1

File tree

157 files changed

+6811
-1556
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

157 files changed

+6811
-1556
lines changed

.github/actions/setup-tf/action.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ runs:
77
- name: Install Terraform
88
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # v3.1.2
99
with:
10-
terraform_version: 1.11.2
10+
terraform_version: 1.11.3
1111
terraform_wrapper: false

agent/agent.go

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ func (a *agent) run() (retErr error) {
907907
defer func() {
908908
cErr := aAPI.DRPCConn().Close()
909909
if cErr != nil {
910-
a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(err))
910+
a.logger.Debug(a.hardCtx, "error closing drpc connection", slog.Error(cErr))
911911
}
912912
}()
913913

@@ -1186,9 +1186,9 @@ func (a *agent) createOrUpdateNetwork(manifestOK, networkOK *checkpoint) func(co
11861186
network := a.network
11871187
a.closeMutex.Unlock()
11881188
if network == nil {
1189-
keySeed, err := WorkspaceKeySeed(manifest.WorkspaceID, manifest.AgentName)
1189+
keySeed, err := SSHKeySeed(manifest.OwnerName, manifest.WorkspaceName, manifest.AgentName)
11901190
if err != nil {
1191-
return xerrors.Errorf("generate seed from workspace id: %w", err)
1191+
return xerrors.Errorf("generate SSH key seed: %w", err)
11921192
}
11931193
// use the graceful context here, because creating the tailnet is not itself tied to the
11941194
// agent API.
@@ -1518,14 +1518,11 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai
15181518
a.logger.Info(ctx, "connected to coordination RPC")
15191519

15201520
// This allows the Close() routine to wait for the coordinator to gracefully disconnect.
1521-
a.closeMutex.Lock()
1522-
if a.isClosed() {
1523-
return nil
1521+
disconnected := a.setCoordDisconnected()
1522+
if disconnected == nil {
1523+
return nil // already closed by something else
15241524
}
1525-
disconnected := make(chan struct{})
1526-
a.coordDisconnected = disconnected
15271525
defer close(disconnected)
1528-
a.closeMutex.Unlock()
15291526

15301527
ctrl := tailnet.NewAgentCoordinationController(a.logger, network)
15311528
coordination := ctrl.New(coordinate)
@@ -1547,6 +1544,17 @@ func (a *agent) runCoordinator(ctx context.Context, tClient tailnetproto.DRPCTai
15471544
return <-errCh
15481545
}
15491546

1547+
func (a *agent) setCoordDisconnected() chan struct{} {
1548+
a.closeMutex.Lock()
1549+
defer a.closeMutex.Unlock()
1550+
if a.isClosed() {
1551+
return nil
1552+
}
1553+
disconnected := make(chan struct{})
1554+
a.coordDisconnected = disconnected
1555+
return disconnected
1556+
}
1557+
15501558
// runDERPMapSubscriber runs a coordinator and returns if a reconnect should occur.
15511559
func (a *agent) runDERPMapSubscriber(ctx context.Context, tClient tailnetproto.DRPCTailnetClient24, network *tailnet.Conn) error {
15521560
defer a.logger.Debug(ctx, "disconnected from derp map RPC")
@@ -1773,15 +1781,22 @@ func (a *agent) Close() error {
17731781
a.setLifecycle(codersdk.WorkspaceAgentLifecycleShuttingDown)
17741782

17751783
// Attempt to gracefully shut down all active SSH connections and
1776-
// stop accepting new ones.
1777-
err := a.sshServer.Shutdown(a.hardCtx)
1784+
// stop accepting new ones. If all processes have not exited after 5
1785+
// seconds, we just log it and move on as it's more important to run
1786+
// the shutdown scripts. A typical shutdown time for containers is
1787+
// 10 seconds, so this still leaves a bit of time to run the
1788+
// shutdown scripts in the worst-case.
1789+
sshShutdownCtx, sshShutdownCancel := context.WithTimeout(a.hardCtx, 5*time.Second)
1790+
defer sshShutdownCancel()
1791+
err := a.sshServer.Shutdown(sshShutdownCtx)
17781792
if err != nil {
1779-
a.logger.Error(a.hardCtx, "ssh server shutdown", slog.Error(err))
1780-
}
1781-
err = a.sshServer.Close()
1782-
if err != nil {
1783-
a.logger.Error(a.hardCtx, "ssh server close", slog.Error(err))
1793+
if errors.Is(err, context.DeadlineExceeded) {
1794+
a.logger.Warn(sshShutdownCtx, "ssh server shutdown timeout", slog.Error(err))
1795+
} else {
1796+
a.logger.Error(sshShutdownCtx, "ssh server shutdown", slog.Error(err))
1797+
}
17841798
}
1799+
17851800
// wait for SSH to shut down before the general graceful cancel, because
17861801
// this triggers a disconnect in the tailnet layer, telling all clients to
17871802
// shut down their wireguard tunnels to us. If SSH sessions are still up,
@@ -2061,12 +2076,31 @@ func PrometheusMetricsHandler(prometheusRegistry *prometheus.Registry, logger sl
20612076
})
20622077
}
20632078

2064-
// WorkspaceKeySeed converts a WorkspaceID UUID and agent name to an int64 hash.
2079+
// SSHKeySeed converts an owner userName, workspaceName and agentName to an int64 hash.
20652080
// This uses the FNV-1a hash algorithm which provides decent distribution and collision
20662081
// resistance for string inputs.
2067-
func WorkspaceKeySeed(workspaceID uuid.UUID, agentName string) (int64, error) {
2082+
//
2083+
// Why owner username, workspace name, and agent name? These are the components that are used in hostnames for the
2084+
// workspace over SSH, and so we want the workspace to have a stable key with respect to these. We don't use the
2085+
// respective UUIDs. The workspace UUID would be different if you delete and recreate a workspace with the same name.
2086+
// The agent UUID is regenerated on each build. Since Coder's Tailnet networking is handling the authentication, we
2087+
// should not be showing users warnings about host SSH keys.
2088+
func SSHKeySeed(userName, workspaceName, agentName string) (int64, error) {
20682089
h := fnv.New64a()
2069-
_, err := h.Write(workspaceID[:])
2090+
_, err := h.Write([]byte(userName))
2091+
if err != nil {
2092+
return 42, err
2093+
}
2094+
// null separators between strings so that (dog, foodstuff) is distinct from (dogfood, stuff)
2095+
_, err = h.Write([]byte{0})
2096+
if err != nil {
2097+
return 42, err
2098+
}
2099+
_, err = h.Write([]byte(workspaceName))
2100+
if err != nil {
2101+
return 42, err
2102+
}
2103+
_, err = h.Write([]byte{0})
20702104
if err != nil {
20712105
return 42, err
20722106
}

agent/agentssh/agentssh.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,12 @@ func (s *Server) sessionStart(logger slog.Logger, session ssh.Session, env []str
582582
func (s *Server) startNonPTYSession(logger slog.Logger, session ssh.Session, magicTypeLabel string, cmd *exec.Cmd) error {
583583
s.metrics.sessionsTotal.WithLabelValues(magicTypeLabel, "no").Add(1)
584584

585+
// Create a process group and send SIGHUP to child processes,
586+
// otherwise context cancellation will not propagate properly
587+
// and SSH server close may be delayed.
588+
cmd.SysProcAttr = cmdSysProcAttr()
589+
cmd.Cancel = cmdCancel(session.Context(), logger, cmd)
590+
585591
cmd.Stdout = session
586592
cmd.Stderr = session.Stderr()
587593
// This blocks forever until stdin is received if we don't
@@ -926,7 +932,12 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string,
926932
// Serve starts the server to handle incoming connections on the provided listener.
927933
// It returns an error if no host keys are set or if there is an issue accepting connections.
928934
func (s *Server) Serve(l net.Listener) (retErr error) {
929-
if len(s.srv.HostSigners) == 0 {
935+
// Ensure we're not mutating HostSigners as we're reading it.
936+
s.mu.RLock()
937+
noHostKeys := len(s.srv.HostSigners) == 0
938+
s.mu.RUnlock()
939+
940+
if noHostKeys {
930941
return xerrors.New("no host keys set")
931942
}
932943

@@ -1054,43 +1065,72 @@ func (s *Server) Close() error {
10541065
}
10551066
s.closing = make(chan struct{})
10561067

1068+
ctx := context.Background()
1069+
1070+
s.logger.Debug(ctx, "closing server")
1071+
1072+
// Stop accepting new connections.
1073+
s.logger.Debug(ctx, "closing all active listeners", slog.F("count", len(s.listeners)))
1074+
for l := range s.listeners {
1075+
_ = l.Close()
1076+
}
1077+
10571078
// Close all active sessions to gracefully
10581079
// terminate client connections.
1080+
s.logger.Debug(ctx, "closing all active sessions", slog.F("count", len(s.sessions)))
10591081
for ss := range s.sessions {
10601082
// We call Close on the underlying channel here because we don't
10611083
// want to send an exit status to the client (via Exit()).
10621084
// Typically OpenSSH clients will return 255 as the exit status.
10631085
_ = ss.Close()
10641086
}
1065-
1066-
// Close all active listeners and connections.
1067-
for l := range s.listeners {
1068-
_ = l.Close()
1069-
}
1087+
s.logger.Debug(ctx, "closing all active connections", slog.F("count", len(s.conns)))
10701088
for c := range s.conns {
10711089
_ = c.Close()
10721090
}
10731091

1074-
// Close the underlying SSH server.
1092+
s.logger.Debug(ctx, "closing SSH server")
10751093
err := s.srv.Close()
10761094

10771095
s.mu.Unlock()
1096+
1097+
s.logger.Debug(ctx, "waiting for all goroutines to exit")
10781098
s.wg.Wait() // Wait for all goroutines to exit.
10791099

10801100
s.mu.Lock()
10811101
close(s.closing)
10821102
s.closing = nil
10831103
s.mu.Unlock()
10841104

1105+
s.logger.Debug(ctx, "closing server done")
1106+
10851107
return err
10861108
}
10871109

1088-
// Shutdown gracefully closes all active SSH connections and stops
1089-
// accepting new connections.
1090-
//
1091-
// Shutdown is not implemented.
1092-
func (*Server) Shutdown(_ context.Context) error {
1093-
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1110+
// Shutdown stops accepting new connections. The current implementation
1111+
// calls Close() for simplicity instead of waiting for existing
1112+
// connections to close. If the context times out, Shutdown will return
1113+
// but Close() may not have completed.
1114+
func (s *Server) Shutdown(ctx context.Context) error {
1115+
ch := make(chan error, 1)
1116+
go func() {
1117+
// TODO(mafredri): Implement shutdown, SIGHUP running commands, etc.
1118+
// For now we just close the server.
1119+
ch <- s.Close()
1120+
}()
1121+
var err error
1122+
select {
1123+
case <-ctx.Done():
1124+
err = ctx.Err()
1125+
case err = <-ch:
1126+
}
1127+
// Re-check for context cancellation precedence.
1128+
if ctx.Err() != nil {
1129+
err = ctx.Err()
1130+
}
1131+
if err != nil {
1132+
return xerrors.Errorf("close server: %w", err)
1133+
}
10941134
return nil
10951135
}
10961136

agent/agentssh/agentssh_test.go

Lines changed: 79 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"go.uber.org/goleak"
2222
"golang.org/x/crypto/ssh"
2323

24+
"cdr.dev/slog"
2425
"cdr.dev/slog/sloggers/slogtest"
2526

2627
"github.com/coder/coder/v2/agent/agentexec"
@@ -147,51 +148,92 @@ func (*fakeEnvInfoer) ModifyCommand(cmd string, args ...string) (string, []strin
147148
func TestNewServer_CloseActiveConnections(t *testing.T) {
148149
t.Parallel()
149150

150-
ctx := context.Background()
151-
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
152-
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
153-
require.NoError(t, err)
154-
defer s.Close()
155-
err = s.UpdateHostSigner(42)
156-
assert.NoError(t, err)
151+
prepare := func(ctx context.Context, t *testing.T) (*agentssh.Server, func()) {
152+
t.Helper()
153+
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug)
154+
s, err := agentssh.NewServer(ctx, logger, prometheus.NewRegistry(), afero.NewMemMapFs(), agentexec.DefaultExecer, nil)
155+
require.NoError(t, err)
156+
defer s.Close()
157+
err = s.UpdateHostSigner(42)
158+
assert.NoError(t, err)
157159

158-
ln, err := net.Listen("tcp", "127.0.0.1:0")
159-
require.NoError(t, err)
160+
ln, err := net.Listen("tcp", "127.0.0.1:0")
161+
require.NoError(t, err)
160162

161-
var wg sync.WaitGroup
162-
wg.Add(2)
163-
go func() {
164-
defer wg.Done()
165-
err := s.Serve(ln)
166-
assert.Error(t, err) // Server is closed.
167-
}()
163+
waitConns := make([]chan struct{}, 4)
168164

169-
pty := ptytest.New(t)
165+
var wg sync.WaitGroup
166+
wg.Add(1 + len(waitConns))
170167

171-
doClose := make(chan struct{})
172-
go func() {
173-
defer wg.Done()
174-
c := sshClient(t, ln.Addr().String())
175-
sess, err := c.NewSession()
176-
assert.NoError(t, err)
177-
sess.Stdin = pty.Input()
178-
sess.Stdout = pty.Output()
179-
sess.Stderr = pty.Output()
168+
go func() {
169+
defer wg.Done()
170+
err := s.Serve(ln)
171+
assert.Error(t, err) // Server is closed.
172+
}()
180173

181-
assert.NoError(t, err)
182-
err = sess.Start("")
183-
assert.NoError(t, err)
174+
for i := 0; i < len(waitConns); i++ {
175+
waitConns[i] = make(chan struct{})
176+
go func(ch chan struct{}) {
177+
defer wg.Done()
178+
c := sshClient(t, ln.Addr().String())
179+
sess, err := c.NewSession()
180+
assert.NoError(t, err)
181+
pty := ptytest.New(t)
182+
sess.Stdin = pty.Input()
183+
sess.Stdout = pty.Output()
184+
sess.Stderr = pty.Output()
185+
186+
// Every other session will request a PTY.
187+
if i%2 == 0 {
188+
err = sess.RequestPty("xterm", 80, 80, nil)
189+
assert.NoError(t, err)
190+
}
191+
// The 60 seconds here is intended to be longer than the
192+
// test. The shutdown should propagate.
193+
err = sess.Start("/bin/bash -c 'trap \"sleep 60\" SIGTERM; sleep 60'")
194+
assert.NoError(t, err)
195+
196+
close(ch)
197+
err = sess.Wait()
198+
assert.Error(t, err)
199+
}(waitConns[i])
200+
}
184201

185-
close(doClose)
186-
err = sess.Wait()
187-
assert.Error(t, err)
188-
}()
202+
for _, ch := range waitConns {
203+
<-ch
204+
}
189205

190-
<-doClose
191-
err = s.Close()
192-
require.NoError(t, err)
206+
return s, wg.Wait
207+
}
208+
209+
t.Run("Close", func(t *testing.T) {
210+
t.Parallel()
211+
ctx := testutil.Context(t, testutil.WaitMedium)
212+
s, wait := prepare(ctx, t)
213+
err := s.Close()
214+
require.NoError(t, err)
215+
wait()
216+
})
193217

194-
wg.Wait()
218+
t.Run("Shutdown", func(t *testing.T) {
219+
t.Parallel()
220+
ctx := testutil.Context(t, testutil.WaitMedium)
221+
s, wait := prepare(ctx, t)
222+
err := s.Shutdown(ctx)
223+
require.NoError(t, err)
224+
wait()
225+
})
226+
227+
t.Run("Shutdown Early", func(t *testing.T) {
228+
t.Parallel()
229+
ctx := testutil.Context(t, testutil.WaitMedium)
230+
s, wait := prepare(ctx, t)
231+
ctx, cancel := context.WithCancel(ctx)
232+
cancel()
233+
err := s.Shutdown(ctx)
234+
require.ErrorIs(t, err, context.Canceled)
235+
wait()
236+
})
195237
}
196238

197239
func TestNewServer_Signal(t *testing.T) {

0 commit comments

Comments
 (0)