Skip to content

Commit 3a5c2d7

Browse files
stirbyethanndicksonibetitsmikedannykoppingevgeniy-scherbina
authored
chore: cherry pick for release 2.22 (#17842)
Co-authored-by: Ethan <39577870+ethanndickson@users.noreply.github.com> Co-authored-by: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Co-authored-by: Danny Kopping <danny@coder.com> Co-authored-by: Yevhenii Shcherbina <evgeniy.shcherbina.es@gmail.com> Co-authored-by: Dean Sheather <dean@deansheather.com> Co-authored-by: Mathias Fredriksson <mafredri@gmail.com> Co-authored-by: Danny Kopping <dannykopping@gmail.com> Co-authored-by: Steven Masley <stevenmasley@gmail.com>
1 parent 2e96160 commit 3a5c2d7

28 files changed

+1512
-286
lines changed

cli/server.go

+9
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,15 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
739739
_ = sqlDB.Close()
740740
}()
741741

742+
if options.DeploymentValues.Prometheus.Enable {
743+
// At this stage we don't think the database name serves much purpose in these metrics.
744+
// It requires parsing the DSN to determine it, which requires pulling in another dependency
745+
// (i.e. https://github.com/jackc/pgx), but it's rather heavy.
746+
// The conn string (https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING) can
747+
// take different forms, which make parsing non-trivial.
748+
options.PrometheusRegistry.MustRegister(collectors.NewDBStatsCollector(sqlDB, ""))
749+
}
750+
742751
options.Database = database.New(sqlDB)
743752
ps, err := pubsub.New(ctx, logger.Named("pubsub"), sqlDB, dbURL)
744753
if err != nil {

cli/ssh.go

+127-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"io"
1010
"log"
11+
"net"
1112
"net/http"
1213
"net/url"
1314
"os"
@@ -66,6 +67,7 @@ func (r *RootCmd) ssh() *serpent.Command {
6667
stdio bool
6768
hostPrefix string
6869
hostnameSuffix string
70+
forceNewTunnel bool
6971
forwardAgent bool
7072
forwardGPG bool
7173
identityAgent string
@@ -85,6 +87,7 @@ func (r *RootCmd) ssh() *serpent.Command {
8587
containerUser string
8688
)
8789
client := new(codersdk.Client)
90+
wsClient := workspacesdk.New(client)
8891
cmd := &serpent.Command{
8992
Annotations: workspaceCommand,
9093
Use: "ssh <workspace>",
@@ -203,14 +206,14 @@ func (r *RootCmd) ssh() *serpent.Command {
203206
parsedEnv = append(parsedEnv, [2]string{k, v})
204207
}
205208

206-
deploymentSSHConfig := codersdk.SSHConfigResponse{
209+
cliConfig := codersdk.SSHConfigResponse{
207210
HostnamePrefix: hostPrefix,
208211
HostnameSuffix: hostnameSuffix,
209212
}
210213

211214
workspace, workspaceAgent, err := findWorkspaceAndAgentByHostname(
212215
ctx, inv, client,
213-
inv.Args[0], deploymentSSHConfig, disableAutostart)
216+
inv.Args[0], cliConfig, disableAutostart)
214217
if err != nil {
215218
return err
216219
}
@@ -275,10 +278,44 @@ func (r *RootCmd) ssh() *serpent.Command {
275278
return err
276279
}
277280

281+
// If we're in stdio mode, check to see if we can use Coder Connect.
282+
// We don't support Coder Connect over non-stdio coder ssh yet.
283+
if stdio && !forceNewTunnel {
284+
connInfo, err := wsClient.AgentConnectionInfoGeneric(ctx)
285+
if err != nil {
286+
return xerrors.Errorf("get agent connection info: %w", err)
287+
}
288+
coderConnectHost := fmt.Sprintf("%s.%s.%s.%s",
289+
workspaceAgent.Name, workspace.Name, workspace.OwnerName, connInfo.HostnameSuffix)
290+
exists, _ := workspacesdk.ExistsViaCoderConnect(ctx, coderConnectHost)
291+
if exists {
292+
defer cancel()
293+
294+
if networkInfoDir != "" {
295+
if err := writeCoderConnectNetInfo(ctx, networkInfoDir); err != nil {
296+
logger.Error(ctx, "failed to write coder connect net info file", slog.Error(err))
297+
}
298+
}
299+
300+
stopPolling := tryPollWorkspaceAutostop(ctx, client, workspace)
301+
defer stopPolling()
302+
303+
usageAppName := getUsageAppName(usageApp)
304+
if usageAppName != "" {
305+
closeUsage := client.UpdateWorkspaceUsageWithBodyContext(ctx, workspace.ID, codersdk.PostWorkspaceUsageRequest{
306+
AgentID: workspaceAgent.ID,
307+
AppName: usageAppName,
308+
})
309+
defer closeUsage()
310+
}
311+
return runCoderConnectStdio(ctx, fmt.Sprintf("%s:22", coderConnectHost), stdioReader, stdioWriter, stack)
312+
}
313+
}
314+
278315
if r.disableDirect {
279316
_, _ = fmt.Fprintln(inv.Stderr, "Direct connections disabled.")
280317
}
281-
conn, err := workspacesdk.New(client).
318+
conn, err := wsClient.
282319
DialAgent(ctx, workspaceAgent.ID, &workspacesdk.DialAgentOptions{
283320
Logger: logger,
284321
BlockEndpoints: r.disableDirect,
@@ -660,6 +697,12 @@ func (r *RootCmd) ssh() *serpent.Command {
660697
Value: serpent.StringOf(&containerUser),
661698
Hidden: true, // Hidden until this features is at least in beta.
662699
},
700+
{
701+
Flag: "force-new-tunnel",
702+
Description: "Force the creation of a new tunnel to the workspace, even if the Coder Connect tunnel is available.",
703+
Value: serpent.BoolOf(&forceNewTunnel),
704+
Hidden: true,
705+
},
663706
sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)),
664707
}
665708
return cmd
@@ -1372,12 +1415,13 @@ func setStatsCallback(
13721415
}
13731416

13741417
type sshNetworkStats struct {
1375-
P2P bool `json:"p2p"`
1376-
Latency float64 `json:"latency"`
1377-
PreferredDERP string `json:"preferred_derp"`
1378-
DERPLatency map[string]float64 `json:"derp_latency"`
1379-
UploadBytesSec int64 `json:"upload_bytes_sec"`
1380-
DownloadBytesSec int64 `json:"download_bytes_sec"`
1418+
P2P bool `json:"p2p"`
1419+
Latency float64 `json:"latency"`
1420+
PreferredDERP string `json:"preferred_derp"`
1421+
DERPLatency map[string]float64 `json:"derp_latency"`
1422+
UploadBytesSec int64 `json:"upload_bytes_sec"`
1423+
DownloadBytesSec int64 `json:"download_bytes_sec"`
1424+
UsingCoderConnect bool `json:"using_coder_connect"`
13811425
}
13821426

13831427
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
@@ -1448,6 +1492,80 @@ func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn,
14481492
}, nil
14491493
}
14501494

