diff --git a/cli/root.go b/cli/root.go index 2e94559e2d40f..cc9cae9488ba3 100644 --- a/cli/root.go +++ b/cli/root.go @@ -98,6 +98,7 @@ func Core() []*cobra.Command { users(), versionCmd(), workspaceAgent(), + vscodeipcCmd(), } } diff --git a/cli/speedtest.go b/cli/speedtest.go index 0761b558ef39f..7b4d29cfeb789 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -71,7 +71,7 @@ func speedtest() *cobra.Command { return ctx.Err() case <-ticker.C: } - dur, err := conn.Ping(ctx) + dur, p2p, err := conn.Ping(ctx) if err != nil { continue } @@ -80,7 +80,7 @@ func speedtest() *cobra.Command { continue } peer := status.Peer[status.Peers()[0]] - if peer.CurAddr == "" && direct { + if !p2p && direct { cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) continue } diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 27d1335bdfb57..8e828b4e047d5 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -65,6 +65,8 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + workspace, err := client.Workspace(context.Background(), workspace.ID) + require.NoError(t, err) return client, workspace, agentToken } diff --git a/cli/vscodeipc.go b/cli/vscodeipc.go new file mode 100644 index 0000000000000..262fbf69aae8e --- /dev/null +++ b/cli/vscodeipc.go @@ -0,0 +1,88 @@ +package cli + +import ( + "fmt" + "net" + "net/http" + "net/url" + + "github.com/google/uuid" + "github.com/spf13/cobra" + "golang.org/x/xerrors" + + "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/vscodeipc" + "github.com/coder/coder/codersdk" +) + +// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages. +// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder +func vscodeipcCmd() *cobra.Command { + var ( + rawURL string + token string + port uint16 + ) + cmd := &cobra.Command{ + Use: "vscodeipc ", + Args: cobra.ExactArgs(1), + SilenceUsage: true, + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + if rawURL == "" { + return xerrors.New("CODER_URL must be set!") + } + // token is validated in a header on each request to prevent + // unauthenticated clients from connecting. + if token == "" { + return xerrors.New("CODER_TOKEN must be set!") + } + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + return xerrors.Errorf("listen: %w", err) + } + defer listener.Close() + addr, ok := listener.Addr().(*net.TCPAddr) + if !ok { + return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr()) + } + url, err := url.Parse(rawURL) + if err != nil { + return err + } + agentID, err := uuid.Parse(args[0]) + if err != nil { + return err + } + client := codersdk.New(url) + client.SetSessionToken(token) + + handler, closer, err := vscodeipc.New(cmd.Context(), client, agentID, nil) + if err != nil { + return err + } + defer closer.Close() + // nolint:gosec + server := http.Server{ + Handler: handler, + } + defer server.Close() + _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String()) + errChan := make(chan error, 1) + go func() { + err := server.Serve(listener) + errChan <- err + }() + select { + case <-cmd.Context().Done(): + return cmd.Context().Err() + case err := <-errChan: + return err + } + }, + } + cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!") + cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!") + cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!") + return cmd +} diff --git a/cli/vscodeipc/vscodeipc.go b/cli/vscodeipc/vscodeipc.go new file mode 100644 index 0000000000000..9d4e094564da2 --- /dev/null +++ b/cli/vscodeipc/vscodeipc.go @@ -0,0 +1,313 @@ +package vscodeipc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "golang.org/x/crypto/ssh" + "golang.org/x/xerrors" + "tailscale.com/tailcfg" + + "github.com/coder/coder/agent" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" +) + +const AuthHeader = "Coder-IPC-Token" + +// New creates a VS Code IPC client that can be used to communicate with workspaces. +// +// Creating this IPC was required instead of using SSH, because we're unable to get +// connection information to display in the bottom-bar when using SSH. It's possible +// we could jank around this (maybe by using a temporary SSH host), but that's not +// ideal. +// +// This persists a single workspace connection, and lets you execute commands, check +// for network information, and forward ports. +// +// The VS Code extension is located at https://github.com/coder/vscode-coder. The +// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc` +// which calls this function. This API must maintain backward compatibility with +// the extension to support prior versions of Coder. +func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) { + if options == nil { + options = &codersdk.DialWorkspaceAgentOptions{} + } + // We need this to track upload and download! + options.EnableTrafficStats = true + + agentConn, err := client.DialWorkspaceAgent(ctx, agentID, options) + if err != nil { + return nil, nil, err + } + api := &api{ + agentConn: agentConn, + } + r := chi.NewRouter() + // This is to prevent unauthorized clients on the same machine from executing + // requests on behalf of the workspace. + r.Use(sessionTokenMiddleware(client.SessionToken())) + r.Route("/v1", func(r chi.Router) { + r.Get("/port/{port}", api.port) + r.Get("/network", api.network) + r.Post("/execute", api.execute) + }) + return r, api, nil +} + +type api struct { + agentConn *codersdk.AgentConn + sshClient *ssh.Client + sshClientErr error + sshClientOnce sync.Once + + lastNetwork time.Time +} + +func (api *api) Close() error { + if api.sshClient != nil { + api.sshClient.Close() + } + return api.agentConn.Close() +} + +type NetworkResponse struct { + P2P bool `json:"p2p"` + Latency float64 `json:"latency"` + PreferredDERP string `json:"preferred_derp"` + DERPLatency map[string]float64 `json:"derp_latency"` + UploadBytesSec int64 `json:"upload_bytes_sec"` + DownloadBytesSec int64 `json:"download_bytes_sec"` +} + +// port accepts an HTTP request to dial a port on the workspace agent. +// It uses an HTTP connection upgrade to transfer the connection to TCP. +func (api *api) port(w http.ResponseWriter, r *http.Request) { + port, err := strconv.Atoi(chi.URLParam(r, "port")) + if err != nil { + httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{ + Message: "Port must be an integer!", + }) + return + } + remoteConn, err := api.agentConn.DialContext(r.Context(), "tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + httpapi.InternalServerError(w, err) + return + } + defer remoteConn.Close() + + // Upgrade an switch to TCP! + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "tcp") + w.WriteHeader(http.StatusSwitchingProtocols) + + hijacker, ok := w.(http.Hijacker) + if !ok { + httpapi.InternalServerError(w, xerrors.Errorf("unable to hijack connection: %T", w)) + return + } + + localConn, brw, err := hijacker.Hijack() + if err != nil { + httpapi.InternalServerError(w, err) + return + } + defer localConn.Close() + + _ = brw.Flush() + agent.Bicopy(r.Context(), localConn, remoteConn) +} + +// network returns network information about the workspace. +func (api *api) network(w http.ResponseWriter, r *http.Request) { + // Ping the workspace agent to get the latency. + latency, p2p, err := api.agentConn.Ping(r.Context()) + if err != nil { + httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to ping the workspace agent.", + Detail: err.Error(), + }) + return + } + + node := api.agentConn.Node() + derpMap := api.agentConn.DERPMap() + derpLatency := map[string]float64{} + + // Convert DERP region IDs to friendly names for display in the UI. + for rawRegion, latency := range node.DERPLatency { + regionParts := strings.SplitN(rawRegion, "-", 2) + regionID, err := strconv.Atoi(regionParts[0]) + if err != nil { + continue + } + region, found := derpMap.Regions[regionID] + if !found { + // It's possible that a workspace agent is using an old DERPMap + // and reports regions that do not exist. If that's the case, + // report the region as unknown! + region = &tailcfg.DERPRegion{ + RegionID: regionID, + RegionName: fmt.Sprintf("Unnamed %d", regionID), + } + } + // Convert the microseconds to milliseconds. + derpLatency[region.RegionName] = latency * 1000 + } + + totalRx := uint64(0) + totalTx := uint64(0) + for _, stat := range api.agentConn.ExtractTrafficStats() { + totalRx += stat.RxBytes + totalTx += stat.TxBytes + } + // Tracking the time since last request is required because + // ExtractTrafficStats() resets its counters after each call. + dur := time.Since(api.lastNetwork) + uploadSecs := float64(totalTx) / dur.Seconds() + downloadSecs := float64(totalRx) / dur.Seconds() + + api.lastNetwork = time.Now() + + httpapi.Write(r.Context(), w, http.StatusOK, NetworkResponse{ + P2P: p2p, + Latency: float64(latency.Microseconds()) / 1000, + PreferredDERP: derpMap.Regions[node.PreferredDERP].RegionName, + DERPLatency: derpLatency, + UploadBytesSec: int64(uploadSecs), + DownloadBytesSec: int64(downloadSecs), + }) +} + +type ExecuteRequest struct { + Command string `json:"command"` + Stdin string `json:"stdin"` +} + +type ExecuteResponse struct { + Data string `json:"data"` + ExitCode *int `json:"exit_code"` +} + +// execute runs the command provided, streams the output back, and returns the exit code. +func (api *api) execute(w http.ResponseWriter, r *http.Request) { + var req ExecuteRequest + if !httpapi.Read(r.Context(), w, r, &req) { + return + } + api.sshClientOnce.Do(func() { + // The SSH client is lazily created because it's not needed for + // all requests. It's only needed for the execute endpoint. + // + // It's alright if this fails on the first execution, because + // a new instance of this API is created for each remote SSH request. + api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background()) + }) + if api.sshClientErr != nil { + httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to create SSH client.", + Detail: api.sshClientErr.Error(), + }) + return + } + session, err := api.sshClient.NewSession() + if err != nil { + httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to create SSH session.", + Detail: err.Error(), + }) + return + } + defer session.Close() + f, ok := w.(http.Flusher) + if !ok { + httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ + Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w), + }) + return + } + + execWriter := &execWriter{w, f} + session.Stdout = execWriter + session.Stderr = execWriter + session.Stdin = strings.NewReader(req.Stdin) + err = session.Start(req.Command) + if err != nil { + httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to start SSH session.", + Detail: err.Error(), + }) + return + } + err = session.Wait() + + writeExit := func(exitCode int) { + data, _ := json.Marshal(&ExecuteResponse{ + ExitCode: &exitCode, + }) + _, _ = w.Write(data) + f.Flush() + } + + if err != nil { + var exitError *ssh.ExitError + if errors.As(err, &exitError) { + writeExit(exitError.ExitStatus()) + return + } + } + writeExit(0) +} + +type execWriter struct { + w http.ResponseWriter + f http.Flusher +} + +func (e *execWriter) Write(data []byte) (int, error) { + js, err := json.Marshal(&ExecuteResponse{ + Data: string(data), + }) + if err != nil { + return 0, err + } + _, err = e.w.Write(js) + if err != nil { + return 0, err + } + e.f.Flush() + return len(data), nil +} + +func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get(AuthHeader) + if token == "" { + httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ + Message: fmt.Sprintf("A session token must be provided in the `%s` header.", AuthHeader), + }) + return + } + if token != sessionToken { + httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{ + Message: "The session token provided doesn't match the one used to create the client.", + }) + return + } + w.Header().Set("Access-Control-Allow-Origin", "*") + h.ServeHTTP(w, r) + }) + } +} diff --git a/cli/vscodeipc/vscodeipc_test.go b/cli/vscodeipc/vscodeipc_test.go new file mode 100644 index 0000000000000..5213c2422be6a --- /dev/null +++ b/cli/vscodeipc/vscodeipc_test.go @@ -0,0 +1,202 @@ +package vscodeipc_test + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "runtime" + "testing" + + "github.com/google/uuid" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "nhooyr.io/websocket" + + "github.com/coder/coder/agent" + "github.com/coder/coder/cli/vscodeipc" + "github.com/coder/coder/coderd/httpapi" + "github.com/coder/coder/codersdk" + "github.com/coder/coder/tailnet" + "github.com/coder/coder/tailnet/tailnettest" + "github.com/coder/coder/testutil" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + +func TestVSCodeIPC(t *testing.T) { + t.Parallel() + ctx := context.Background() + + id := uuid.New() + derpMap := tailnettest.RunDERPAndSTUN(t) + coordinator := tailnet.NewCoordinator() + t.Cleanup(func() { + _ = coordinator.Close() + }) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id): + assert.Equal(t, r.Method, http.MethodGet) + httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{ + DERPMap: derpMap, + }) + return + case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id): + assert.Equal(t, r.Method, http.MethodGet) + ws, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) + _ = coordinator.ServeClient(conn, uuid.New(), id) + return + case "/api/v2/workspaceagents/me/version": + assert.Equal(t, r.Method, http.MethodPost) + w.WriteHeader(http.StatusOK) + return + case "/api/v2/workspaceagents/me/metadata": + assert.Equal(t, r.Method, http.MethodGet) + httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{ + DERPMap: derpMap, + }) + return + case "/api/v2/workspaceagents/me/coordinate": + assert.Equal(t, r.Method, http.MethodGet) + ws, err := websocket.Accept(w, r, nil) + require.NoError(t, err) + conn := websocket.NetConn(ctx, ws, websocket.MessageBinary) + _ = coordinator.ServeAgent(conn, id) + return + case "/api/v2/workspaceagents/me/report-stats": + assert.Equal(t, r.Method, http.MethodPost) + w.WriteHeader(http.StatusOK) + return + case "/": + w.WriteHeader(http.StatusOK) + return + default: + t.Fatalf("unexpected request %s", r.URL.Path) + } + })) + t.Cleanup(srv.Close) + srvURL, _ := url.Parse(srv.URL) + + client := codersdk.New(srvURL) + token := uuid.New().String() + client.SetSessionToken(token) + agentConn := agent.New(agent.Options{ + Client: client, + Filesystem: afero.NewMemMapFs(), + TempDir: t.TempDir(), + }) + t.Cleanup(func() { + _ = agentConn.Close() + }) + + handler, closer, err := vscodeipc.New(ctx, client, id, nil) + require.NoError(t, err) + t.Cleanup(func() { + _ = closer.Close() + }) + + // Ensure that we're actually connected! + require.Eventually(t, func() bool { + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/network", nil) + req.Header.Set(vscodeipc.AuthHeader, token) + handler.ServeHTTP(res, req) + network := &vscodeipc.NetworkResponse{} + err = json.NewDecoder(res.Body).Decode(&network) + assert.NoError(t, err) + return network.Latency != 0 + }, testutil.WaitLong, testutil.IntervalFast) + + _, port, err := net.SplitHostPort(srvURL.Host) + require.NoError(t, err) + + t.Run("NoSessionToken", func(t *testing.T) { + t.Parallel() + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) + handler.ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Code) + }) + + t.Run("MismatchedSessionToken", func(t *testing.T) { + t.Parallel() + res := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) + req.Header.Set(vscodeipc.AuthHeader, uuid.NewString()) + handler.ServeHTTP(res, req) + require.Equal(t, http.StatusUnauthorized, res.Code) + }) + + t.Run("Port", func(t *testing.T) { + // Tests that the port endpoint can be used for forward traffic. + // For this test, we simply use the already listening httptest server. + t.Parallel() + input, output := net.Pipe() + defer input.Close() + defer output.Close() + res := &hijackable{httptest.NewRecorder(), output} + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/v1/port/%s", port), nil) + req.Header.Set(vscodeipc.AuthHeader, token) + go handler.ServeHTTP(res, req) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil) + require.NoError(t, err) + client := http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return input, nil + }, + }, + } + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("Execute", func(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("Execute isn't supported on Windows yet!") + return + } + + res := httptest.NewRecorder() + data, _ := json.Marshal(vscodeipc.ExecuteRequest{ + Command: "echo test", + }) + req := httptest.NewRequest(http.MethodPost, "/v1/execute", bytes.NewReader(data)) + req.Header.Set(vscodeipc.AuthHeader, token) + handler.ServeHTTP(res, req) + + decoder := json.NewDecoder(res.Body) + var msg vscodeipc.ExecuteResponse + err = decoder.Decode(&msg) + require.NoError(t, err) + require.Equal(t, "test\n", msg.Data) + err = decoder.Decode(&msg) + require.NoError(t, err) + require.Equal(t, 0, *msg.ExitCode) + }) +} + +type hijackable struct { + *httptest.ResponseRecorder + conn net.Conn +} + +func (h *hijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return h.conn, bufio.NewReadWriter(bufio.NewReader(h.conn), bufio.NewWriter(h.conn)), nil +} diff --git a/cli/vscodeipc_test.go b/cli/vscodeipc_test.go new file mode 100644 index 0000000000000..1edb52102841c --- /dev/null +++ b/cli/vscodeipc_test.go @@ -0,0 +1,44 @@ +package cli_test + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/testutil" +) + +func TestVSCodeIPC(t *testing.T) { + t.Parallel() + // Ensures the vscodeipc command outputs it's running port! + // This signifies to the caller that it's ready to accept requests. + t.Run("PortOutputs", func(t *testing.T) { + t.Parallel() + client, workspace, _ := setupWorkspaceForAgent(t, nil) + cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(), + "--token", client.SessionToken(), "--url", client.URL.String()) + rdr, wtr := io.Pipe() + cmd.SetOut(wtr) + ctx, cancelFunc := testutil.Context(t) + defer cancelFunc() + done := make(chan error, 1) + go func() { + err := cmd.ExecuteContext(ctx) + done <- err + }() + + buf := make([]byte, 64) + require.Eventually(t, func() bool { + t.Log("Looking for address!") + var err error + _, err = rdr.Read(buf) + return err == nil + }, testutil.WaitMedium, testutil.IntervalFast) + t.Logf("Address: %s\n", buf) + + cancelFunc() + <-done + }) +} diff --git a/codersdk/agentconn.go b/codersdk/agentconn.go index dc68bad67d6b9..33bf16bcc3406 100644 --- a/codersdk/agentconn.go +++ b/codersdk/agentconn.go @@ -139,7 +139,9 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool { return c.Conn.AwaitReachable(ctx, TailnetIP) } -func (c *AgentConn) Ping(ctx context.Context) (time.Duration, error) { +// Ping pings the agent and returns the round-trip time. +// The bool returns true if the ping was made P2P. +func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 63b895a2091e9..258b108737b9c 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -346,7 +346,8 @@ func (c *Client) ListenWorkspaceAgent(ctx context.Context) (net.Conn, error) { type DialWorkspaceAgentOptions struct { Logger slog.Logger // BlockEndpoints forced a direct connection through DERP. - BlockEndpoints bool + BlockEndpoints bool + EnableTrafficStats bool } func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *DialWorkspaceAgentOptions) (*AgentConn, error) { @@ -369,10 +370,11 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti ip := tailnet.IP() conn, err := tailnet.NewConn(&tailnet.Options{ - Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, - DERPMap: connInfo.DERPMap, - Logger: options.Logger, - BlockEndpoints: options.BlockEndpoints, + Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, + DERPMap: connInfo.DERPMap, + Logger: options.Logger, + BlockEndpoints: options.BlockEndpoints, + EnableTrafficStats: options.EnableTrafficStats, }) if err != nil { return nil, xerrors.Errorf("create tailnet: %w", err) diff --git a/enterprise/coderd/replicas_test.go b/enterprise/coderd/replicas_test.go index 31b468f3427f7..4713c276adaf1 100644 --- a/enterprise/coderd/replicas_test.go +++ b/enterprise/coderd/replicas_test.go @@ -81,7 +81,7 @@ func TestReplicas(t *testing.T) { require.Eventually(t, func() bool { ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancelFunc() - _, err = conn.Ping(ctx) + _, _, err = conn.Ping(ctx) return err == nil }, testutil.WaitLong, testutil.IntervalFast) _ = conn.Close() @@ -124,7 +124,7 @@ func TestReplicas(t *testing.T) { require.Eventually(t, func() bool { ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.IntervalSlow) defer cancelFunc() - _, err = conn.Ping(ctx) + _, _, err = conn.Ping(ctx) return err == nil }, testutil.WaitLong, testutil.IntervalFast) _ = conn.Close() diff --git a/flake.nix b/flake.nix index dfc44b91df36f..39ab42744a65e 100644 --- a/flake.nix +++ b/flake.nix @@ -18,6 +18,7 @@ buildInputs = with pkgs; [ bash bat + cairo drpc.defaultPackage.${system} exa getopt @@ -34,7 +35,10 @@ nodejs openssh openssl + pango + pixman postgresql + pkg-config protoc-gen-go ripgrep shellcheck diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index ae4e171c30b4b..1a93977efe204 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -141,7 +141,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *codersdk.AgentConn) for i := 0; i < pingAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDisco ping attempt %d/%d...\n", i+1, pingAttempts) pingCtx, cancel := context.WithTimeout(ctx, defaultRequestTimeout) - _, err := conn.Ping(pingCtx) + _, _, err := conn.Ping(pingCtx) cancel() if err == nil { break diff --git a/tailnet/conn.go b/tailnet/conn.go index 10dec35287a1d..a23ef3f5a184d 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -77,6 +77,7 @@ func NewConn(options *Options) (*Conn, error) { nodePublicKey := nodePrivateKey.Public() netMap := &netmap.NetworkMap{ + DERPMap: options.DERPMap, NodeKey: nodePublicKey, PrivateKey: nodePrivateKey, Addresses: options.Addresses, @@ -407,26 +408,34 @@ func (c *Conn) Status() *ipnstate.Status { } // Ping sends a Disco ping to the Wireguard engine. -func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, error) { +// The bool returned is true if the ping was performed P2P. +func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, error) { errCh := make(chan error, 1) - durCh := make(chan time.Duration, 1) + prChan := make(chan *ipnstate.PingResult, 1) go c.wireguardEngine.Ping(ip, tailcfg.PingDisco, func(pr *ipnstate.PingResult) { if pr.Err != "" { errCh <- xerrors.New(pr.Err) return } - durCh <- time.Duration(pr.LatencySeconds * float64(time.Second)) + prChan <- pr }) select { case err := <-errCh: - return 0, err + return 0, false, err case <-ctx.Done(): - return 0, ctx.Err() - case dur := <-durCh: - return dur, nil + return 0, false, ctx.Err() + case pr := <-prChan: + return time.Duration(pr.LatencySeconds * float64(time.Second)), pr.Endpoint != "", nil } } +// DERPMap returns the currently set DERP mapping. +func (c *Conn) DERPMap() *tailcfg.DERPMap { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.netMap.DERPMap +} + // AwaitReachable pings the provided IP continually until the // address is reachable. It's the callers responsibility to provide // a timeout, otherwise this function will block forever. @@ -445,7 +454,7 @@ func (c *Conn) AwaitReachable(ctx context.Context, ip netip.Addr) bool { ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) defer cancel() - _, err := c.Ping(ctx, ip) + _, _, err := c.Ping(ctx, ip) if err == nil { completed() } @@ -523,20 +532,7 @@ func (c *Conn) sendNode() { c.nodeChanged = true return } - node := &Node{ - ID: c.netMap.SelfNode.ID, - AsOf: database.Now(), - Key: c.netMap.SelfNode.Key, - Addresses: c.netMap.SelfNode.Addresses, - AllowedIPs: c.netMap.SelfNode.AllowedIPs, - DiscoKey: c.magicConn.DiscoPublicKey(), - Endpoints: c.lastEndpoints, - PreferredDERP: c.lastPreferredDERP, - DERPLatency: c.lastDERPLatency, - } - if c.blockEndpoints { - node.Endpoints = nil - } + node := c.selfNode() nodeCallback := c.nodeCallback if nodeCallback == nil { return @@ -557,6 +553,31 @@ func (c *Conn) sendNode() { }() } +// Node returns the last node that was sent to the node callback. +func (c *Conn) Node() *Node { + c.lastMutex.Lock() + defer c.lastMutex.Unlock() + return c.selfNode() +} + +func (c *Conn) selfNode() *Node { + node := &Node{ + ID: c.netMap.SelfNode.ID, + AsOf: database.Now(), + Key: c.netMap.SelfNode.Key, + Addresses: c.netMap.SelfNode.Addresses, + AllowedIPs: c.netMap.SelfNode.AllowedIPs, + DiscoKey: c.magicConn.DiscoPublicKey(), + Endpoints: c.lastEndpoints, + PreferredDERP: c.lastPreferredDERP, + DERPLatency: c.lastDERPLatency, + } + if c.blockEndpoints { + node.Endpoints = nil + } + return node +} + // This and below is taken _mostly_ verbatim from Tailscale: // https://github.com/tailscale/tailscale/blob/c88bd53b1b7b2fcf7ba302f2e53dd1ce8c32dad4/tsnet/tsnet.go#L459-L494