Skip to content

Commit 97dfbcf

Browse files
committed
Merge branch 'main' into 9883-add-owner-to-workspace
2 parents 8e6f929 + 38d9ce5 commit 97dfbcf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2262
-138
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ provisionersdk/proto/*.go linguist-generated=true
1212
*.tfstate.dot linguist-generated=true
1313
*.tfplan.dot linguist-generated=true
1414
site/src/api/typesGenerated.ts linguist-generated=true
15+
site/src/pages/SetupPage/countries.tsx linguist-generated=true

.github/workflows/typos.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ darcula = "darcula"
1414
Hashi = "Hashi"
1515
trialer = "trialer"
1616
encrypter = "encrypter"
17-
hel = "hel" # as in helsinki
17+
hel = "hel" # as in helsinki
1818

1919
[files]
2020
extend-exclude = [
@@ -31,4 +31,5 @@ extend-exclude = [
3131
"**/*.test.tsx",
3232
"**/pnpm-lock.yaml",
3333
"tailnet/testdata/**",
34+
"site/src/pages/SetupPage/countries.tsx",
3435
]

agent/agent_test.go

+34-21
Original file line numberDiff line numberDiff line change
@@ -214,46 +214,59 @@ func TestAgent_Stats_Magic(t *testing.T) {
214214
_, b, _, ok := runtime.Caller(0)
215215
require.True(t, ok)
216216
dir := filepath.Join(filepath.Dir(b), "../scripts/echoserver/main.go")
217-
echoServerCmd := exec.Command("go", "run", dir,
218-
"-D", agentssh.MagicProcessCmdlineJetBrains)
219-
stdout, err := echoServerCmd.StdoutPipe()
220-
require.NoError(t, err)
221-
err = echoServerCmd.Start()
222-
require.NoError(t, err)
223-
defer echoServerCmd.Process.Kill()
224217

225-
// The echo server prints its port as the first line.
226-
sc := bufio.NewScanner(stdout)
227-
sc.Scan()
228-
remotePort := sc.Text()
218+
spawnServer := func(network string) (string, *exec.Cmd) {
219+
echoServerCmd := exec.Command("go", "run", dir,
220+
network, "-D", agentssh.MagicProcessCmdlineJetBrains)
221+
stdout, err := echoServerCmd.StdoutPipe()
222+
require.NoError(t, err)
223+
err = echoServerCmd.Start()
224+
require.NoError(t, err)
225+
t.Cleanup(func() {
226+
echoServerCmd.Process.Kill()
227+
})
228+
229+
// The echo server prints its port as the first line.
230+
sc := bufio.NewScanner(stdout)
231+
sc.Scan()
232+
return sc.Text(), echoServerCmd
233+
}
234+
235+
port4, cmd4 := spawnServer("tcp4")
236+
port6, cmd6 := spawnServer("tcp6")
229237

230238
//nolint:dogsled
231239
conn, _, stats, _, _ := setupAgent(t, agentsdk.Manifest{}, 0)
240+
defer conn.Close()
241+
232242
sshClient, err := conn.SSHClient(ctx)
233243
require.NoError(t, err)
234244

235-
tunneledConn, err := sshClient.Dial("tcp", fmt.Sprintf("127.0.0.1:%s", remotePort))
245+
tunnel4, err := sshClient.Dial("tcp4", fmt.Sprintf("127.0.0.1:%s", port4))
236246
require.NoError(t, err)
237-
t.Cleanup(func() {
238-
// always close on failure of test
239-
_ = conn.Close()
240-
_ = tunneledConn.Close()
241-
})
247+
defer tunnel4.Close()
248+
249+
tunnel6, err := sshClient.Dial("tcp6", fmt.Sprintf("[::]:%s", port6))
250+
require.NoError(t, err)
251+
defer tunnel6.Close()
242252

243253
require.Eventuallyf(t, func() bool {
244254
s, ok := <-stats
245255
t.Logf("got stats with conn open: ok=%t, ConnectionCount=%d, SessionCountJetBrains=%d",
246256
ok, s.ConnectionCount, s.SessionCountJetBrains)
247257
return ok && s.ConnectionCount > 0 &&
248-
s.SessionCountJetBrains == 1
258+
s.SessionCountJetBrains == 2
249259
}, testutil.WaitLong, testutil.IntervalFast,
250260
"never saw stats with conn open",
251261
)
252262

253263
// Kill the server and connection after checking for the echo.
254-
requireEcho(t, tunneledConn)
255-
_ = echoServerCmd.Process.Kill()
256-
_ = tunneledConn.Close()
264+
requireEcho(t, tunnel4)
265+
requireEcho(t, tunnel6)
266+
_ = cmd4.Process.Kill()
267+
_ = cmd6.Process.Kill()
268+
_ = tunnel4.Close()
269+
_ = tunnel6.Close()
257270

258271
require.Eventuallyf(t, func() bool {
259272
s, ok := <-stats

agent/agentssh/agentssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
9999
}
100100

101101
forwardHandler := &ssh.ForwardedTCPHandler{}
102-
unixForwardHandler := &forwardedUnixHandler{log: logger}
102+
unixForwardHandler := newForwardedUnixHandler(logger)
103103

104104
metrics := newSSHServerMetrics(prometheusRegistry)
105105
s := &Server{

agent/agentssh/forward.go

+74-33
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@ package agentssh
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"io/fs"
68
"net"
79
"os"
810
"path/filepath"
911
"sync"
12+
"syscall"
1013

1114
"github.com/gliderlabs/ssh"
1215
gossh "golang.org/x/crypto/ssh"
@@ -33,22 +36,29 @@ type forwardedStreamLocalPayload struct {
3336
type forwardedUnixHandler struct {
3437
sync.Mutex
3538
log slog.Logger
36-
forwards map[string]net.Listener
39+
forwards map[forwardKey]net.Listener
40+
}
41+
42+
type forwardKey struct {
43+
sessionID string
44+
addr string
45+
}
46+
47+
func newForwardedUnixHandler(log slog.Logger) *forwardedUnixHandler {
48+
return &forwardedUnixHandler{
49+
log: log,
50+
forwards: make(map[forwardKey]net.Listener),
51+
}
3752
}
3853

3954
func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server, req *gossh.Request) (bool, []byte) {
4055
h.log.Debug(ctx, "handling SSH unix forward")
41-
h.Lock()
42-
if h.forwards == nil {
43-
h.forwards = make(map[string]net.Listener)
44-
}
45-
h.Unlock()
4656
conn, ok := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
4757
if !ok {
4858
h.log.Warn(ctx, "SSH unix forward request from client with no gossh connection")
4959
return false, nil
5060
}
51-
log := h.log.With(slog.F("remote_addr", conn.RemoteAddr()))
61+
log := h.log.With(slog.F("session_id", ctx.SessionID()), slog.F("remote_addr", conn.RemoteAddr()))
5262

5363
switch req.Type {
5464
case "streamlocal-forward@openssh.com":
@@ -62,14 +72,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
6272
addr := reqPayload.SocketPath
6373
log = log.With(slog.F("socket_path", addr))
6474
log.Debug(ctx, "request begin SSH unix forward")
75+
76+
key := forwardKey{
77+
sessionID: ctx.SessionID(),
78+
addr: addr,
79+
}
80+
6581
h.Lock()
66-
_, ok := h.forwards[addr]
82+
_, ok := h.forwards[key]
6783
h.Unlock()
6884
if ok {
69-
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded (maybe to another client?)",
70-
slog.F("socket_path", addr),
71-
)
72-
return false, nil
85+
// In cases where `ExitOnForwardFailure=yes` is set, returning false
86+
// here will cause the connection to be closed. To avoid this, and
87+
// to match OpenSSH behavior, we silently ignore the second forward
88+
// request.
89+
log.Warn(ctx, "SSH unix forward request for socket path that is already being forwarded on this session, ignoring")
90+
return true, nil
7391
}
7492

7593
// Create socket parent dir if not exists.
@@ -83,12 +101,20 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
83101
return false, nil
84102
}
85103

86-
ln, err := net.Listen("unix", addr)
104+
// Remove existing socket if it exists. We do not use os.Remove() here
105+
// so that directories are kept. Note that it's possible that we will
106+
// overwrite a regular file here. Both of these behaviors match OpenSSH,
107+
// however, which is why we unlink.
108+
err = unlink(addr)
109+
if err != nil && !errors.Is(err, fs.ErrNotExist) {
110+
log.Warn(ctx, "remove existing socket for SSH unix forward request", slog.Error(err))
111+
return false, nil
112+
}
113+
114+
lc := &net.ListenConfig{}
115+
ln, err := lc.Listen(ctx, "unix", addr)
87116
if err != nil {
88-
log.Warn(ctx, "listen on Unix socket for SSH unix forward request",
89-
slog.F("socket_path", addr),
90-
slog.Error(err),
91-
)
117+
log.Warn(ctx, "listen on Unix socket for SSH unix forward request", slog.Error(err))
92118
return false, nil
93119
}
94120
log.Debug(ctx, "SSH unix forward listening on socket")
@@ -99,7 +125,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
99125
//
100126
// This is also what the upstream TCP version of this code does.
101127
h.Lock()
102-
h.forwards[addr] = ln
128+
h.forwards[key] = ln
103129
h.Unlock()
104130
log.Debug(ctx, "SSH unix forward added to cache")
105131

@@ -115,9 +141,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
115141
c, err := ln.Accept()
116142
if err != nil {
117143
if !xerrors.Is(err, net.ErrClosed) {
118-
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request",
119-
slog.Error(err),
120-
)
144+
log.Warn(ctx, "accept on local Unix socket for SSH unix forward request", slog.Error(err))
121145
}
122146
// closed below
123147
log.Debug(ctx, "SSH unix forward listener closed")
@@ -131,10 +155,7 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
131155
go func() {
132156
ch, reqs, err := conn.OpenChannel("forwarded-streamlocal@openssh.com", payload)
133157
if err != nil {
134-
h.log.Warn(ctx, "open SSH unix forward channel to client",
135-
slog.F("socket_path", addr),
136-
slog.Error(err),
137-
)
158+
h.log.Warn(ctx, "open SSH unix forward channel to client", slog.Error(err))
138159
_ = c.Close()
139160
return
140161
}
@@ -144,12 +165,11 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
144165
}
145166

146167
h.Lock()
147-
ln2, ok := h.forwards[addr]
148-
if ok && ln2 == ln {
149-
delete(h.forwards, addr)
168+
if ln2, ok := h.forwards[key]; ok && ln2 == ln {
169+
delete(h.forwards, key)
150170
}
151171
h.Unlock()
152-
log.Debug(ctx, "SSH unix forward listener removed from cache", slog.F("path", addr))
172+
log.Debug(ctx, "SSH unix forward listener removed from cache")
153173
_ = ln.Close()
154174
}()
155175

@@ -162,13 +182,22 @@ func (h *forwardedUnixHandler) HandleSSHRequest(ctx ssh.Context, _ *ssh.Server,
162182
h.log.Warn(ctx, "parse cancel-streamlocal-forward@openssh.com (SSH unix forward) request payload from client", slog.Error(err))
163183
return false, nil
164184
}
165-
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("path", reqPayload.SocketPath))
185+
log.Debug(ctx, "request to cancel SSH unix forward", slog.F("socket_path", reqPayload.SocketPath))
186+
187+
key := forwardKey{
188+
sessionID: ctx.SessionID(),
189+
addr: reqPayload.SocketPath,
190+
}
191+
166192
h.Lock()
167-
ln, ok := h.forwards[reqPayload.SocketPath]
193+
ln, ok := h.forwards[key]
194+
delete(h.forwards, key)
168195
h.Unlock()
169-
if ok {
170-
_ = ln.Close()
196+
if !ok {
197+
log.Warn(ctx, "SSH unix forward not found in cache")
198+
return true, nil
171199
}
200+
_ = ln.Close()
172201
return true, nil
173202

174203
default:
@@ -209,3 +238,15 @@ func directStreamLocalHandler(_ *ssh.Server, _ *gossh.ServerConn, newChan gossh.
209238

210239
Bicopy(ctx, ch, dconn)
211240
}
241+
242+
// unlink removes files and unlike os.Remove, directories are kept.
243+
func unlink(path string) error {
244+
// Ignore EINTR like os.Remove, see ignoringEINTR in os/file_posix.go
245+
// for more details.
246+
for {
247+
err := syscall.Unlink(path)
248+
if !errors.Is(err, syscall.EINTR) {
249+
return err
250+
}
251+
}
252+
}

agent/agentssh/portinspection_supported.go

+19-9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package agentssh
44

55
import (
6+
"errors"
67
"fmt"
78
"os"
89

@@ -11,24 +12,33 @@ import (
1112
)
1213

1314
func getListeningPortProcessCmdline(port uint32) (string, error) {
14-
tabs, err := netstat.TCPSocks(func(s *netstat.SockTabEntry) bool {
15+
acceptFn := func(s *netstat.SockTabEntry) bool {
1516
return s.LocalAddr != nil && uint32(s.LocalAddr.Port) == port
16-
})
17-
if err != nil {
18-
return "", xerrors.Errorf("inspect port %d: %w", port, err)
1917
}
20-
if len(tabs) == 0 {
21-
return "", nil
18+
tabs, err := netstat.TCPSocks(acceptFn)
19+
tabs6, err6 := netstat.TCP6Socks(acceptFn)
20+
21+
// Only return the error if the other method found nothing.
22+
if (err != nil && len(tabs6) == 0) || (err6 != nil && len(tabs) == 0) {
23+
return "", xerrors.Errorf("inspect port %d: %w", port, errors.Join(err, err6))
2224
}
2325

24-
// Defensive check.
25-
if tabs[0].Process == nil {
26+
var proc *netstat.Process
27+
if len(tabs) > 0 {
28+
proc = tabs[0].Process
29+
} else if len(tabs6) > 0 {
30+
proc = tabs6[0].Process
31+
}
32+
if proc == nil {
33+
// Either nothing is listening on this port or we were unable to read the
34+
// process details (permission issues reading /proc/$pid/* potentially).
35+
// Or, perhaps /proc/net/tcp{,6} is not listing the port for some reason.
2636
return "", nil
2737
}
2838

2939
// The process name provided by go-netstat does not include the full command
3040
// line so grab that instead.
31-
pid := tabs[0].Process.Pid
41+
pid := proc.Pid
3242
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
3343
if err != nil {
3444
return "", xerrors.Errorf("read /proc/%d/cmdline: %w", pid, err)

0 commit comments

Comments
 (0)