1495+
type coderConnectDialerContextKey struct{}
1496+
1497+
type coderConnectDialer interface {
1498+
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
1499+
}
1500+
1501+
func WithTestOnlyCoderConnectDialer(ctx context.Context, dialer coderConnectDialer) context.Context {
1502+
return context.WithValue(ctx, coderConnectDialerContextKey{}, dialer)
1503+
}
1504+
1505+
func testOrDefaultDialer(ctx context.Context) coderConnectDialer {
1506+
dialer, ok := ctx.Value(coderConnectDialerContextKey{}).(coderConnectDialer)
1507+
if !ok || dialer == nil {
1508+
return &net.Dialer{}
1509+
}
1510+
return dialer
1511+
}
1512+
1513+
func runCoderConnectStdio(ctx context.Context, addr string, stdin io.Reader, stdout io.Writer, stack *closerStack) error {
1514+
dialer := testOrDefaultDialer(ctx)
1515+
conn, err := dialer.DialContext(ctx, "tcp", addr)
1516+
if err != nil {
1517+
return xerrors.Errorf("dial coder connect host: %w", err)
1518+
}
1519+
if err := stack.push("tcp conn", conn); err != nil {
1520+
return err
1521+
}
1522+
1523+
agentssh.Bicopy(ctx, conn, &StdioRwc{
1524+
Reader: stdin,
1525+
Writer: stdout,
1526+
})
1527+
1528+
return nil
1529+
}
1530+
1531+
type StdioRwc struct {
1532+
io.Reader
1533+
io.Writer
1534+
}
1535+
1536+
func (*StdioRwc) Close() error {
1537+
return nil
1538+
}
1539+
1540+
func writeCoderConnectNetInfo(ctx context.Context, networkInfoDir string) error {
1541+
fs, ok := ctx.Value("fs").(afero.Fs)
1542+
if !ok {
1543+
fs = afero.NewOsFs()
1544+
}
1545+
if err := fs.MkdirAll(networkInfoDir, 0o700); err != nil {
1546+
return xerrors.Errorf("mkdir: %w", err)
1547+
}
1548+
1549+
// The VS Code extension obtains the PID of the SSH process to
1550+
// find the log file associated with a SSH session.
1551+
//
1552+
// We get the parent PID because it's assumed `ssh` is calling this
1553+
// command via the ProxyCommand SSH option.
1554+
networkInfoFilePath := filepath.Join(networkInfoDir, fmt.Sprintf("%d.json", os.Getppid()))
1555+
stats := &sshNetworkStats{
1556+
UsingCoderConnect: true,
1557+
}
1558+
rawStats, err := json.Marshal(stats)
1559+
if err != nil {
1560+
return xerrors.Errorf("marshal network stats: %w", err)
1561+
}
1562+
err = afero.WriteFile(fs, networkInfoFilePath, rawStats, 0o600)
1563+
if err != nil {
1564+
return xerrors.Errorf("write network stats: %w", err)
1565+
}
1566+
return nil
1567+
}
1568+
14511569
// Converts workspace name input to owner/workspace.agent format
14521570
// Possible valid input formats:
14531571
// workspace

