diff --git a/Makefile b/Makefile index 9040a891700e1..a5341ee79f753 100644 --- a/Makefile +++ b/Makefile @@ -636,7 +636,8 @@ GEN_FILES := \ coderd/database/pubsub/psmock/psmock.go \ agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ - coderd/httpmw/loggermw/loggermock/loggermock.go + coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go # all gen targets should be added here and to gen/mark-fresh gen: gen/db gen/golden-files $(GEN_FILES) @@ -686,6 +687,7 @@ gen/mark-fresh: agent/agentcontainers/acmock/acmock.go \ agent/agentcontainers/dcspec/dcspec_gen.go \ coderd/httpmw/loggermw/loggermock/loggermock.go \ + codersdk/workspacesdk/agentconnmock/agentconnmock.go \ " for file in $$files; do @@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g go generate ./coderd/httpmw/loggermw/loggermock/ touch "$@" +codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go + go generate ./codersdk/workspacesdk/agentconnmock/ + touch "$@" + agent/agentcontainers/dcspec/dcspec_gen.go: \ node_modules/.installed \ agent/agentcontainers/dcspec/devContainer.base.schema.json \ diff --git a/agent/agent_test.go b/agent/agent_test.go index 52d8cfc09d573..2425fd81a0ead 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) { switch l.Addr().Network() { case "tcp": - conn, err = agentConn.Conn.DialContextTCP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp) case "udp": - conn, err = agentConn.Conn.DialContextUDP(ctx, ipp) + conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp) default: t.Fatalf("unknown network: %s", l.Addr().Network()) } @@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) { }) // Setup a client connection. - newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn { + newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn { conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()}, DERPMap: derpMap, @@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) { // Connect from a second client and make sure it uses the new DERP map. conn2 := newClientConn(newDerpMap, "client2") - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) t.Log("conn2 got the new DERPMap") // If the first client gets a DERP map update, it should be able to // reconnect just fine. - conn1.SetDERPMap(newDerpMap) - require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs()) + conn1.TailnetConn().SetDERPMap(newDerpMap) + require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs()) t.Log("set the new DERPMap on conn1") ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() @@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort( } func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) ( - *workspacesdk.AgentConn, + workspacesdk.AgentConn, *agenttest.Client, <-chan *proto.Stats, afero.Fs, diff --git a/cli/ping.go b/cli/ping.go index 0b9fde5c62eb8..29af06affeaee 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command { } defer conn.Close() - derpMap := conn.DERPMap() + derpMap := conn.TailnetConn().DERPMap() diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second) defer diagCancel() @@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command { // Silent ping to determine whether we should show diags _, didP2p, _, _ := conn.Ping(ctx) - ni := conn.GetNetInfo() + ni := conn.TailnetConn().GetNetInfo() connDiags := cliui.ConnDiags{ DisableDirect: r.disableDirect, LocalNetInfo: ni, diff --git a/cli/portforward.go b/cli/portforward.go index 59c1f5827b06f..1b055d9e4362e 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command { func listenAndPortForward( ctx context.Context, inv *serpent.Invocation, - conn *workspacesdk.AgentConn, + conn workspacesdk.AgentConn, wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, diff --git a/cli/speedtest.go b/cli/speedtest.go index 3827b45125842..86d0e6a9ee63c 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command { if err != nil { continue } - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { continue } @@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command { outputResult.Intervals[i] = interval } } - conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) + conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits) out, err := formatter.Format(inv.Context(), outputResult) if err != nil { return err diff --git a/cli/ssh.go b/cli/ssh.go index bc2bb24235ad2..a2f0db7327bef 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command { } err = sshSession.Wait() - conn.SendDisconnectedTelemetry() + conn.TailnetConn().SendDisconnectedTelemetry() if err != nil { if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) { // Clear the error since it's not useful beyond @@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName { func setStatsCallback( ctx context.Context, - agentConn *workspacesdk.AgentConn, + agentConn workspacesdk.AgentConn, logger slog.Logger, networkInfoDir string, networkInfoInterval time.Duration, @@ -1437,7 +1437,7 @@ func setStatsCallback( now := time.Now() cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{}) - agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb) + agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb) return errCh, nil } @@ -1451,13 +1451,13 @@ type sshNetworkStats struct { UsingCoderConnect bool `json:"using_coder_connect"` } -func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { +func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) { latency, p2p, pingResult, err := agentConn.Ping(ctx) if err != nil { return nil, err } - node := agentConn.Node() - derpMap := agentConn.DERPMap() + node := agentConn.TailnetConn().Node() + derpMap := agentConn.TailnetConn().DERPMap() totalRx := uint64(0) totalTx := uint64(0) diff --git a/coderd/coderd.go b/coderd/coderd.go index a934536c0aef0..8ab204f8a31ef 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -325,6 +325,9 @@ func New(options *Options) *API { }) } + if options.PrometheusRegistry == nil { + options.PrometheusRegistry = prometheus.NewRegistry() + } if options.Authorizer == nil { options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry) if buildinfo.IsDev() { @@ -381,9 +384,6 @@ func New(options *Options) *API { if options.FilesRateLimit == 0 { options.FilesRateLimit = 12 } - if options.PrometheusRegistry == nil { - options.PrometheusRegistry = prometheus.NewRegistry() - } if options.Clock == nil { options.Clock = quartz.NewReal() } diff --git a/coderd/tailnet.go b/coderd/tailnet.go index 172edea95a586..cdcf657fe732d 100644 --- a/coderd/tailnet.go +++ b/coderd/tailnet.go @@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) ( }, nil } -func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) { +func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { var ( - conn *workspacesdk.AgentConn + conn workspacesdk.AgentConn ret func() ) diff --git a/coderd/workspaceagents_internal_test.go b/coderd/workspaceagents_internal_test.go new file mode 100644 index 0000000000000..c7520f05ab503 --- /dev/null +++ b/coderd/workspaceagents_internal_test.go @@ -0,0 +1,186 @@ +package coderd + +import ( + "bytes" + "context" + "database/sql" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "net/url" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/workspaceapps/appurl" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" + "github.com/coder/coder/v2/codersdk/wsjson" + "github.com/coder/coder/v2/tailnet" + "github.com/coder/coder/v2/tailnet/tailnettest" + "github.com/coder/coder/v2/testutil" + "github.com/coder/websocket" +) + +type fakeAgentProvider struct { + agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) +} + +func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy { + panic("unimplemented") +} + +func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + if f.agentConn != nil { + return f.agentConn(ctx, agentID) + } + + panic("unimplemented") +} + +func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + panic("unimplemented") +} + +func (fakeAgentProvider) Close() error { + return nil +} + +func TestWatchAgentContainers(t *testing.T) { + t.Parallel() + + t.Run("WebSocketClosesProperly", func(t *testing.T) { + t.Parallel() + + // This test ensures that the agent containers `/watch` websocket can gracefully + // handle the underlying websocket unexpectedly closing. This test was created in + // response to this issue: https://github.com/coder/coder/issues/19372 + + var ( + ctx = testutil.Context(t, testutil.WaitShort) + logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd") + + mCtrl = gomock.NewController(t) + mDB = dbmock.NewMockStore(mCtrl) + mCoordinator = tailnettest.NewMockCoordinator(mCtrl) + mAgentConn = agentconnmock.NewMockAgentConn(mCtrl) + + fAgentProvider = fakeAgentProvider{ + agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) { + return mAgentConn, func() {}, nil + }, + } + + workspaceID = uuid.New() + agentID = uuid.New() + resourceID = uuid.New() + jobID = uuid.New() + buildID = uuid.New() + + containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse) + + r = chi.NewMux() + + api = API{ + ctx: ctx, + Options: &Options{ + AgentInactiveDisconnectTimeout: testutil.WaitShort, + Database: mDB, + Logger: logger, + DeploymentValues: &codersdk.DeploymentValues{}, + TailnetCoordinator: tailnettest.NewFakeCoordinator(), + }, + } + ) + + var tailnetCoordinator tailnet.Coordinator = mCoordinator + api.TailnetCoordinator.Store(&tailnetCoordinator) + api.agentProvider = fAgentProvider + + // Setup: Allow `ExtractWorkspaceAgentParams` to complete. + mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{ + ID: agentID, + ResourceID: resourceID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, + }, nil) + mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{ + ID: resourceID, + JobID: jobID, + }, nil) + mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{ + ID: jobID, + Type: database.ProvisionerJobTypeWorkspaceBuild, + }, nil) + mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{ + WorkspaceID: workspaceID, + ID: buildID, + }, nil) + + // And: Allow `db2dsk.WorkspaceAgent` to complete. + mCoordinator.EXPECT().Node(gomock.Any()).Return(nil) + + // And: Allow `WatchContainers` to be called, returing our `containersCh` channel. + mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()). + Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil) + + // And: We mount the HTTP Handler + r.With(httpmw.ExtractWorkspaceAgentParam(mDB)). + Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers) + + // Given: We create the HTTP server + srv := httptest.NewServer(r) + defer srv.Close() + + // And: Dial the WebSocket + wsURL := strings.Replace(srv.URL, "http://", "ws://", 1) + conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil) + require.NoError(t, err) + if resp.Body != nil { + defer resp.Body.Close() + } + + // And: Create a streaming decoder + decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger) + defer decoder.Close() + decodeCh := decoder.Chan() + + // And: We can successfully send through the channel. + testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{ + Containers: []codersdk.WorkspaceAgentContainer{{ + ID: "test-container-id", + }}, + }) + + // And: Receive the data. + containerResp := testutil.RequireReceive(ctx, t, decodeCh) + require.Len(t, containerResp.Containers, 1) + require.Equal(t, "test-container-id", containerResp.Containers[0].ID) + + // When: We close the `containersCh` + close(containersCh) + + // Then: We expect `decodeCh` to be closed. + select { + case <-ctx.Done(): + t.Fail() + + case _, ok := <-decodeCh: + require.False(t, ok, "channel is expected to be closed") + } + }) +} diff --git a/coderd/workspaceagents_test.go b/coderd/workspaceagents_test.go index 1855ed8a7e8fc..ac58df1b772ad 100644 --- a/coderd/workspaceagents_test.go +++ b/coderd/workspaceagents_test.go @@ -593,7 +593,7 @@ func TestWorkspaceAgentTailnet(t *testing.T) { _ = agenttest.New(t, client.URL, r.AgentToken) resources := coderdtest.AwaitWorkspaceAgents(t, client, r.Workspace.ID) - conn, err := func() (*workspacesdk.AgentConn, error) { + conn, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -1574,82 +1574,6 @@ func TestWatchWorkspaceAgentDevcontainers(t *testing.T) { } } }) - - t.Run("PayloadTooLarge", func(t *testing.T) { - t.Parallel() - - var ( - ctx = testutil.Context(t, testutil.WaitSuperLong) - logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) - mClock = quartz.NewMock(t) - updaterTickerTrap = mClock.Trap().TickerFunc("updaterLoop") - mCtrl = gomock.NewController(t) - mCCLI = acmock.NewMockContainerCLI(mCtrl) - - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{Logger: &logger}) - user = coderdtest.CreateFirstUser(t, client) - r = dbfake.WorkspaceBuild(t, db, database.WorkspaceTable{ - OrganizationID: user.OrganizationID, - OwnerID: user.UserID, - }).WithAgent(func(agents []*proto.Agent) []*proto.Agent { - return agents - }).Do() - ) - - // WebSocket limit is 4MiB, so we want to ensure we create _more_ than 4MiB worth of payload. - // Creating 20,000 fake containers creates a payload of roughly 7MiB. - var fakeContainers []codersdk.WorkspaceAgentContainer - for range 20_000 { - fakeContainers = append(fakeContainers, codersdk.WorkspaceAgentContainer{ - CreatedAt: time.Now(), - ID: uuid.NewString(), - FriendlyName: uuid.NewString(), - Image: "busybox:latest", - Labels: map[string]string{ - agentcontainers.DevcontainerLocalFolderLabel: "/home/coder/project", - agentcontainers.DevcontainerConfigFileLabel: "/home/coder/project/.devcontainer/devcontainer.json", - }, - Running: false, - Ports: []codersdk.WorkspaceAgentContainerPort{}, - Status: string(codersdk.WorkspaceAgentDevcontainerStatusRunning), - Volumes: map[string]string{}, - }) - } - - mCCLI.EXPECT().List(gomock.Any()).Return(codersdk.WorkspaceAgentListContainersResponse{Containers: fakeContainers}, nil) - mCCLI.EXPECT().DetectArchitecture(gomock.Any(), gomock.Any()).Return("", nil).AnyTimes() - - _ = agenttest.New(t, client.URL, r.AgentToken, func(o *agent.Options) { - o.Logger = logger.Named("agent") - o.Devcontainers = true - o.DevcontainerAPIOptions = []agentcontainers.Option{ - agentcontainers.WithClock(mClock), - agentcontainers.WithContainerCLI(mCCLI), - agentcontainers.WithWatcher(watcher.NewNoop()), - } - }) - - resources := coderdtest.NewWorkspaceAgentWaiter(t, client, r.Workspace.ID).Wait() - require.Len(t, resources, 1, "expected one resource") - require.Len(t, resources[0].Agents, 1, "expected one agent") - agentID := resources[0].Agents[0].ID - - updaterTickerTrap.MustWait(ctx).MustRelease(ctx) - defer updaterTickerTrap.Close() - - containers, closer, err := client.WatchWorkspaceAgentContainers(ctx, agentID) - require.NoError(t, err) - defer func() { - closer.Close() - }() - - select { - case <-ctx.Done(): - t.Fail() - case _, ok := <-containers: - require.False(t, ok) - } - }) } func TestWorkspaceAgentRecreateDevcontainer(t *testing.T) { @@ -2497,7 +2421,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { agentID := resources[0].Agents[0].ID // Connect from a client. - conn1, err := func() (*workspacesdk.AgentConn, error) { + conn1, err := func() (workspacesdk.AgentConn, error) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() // Connection should remain open even if the dial context is canceled. @@ -2538,7 +2462,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { // Wait for the DERP map to be updated on the existing client. require.Eventually(t, func() bool { - regionIDs := conn1.Conn.DERPMap().RegionIDs() + regionIDs := conn1.TailnetConn().DERPMap().RegionIDs() return len(regionIDs) == 1 && regionIDs[0] == 2 }, testutil.WaitLong, testutil.IntervalFast) @@ -2555,7 +2479,7 @@ func TestWorkspaceAgent_UpdatedDERP(t *testing.T) { defer conn2.Close() ok = conn2.AwaitReachable(ctx) require.True(t, ok) - require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs()) + require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs()) } func TestWorkspaceAgentExternalAuthListen(t *testing.T) { diff --git a/coderd/workspaceapps/proxy.go b/coderd/workspaceapps/proxy.go index 2f1294558f67a..002bb1ea05aae 100644 --- a/coderd/workspaceapps/proxy.go +++ b/coderd/workspaceapps/proxy.go @@ -74,7 +74,7 @@ type AgentProvider interface { ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy // AgentConn returns a new connection to the specified agent. - AgentConn(ctx context.Context, agentID uuid.UUID) (_ *workspacesdk.AgentConn, release func(), _ error) + AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 9c65b7ee9a1e1..bb929c9ba2a04 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -34,8 +34,8 @@ import ( // to the WorkspaceAgentConn, or it may be shared in the case of coderd. If the // conn is shared and closing it is undesirable, you may return ErrNoClose from // opts.CloseFunc. This will ensure the underlying conn is not closed. -func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { - return &AgentConn{ +func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) AgentConn { + return &agentConn{ Conn: conn, opts: opts, } @@ -43,23 +43,54 @@ func NewAgentConn(conn *tailnet.Conn, opts AgentConnOptions) *AgentConn { // AgentConn represents a connection to a workspace agent. // @typescript-ignore AgentConn -type AgentConn struct { +type AgentConn interface { + TailnetConn() *tailnet.Conn + + AwaitReachable(ctx context.Context) bool + Close() error + DebugLogs(ctx context.Context) ([]byte, error) + DebugMagicsock(ctx context.Context) ([]byte, error) + DebugManifest(ctx context.Context) ([]byte, error) + DialContext(ctx context.Context, network string, addr string) (net.Conn, error) + GetPeerDiagnostics() tailnet.PeerDiagnostics + ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) + ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) + Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) + Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) + PrometheusMetrics(ctx context.Context) ([]byte, error) + ReconnectingPTY(ctx context.Context, id uuid.UUID, height uint16, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) + RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) + SSH(ctx context.Context) (*gonet.TCPConn, error) + SSHClient(ctx context.Context) (*ssh.Client, error) + SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) + SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) + Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) + WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) +} + +// AgentConn represents a connection to a workspace agent. +// @typescript-ignore AgentConn +type agentConn struct { *tailnet.Conn opts AgentConnOptions } +func (c *agentConn) TailnetConn() *tailnet.Conn { + return c.Conn +} + // @typescript-ignore AgentConnOptions type AgentConnOptions struct { AgentID uuid.UUID CloseFunc func() error } -func (c *AgentConn) agentAddress() netip.Addr { +func (c *agentConn) agentAddress() netip.Addr { return tailnet.TailscaleServicePrefix.AddrFromUUID(c.opts.AgentID) } // AwaitReachable waits for the agent to be reachable. -func (c *AgentConn) AwaitReachable(ctx context.Context) bool { +func (c *agentConn) AwaitReachable(ctx context.Context) bool { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -68,7 +99,7 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool { // Ping pings the agent and returns the round-trip time. // The bool returns true if the ping was made P2P. -func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { +func (c *agentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -76,7 +107,7 @@ func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.Pi } // Close ends the connection to the workspace agent. -func (c *AgentConn) Close() error { +func (c *agentConn) Close() error { var cerr error if c.opts.CloseFunc != nil { cerr = c.opts.CloseFunc() @@ -131,7 +162,7 @@ type ReconnectingPTYRequest struct { // ReconnectingPTY spawns a new reconnecting terminal session. // `ReconnectingPTYRequest` should be JSON marshaled and written to the returned net.Conn. // Raw terminal output will be read from the returned net.Conn. -func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { +func (c *agentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...AgentReconnectingPTYInitOption) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -171,13 +202,13 @@ func (c *AgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, w // SSH pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent. -func (c *AgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { +func (c *agentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { return c.SSHOnPort(ctx, AgentSSHPort) } // SSHOnPort pipes the SSH protocol over the returned net.Conn. // This connects to the built-in SSH server in the workspace agent on the specified port. -func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { +func (c *agentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -190,12 +221,12 @@ func (c *AgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, } // SSHClient calls SSH to create a client -func (c *AgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { +func (c *agentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { return c.SSHClientOnPort(ctx, AgentSSHPort) } // SSHClientOnPort calls SSH to create a client on a specific port -func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { +func (c *agentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -218,7 +249,7 @@ func (c *AgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Clie } // Speedtest runs a speedtest against the workspace agent. -func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { +func (c *agentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -242,7 +273,7 @@ func (c *AgentConn) Speedtest(ctx context.Context, direction speedtest.Direction // DialContext dials the address provided in the workspace agent. // The network must be "tcp" or "udp". -func (c *AgentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { +func (c *agentConn) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -265,7 +296,7 @@ func (c *AgentConn) DialContext(ctx context.Context, network string, addr string } // ListeningPorts lists the ports that are currently in use by the workspace. -func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { +func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/listening-ports", nil) @@ -282,7 +313,7 @@ func (c *AgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent } // Netcheck returns a network check report from the workspace agent. -func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { +func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/netcheck", nil) @@ -299,7 +330,7 @@ func (c *AgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport } // DebugMagicsock makes a request to the workspace agent's magicsock debug endpoint. -func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/magicsock", nil) @@ -319,7 +350,7 @@ func (c *AgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { // DebugManifest returns the agent's in-memory manifest. Unfortunately this must // be returns as a []byte to avoid an import cycle. -func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugManifest(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/manifest", nil) @@ -338,7 +369,7 @@ func (c *AgentConn) DebugManifest(ctx context.Context) ([]byte, error) { } // DebugLogs returns up to the last 10MB of `/tmp/coder-agent.log` -func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { +func (c *agentConn) DebugLogs(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/logs", nil) @@ -357,7 +388,7 @@ func (c *AgentConn) DebugLogs(ctx context.Context) ([]byte, error) { } // PrometheusMetrics returns a response from the agent's prometheus metrics endpoint -func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { +func (c *agentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/debug/prometheus", nil) @@ -376,7 +407,7 @@ func (c *AgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { } // ListContainers returns a response from the agent's containers endpoint -func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { +func (c *agentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/containers", nil) @@ -391,7 +422,7 @@ func (c *AgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } -func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { +func (c *agentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -427,7 +458,7 @@ func (c *AgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<- // RecreateDevcontainer recreates a devcontainer with the given container. // This is a blocking call and will wait for the container to be recreated. -func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { +func (c *agentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/containers/devcontainers/"+devcontainerID+"/recreate", nil) @@ -446,7 +477,7 @@ func (c *AgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID str } // apiRequest makes a request to the workspace agent's HTTP API server. -func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { +func (c *agentConn) apiRequest(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -463,7 +494,7 @@ func (c *AgentConn) apiRequest(ctx context.Context, method, path string, body io // apiClient returns an HTTP client that can be used to make // requests to the workspace agent's HTTP API server. -func (c *AgentConn) apiClient() *http.Client { +func (c *agentConn) apiClient() *http.Client { return &http.Client{ Transport: &http.Transport{ // Disable keep alives as we're usually only making a single @@ -504,6 +535,6 @@ func (c *AgentConn) apiClient() *http.Client { } } -func (c *AgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { +func (c *agentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { return c.Conn.GetPeerDiagnostics(c.opts.AgentID) } diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go new file mode 100644 index 0000000000000..eb55bb27938c0 --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -0,0 +1,373 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: .. (interfaces: AgentConn) +// +// Generated by this command: +// +// mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn +// + +// Package agentconnmock is a generated GoMock package. +package agentconnmock + +import ( + context "context" + io "io" + net "net" + reflect "reflect" + time "time" + + slog "cdr.dev/slog" + codersdk "github.com/coder/coder/v2/codersdk" + healthsdk "github.com/coder/coder/v2/codersdk/healthsdk" + workspacesdk "github.com/coder/coder/v2/codersdk/workspacesdk" + tailnet "github.com/coder/coder/v2/tailnet" + uuid "github.com/google/uuid" + gomock "go.uber.org/mock/gomock" + ssh "golang.org/x/crypto/ssh" + gonet "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + ipnstate "tailscale.com/ipn/ipnstate" + speedtest "tailscale.com/net/speedtest" +) + +// MockAgentConn is a mock of AgentConn interface. +type MockAgentConn struct { + ctrl *gomock.Controller + recorder *MockAgentConnMockRecorder + isgomock struct{} +} + +// MockAgentConnMockRecorder is the mock recorder for MockAgentConn. +type MockAgentConnMockRecorder struct { + mock *MockAgentConn +} + +// NewMockAgentConn creates a new mock instance. +func NewMockAgentConn(ctrl *gomock.Controller) *MockAgentConn { + mock := &MockAgentConn{ctrl: ctrl} + mock.recorder = &MockAgentConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAgentConn) EXPECT() *MockAgentConnMockRecorder { + return m.recorder +} + +// AwaitReachable mocks base method. +func (m *MockAgentConn) AwaitReachable(ctx context.Context) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AwaitReachable", ctx) + ret0, _ := ret[0].(bool) + return ret0 +} + +// AwaitReachable indicates an expected call of AwaitReachable. +func (mr *MockAgentConnMockRecorder) AwaitReachable(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitReachable", reflect.TypeOf((*MockAgentConn)(nil).AwaitReachable), ctx) +} + +// Close mocks base method. +func (m *MockAgentConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockAgentConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close)) +} + +// DebugLogs mocks base method. +func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugLogs", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugLogs indicates an expected call of DebugLogs. +func (mr *MockAgentConnMockRecorder) DebugLogs(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugLogs", reflect.TypeOf((*MockAgentConn)(nil).DebugLogs), ctx) +} + +// DebugMagicsock mocks base method. +func (m *MockAgentConn) DebugMagicsock(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugMagicsock", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugMagicsock indicates an expected call of DebugMagicsock. +func (mr *MockAgentConnMockRecorder) DebugMagicsock(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugMagicsock", reflect.TypeOf((*MockAgentConn)(nil).DebugMagicsock), ctx) +} + +// DebugManifest mocks base method. +func (m *MockAgentConn) DebugManifest(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DebugManifest", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DebugManifest indicates an expected call of DebugManifest. +func (mr *MockAgentConnMockRecorder) DebugManifest(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugManifest", reflect.TypeOf((*MockAgentConn)(nil).DebugManifest), ctx) +} + +// DialContext mocks base method. +func (m *MockAgentConn) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DialContext", ctx, network, addr) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DialContext indicates an expected call of DialContext. +func (mr *MockAgentConnMockRecorder) DialContext(ctx, network, addr any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialContext", reflect.TypeOf((*MockAgentConn)(nil).DialContext), ctx, network, addr) +} + +// GetPeerDiagnostics mocks base method. +func (m *MockAgentConn) GetPeerDiagnostics() tailnet.PeerDiagnostics { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerDiagnostics") + ret0, _ := ret[0].(tailnet.PeerDiagnostics) + return ret0 +} + +// GetPeerDiagnostics indicates an expected call of GetPeerDiagnostics. +func (mr *MockAgentConnMockRecorder) GetPeerDiagnostics() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerDiagnostics", reflect.TypeOf((*MockAgentConn)(nil).GetPeerDiagnostics)) +} + +// ListContainers mocks base method. +func (m *MockAgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListContainers", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListContainers indicates an expected call of ListContainers. +func (mr *MockAgentConnMockRecorder) ListContainers(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListContainers", reflect.TypeOf((*MockAgentConn)(nil).ListContainers), ctx) +} + +// ListeningPorts mocks base method. +func (m *MockAgentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgentListeningPortsResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListeningPorts", ctx) + ret0, _ := ret[0].(codersdk.WorkspaceAgentListeningPortsResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListeningPorts indicates an expected call of ListeningPorts. +func (mr *MockAgentConnMockRecorder) ListeningPorts(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListeningPorts", reflect.TypeOf((*MockAgentConn)(nil).ListeningPorts), ctx) +} + +// Netcheck mocks base method. +func (m *MockAgentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Netcheck", ctx) + ret0, _ := ret[0].(healthsdk.AgentNetcheckReport) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Netcheck indicates an expected call of Netcheck. +func (mr *MockAgentConnMockRecorder) Netcheck(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Netcheck", reflect.TypeOf((*MockAgentConn)(nil).Netcheck), ctx) +} + +// Ping mocks base method. +func (m *MockAgentConn) Ping(ctx context.Context) (time.Duration, bool, *ipnstate.PingResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Ping", ctx) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(*ipnstate.PingResult) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Ping indicates an expected call of Ping. +func (mr *MockAgentConnMockRecorder) Ping(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ping", reflect.TypeOf((*MockAgentConn)(nil).Ping), ctx) +} + +// PrometheusMetrics mocks base method. +func (m *MockAgentConn) PrometheusMetrics(ctx context.Context) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PrometheusMetrics", ctx) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PrometheusMetrics indicates an expected call of PrometheusMetrics. +func (mr *MockAgentConnMockRecorder) PrometheusMetrics(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrometheusMetrics", reflect.TypeOf((*MockAgentConn)(nil).PrometheusMetrics), ctx) +} + +// ReconnectingPTY mocks base method. +func (m *MockAgentConn) ReconnectingPTY(ctx context.Context, id uuid.UUID, height, width uint16, command string, initOpts ...workspacesdk.AgentReconnectingPTYInitOption) (net.Conn, error) { + m.ctrl.T.Helper() + varargs := []any{ctx, id, height, width, command} + for _, a := range initOpts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReconnectingPTY", varargs...) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReconnectingPTY indicates an expected call of ReconnectingPTY. +func (mr *MockAgentConnMockRecorder) ReconnectingPTY(ctx, id, height, width, command any, initOpts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, id, height, width, command}, initOpts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReconnectingPTY", reflect.TypeOf((*MockAgentConn)(nil).ReconnectingPTY), varargs...) +} + +// RecreateDevcontainer mocks base method. +func (m *MockAgentConn) RecreateDevcontainer(ctx context.Context, devcontainerID string) (codersdk.Response, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecreateDevcontainer", ctx, devcontainerID) + ret0, _ := ret[0].(codersdk.Response) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RecreateDevcontainer indicates an expected call of RecreateDevcontainer. +func (mr *MockAgentConnMockRecorder) RecreateDevcontainer(ctx, devcontainerID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecreateDevcontainer", reflect.TypeOf((*MockAgentConn)(nil).RecreateDevcontainer), ctx, devcontainerID) +} + +// SSH mocks base method. +func (m *MockAgentConn) SSH(ctx context.Context) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSH", ctx) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSH indicates an expected call of SSH. +func (mr *MockAgentConnMockRecorder) SSH(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSH", reflect.TypeOf((*MockAgentConn)(nil).SSH), ctx) +} + +// SSHClient mocks base method. +func (m *MockAgentConn) SSHClient(ctx context.Context) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClient", ctx) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClient indicates an expected call of SSHClient. +func (mr *MockAgentConnMockRecorder) SSHClient(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClient", reflect.TypeOf((*MockAgentConn)(nil).SSHClient), ctx) +} + +// SSHClientOnPort mocks base method. +func (m *MockAgentConn) SSHClientOnPort(ctx context.Context, port uint16) (*ssh.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHClientOnPort", ctx, port) + ret0, _ := ret[0].(*ssh.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHClientOnPort indicates an expected call of SSHClientOnPort. +func (mr *MockAgentConnMockRecorder) SSHClientOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHClientOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHClientOnPort), ctx, port) +} + +// SSHOnPort mocks base method. +func (m *MockAgentConn) SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SSHOnPort", ctx, port) + ret0, _ := ret[0].(*gonet.TCPConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SSHOnPort indicates an expected call of SSHOnPort. +func (mr *MockAgentConnMockRecorder) SSHOnPort(ctx, port any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SSHOnPort", reflect.TypeOf((*MockAgentConn)(nil).SSHOnPort), ctx, port) +} + +// Speedtest mocks base method. +func (m *MockAgentConn) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Speedtest", ctx, direction, duration) + ret0, _ := ret[0].([]speedtest.Result) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Speedtest indicates an expected call of Speedtest. +func (mr *MockAgentConnMockRecorder) Speedtest(ctx, direction, duration any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Speedtest", reflect.TypeOf((*MockAgentConn)(nil).Speedtest), ctx, direction, duration) +} + +// TailnetConn mocks base method. +func (m *MockAgentConn) TailnetConn() *tailnet.Conn { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TailnetConn") + ret0, _ := ret[0].(*tailnet.Conn) + return ret0 +} + +// TailnetConn indicates an expected call of TailnetConn. +func (mr *MockAgentConnMockRecorder) TailnetConn() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TailnetConn", reflect.TypeOf((*MockAgentConn)(nil).TailnetConn)) +} + +// WatchContainers mocks base method. +func (m *MockAgentConn) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WatchContainers", ctx, logger) + ret0, _ := ret[0].(<-chan codersdk.WorkspaceAgentListContainersResponse) + ret1, _ := ret[1].(io.Closer) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// WatchContainers indicates an expected call of WatchContainers. +func (mr *MockAgentConnMockRecorder) WatchContainers(ctx, logger any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchContainers", reflect.TypeOf((*MockAgentConn)(nil).WatchContainers), ctx, logger) +} diff --git a/codersdk/workspacesdk/agentconnmock/doc.go b/codersdk/workspacesdk/agentconnmock/doc.go new file mode 100644 index 0000000000000..a795b21a4a89d --- /dev/null +++ b/codersdk/workspacesdk/agentconnmock/doc.go @@ -0,0 +1,4 @@ +// Package agentconnmock contains a mock implementation of workspacesdk.AgentConn for use in tests. +package agentconnmock + +//go:generate mockgen -destination ./agentconnmock.go -package agentconnmock .. AgentConn diff --git a/codersdk/workspacesdk/workspacesdk.go b/codersdk/workspacesdk/workspacesdk.go index 9f587cf5267a8..ddaec06388238 100644 --- a/codersdk/workspacesdk/workspacesdk.go +++ b/codersdk/workspacesdk/workspacesdk.go @@ -202,7 +202,7 @@ func (c *Client) RewriteDERPMap(derpMap *tailcfg.DERPMap) { tailnet.RewriteDERPMapDefaultRelay(context.Background(), c.client.Logger(), derpMap, c.client.URL) } -func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn *AgentConn, err error) { +func (c *Client) DialAgent(dialCtx context.Context, agentID uuid.UUID, options *DialAgentOptions) (agentConn AgentConn, err error) { if options == nil { options = &DialAgentOptions{} } diff --git a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go index b0051551a0f3d..72f5a4291c40e 100644 --- a/enterprise/wsproxy/wsproxysdk/wsproxysdk.go +++ b/enterprise/wsproxy/wsproxysdk/wsproxysdk.go @@ -75,7 +75,7 @@ func (c *Client) RequestIgnoreRedirects(ctx context.Context, method, path string // DialWorkspaceAgent calls the underlying codersdk.Client's DialWorkspaceAgent // method. -func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn *workspacesdk.AgentConn, err error) { +func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, options *workspacesdk.DialAgentOptions) (agentConn workspacesdk.AgentConn, err error) { return workspacesdk.New(c.SDKClient).DialAgent(ctx, agentID, options) } diff --git a/scaletest/agentconn/run.go b/scaletest/agentconn/run.go index dba21cc24e3a0..b0990d9cb11a6 100644 --- a/scaletest/agentconn/run.go +++ b/scaletest/agentconn/run.go @@ -89,7 +89,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { // Ensure DERP for completeness. if r.cfg.ConnectionMode == ConnectionModeDerp { - status := conn.Status() + status := conn.TailnetConn().Status() if len(status.Peers()) != 1 { return xerrors.Errorf("check connection mode: expected 1 peer, got %d", len(status.Peers())) } @@ -133,7 +133,7 @@ func (r *Runner) Run(ctx context.Context, _ string, w io.Writer) error { return nil } -func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDisco(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const pingAttempts = 10 const pingDelay = 1 * time.Second @@ -165,7 +165,7 @@ func waitForDisco(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentC return nil } -func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func waitForDirectConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const directConnectionAttempts = 30 const directConnectionDelay = 1 * time.Second @@ -174,7 +174,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac for i := 0; i < directConnectionAttempts; i++ { _, _ = fmt.Fprintf(logs, "\tDirect connection check %d/%d...\n", i+1, directConnectionAttempts) - status := conn.Status() + status := conn.TailnetConn().Status() var err error if len(status.Peers()) != 1 { @@ -207,7 +207,7 @@ func waitForDirectConnection(ctx context.Context, logs io.Writer, conn *workspac return nil } -func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn) error { +func verifyConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn) error { const verifyConnectionAttempts = 30 const verifyConnectionDelay = 1 * time.Second @@ -249,7 +249,7 @@ func verifyConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Ag return nil } -func performInitialConnections(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, specs []Connection) error { +func performInitialConnections(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, specs []Connection) error { if len(specs) == 0 { return nil } @@ -287,7 +287,7 @@ func performInitialConnections(ctx context.Context, logs io.Writer, conn *worksp return nil } -func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { +func holdConnection(ctx context.Context, logs io.Writer, conn workspacesdk.AgentConn, holdDur time.Duration, specs []Connection) error { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -364,7 +364,7 @@ func holdConnection(ctx context.Context, logs io.Writer, conn *workspacesdk.Agen return nil } -func agentHTTPClient(conn *workspacesdk.AgentConn) *http.Client { +func agentHTTPClient(conn workspacesdk.AgentConn) *http.Client { return &http.Client{ Transport: &http.Transport{ DisableKeepAlives: true, diff --git a/support/support.go b/support/support.go index 2fa41ce7eca8c..31080faaf023b 100644 --- a/support/support.go +++ b/support/support.go @@ -390,7 +390,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L if err := conn.Close(); err != nil { log.Error(ctx, "failed to close agent connection", slog.Error(err)) } - <-conn.Closed() + <-conn.TailnetConn().Closed() } eg.Go(func() error { @@ -399,7 +399,7 @@ func connectedAgentInfo(ctx context.Context, client *codersdk.Client, log slog.L return xerrors.Errorf("create request: %w", err) } rr := httptest.NewRecorder() - conn.MagicsockServeHTTPDebug(rr, req) + conn.TailnetConn().MagicsockServeHTTPDebug(rr, req) a.ClientMagicsockHTML = rr.Body.Bytes() return nil })