cli/ssh_internal_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@ package cli
33
import (
44
"context"
55
"fmt"
6+
"io"
7+
"net"
68
"net/url"
79
"sync"
810
"testing"
911
"time"
1012

13+
gliderssh "github.com/gliderlabs/ssh"
1114
"github.com/stretchr/testify/assert"
1215
"github.com/stretchr/testify/require"
16+
"golang.org/x/crypto/ssh"
1317
"golang.org/x/xerrors"
1418

1519
"cdr.dev/slog"
@@ -220,6 +224,87 @@ func TestCloserStack_Timeout(t *testing.T) {
220224
testutil.TryReceive(ctx, t, closed)
221225
}
222226

227+
func TestCoderConnectStdio(t *testing.T) {
228+
t.Parallel()
229+
230+
ctx := testutil.Context(t, testutil.WaitShort)
231+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
232+
stack := newCloserStack(ctx, logger, quartz.NewMock(t))
233+
234+
clientOutput, clientInput := io.Pipe()
235+
serverOutput, serverInput := io.Pipe()
236+
defer func() {
237+
for _, c := range []io.Closer{clientOutput, clientInput, serverOutput, serverInput} {
238+
_ = c.Close()
239+
}
240+
}()
241+
242+
server := newSSHServer("127.0.0.1:0")
243+
ln, err := net.Listen("tcp", server.server.Addr)
244+
require.NoError(t, err)
245+
246+
go func() {
247+
_ = server.Serve(ln)
248+
}()
249+
t.Cleanup(func() {
250+
_ = server.Close()
251+
})
252+
253+
stdioDone := make(chan struct{})
254+
go func() {
255+
err = runCoderConnectStdio(ctx, ln.Addr().String(), clientOutput, serverInput, stack)
256+
assert.NoError(t, err)
257+
close(stdioDone)
258+
}()
259+
260+
conn, channels, requests, err := ssh.NewClientConn(&testutil.ReaderWriterConn{
261+
Reader: serverOutput,
262+
Writer: clientInput,
263+
}, "", &ssh.ClientConfig{
264+
// #nosec
265+
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
266+
})
267+
require.NoError(t, err)
268+
defer conn.Close()
269+
270+
sshClient := ssh.NewClient(conn, channels, requests)
271+
session, err := sshClient.NewSession()
272+
require.NoError(t, err)
273+
defer session.Close()
274+
275+
// We're not connected to a real shell
276+
err = session.Run("")
277+
require.NoError(t, err)
278+
err = sshClient.Close()
279+
require.NoError(t, err)
280+
_ = clientOutput.Close()
281+
282+
<-stdioDone
283+
}
284+
285+
type sshServer struct {
286+
server *gliderssh.Server
287+
}
288+
289+
func newSSHServer(addr string) *sshServer {
290+
return &sshServer{
291+
server: &gliderssh.Server{
292+
Addr: addr,
293+
Handler: func(s gliderssh.Session) {
294+
_, _ = io.WriteString(s.Stderr(), "Connected!")
295+
},
296+
},
297+
}
298+
}
299+
300+
func (s *sshServer) Serve(ln net.Listener) error {
301+
return s.server.Serve(ln)
302+
}
303+
304+
func (s *sshServer) Close() error {
305+
return s.server.Close()
306+
}
307+
223308
type fakeCloser struct {
224309
closes *[]*fakeCloser
225310
err error

0 commit comments

Comments
 (0)