Skip to content

Commit 5e84d25

Browse files
refactor: convert workspacesdk.AgentConn to an interface (#19392)
Fixes coder/internal#907 We convert `workspacesdk.AgentConn` to an interface and generate a mock for it. This allows writing `coderd` tests that rely on the agent's HTTP api to not have to set up an entire tailnet networking stack.
1 parent 23c494f commit 5e84d25

File tree

18 files changed

+667
-143
lines changed

18 files changed

+667
-143
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ GEN_FILES := \
636636
coderd/database/pubsub/psmock/psmock.go \
637637
agent/agentcontainers/acmock/acmock.go \
638638
agent/agentcontainers/dcspec/dcspec_gen.go \
639-
coderd/httpmw/loggermw/loggermock/loggermock.go
639+
coderd/httpmw/loggermw/loggermock/loggermock.go \
640+
codersdk/workspacesdk/agentconnmock/agentconnmock.go
640641

641642
# all gen targets should be added here and to gen/mark-fresh
642643
gen: gen/db gen/golden-files $(GEN_FILES)
@@ -686,6 +687,7 @@ gen/mark-fresh:
686687
agent/agentcontainers/acmock/acmock.go \
687688
agent/agentcontainers/dcspec/dcspec_gen.go \
688689
coderd/httpmw/loggermw/loggermock/loggermock.go \
690+
codersdk/workspacesdk/agentconnmock/agentconnmock.go \
689691
"
690692

691693
for file in $$files; do
@@ -729,6 +731,10 @@ coderd/httpmw/loggermw/loggermock/loggermock.go: coderd/httpmw/loggermw/logger.g
729731
go generate ./coderd/httpmw/loggermw/loggermock/
730732
touch "$@"
731733

734+
codersdk/workspacesdk/agentconnmock/agentconnmock.go: codersdk/workspacesdk/agentconn.go
735+
go generate ./codersdk/workspacesdk/agentconnmock/
736+
touch "$@"
737+
732738
agent/agentcontainers/dcspec/dcspec_gen.go: \
733739
node_modules/.installed \
734740
agent/agentcontainers/dcspec/devContainer.base.schema.json \

agent/agent_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,9 +2750,9 @@ func TestAgent_Dial(t *testing.T) {
27502750

27512751
switch l.Addr().Network() {
27522752
case "tcp":
2753-
conn, err = agentConn.Conn.DialContextTCP(ctx, ipp)
2753+
conn, err = agentConn.TailnetConn().DialContextTCP(ctx, ipp)
27542754
case "udp":
2755-
conn, err = agentConn.Conn.DialContextUDP(ctx, ipp)
2755+
conn, err = agentConn.TailnetConn().DialContextUDP(ctx, ipp)
27562756
default:
27572757
t.Fatalf("unknown network: %s", l.Addr().Network())
27582758
}
@@ -2811,7 +2811,7 @@ func TestAgent_UpdatedDERP(t *testing.T) {
28112811
})
28122812

28132813
// Setup a client connection.
2814-
newClientConn := func(derpMap *tailcfg.DERPMap, name string) *workspacesdk.AgentConn {
2814+
newClientConn := func(derpMap *tailcfg.DERPMap, name string) workspacesdk.AgentConn {
28152815
conn, err := tailnet.NewConn(&tailnet.Options{
28162816
Addresses: []netip.Prefix{tailnet.TailscaleServicePrefix.RandomPrefix()},
28172817
DERPMap: derpMap,
@@ -2891,13 +2891,13 @@ func TestAgent_UpdatedDERP(t *testing.T) {
28912891

28922892
// Connect from a second client and make sure it uses the new DERP map.
28932893
conn2 := newClientConn(newDerpMap, "client2")
2894-
require.Equal(t, []int{2}, conn2.DERPMap().RegionIDs())
2894+
require.Equal(t, []int{2}, conn2.TailnetConn().DERPMap().RegionIDs())
28952895
t.Log("conn2 got the new DERPMap")
28962896

28972897
// If the first client gets a DERP map update, it should be able to
28982898
// reconnect just fine.
2899-
conn1.SetDERPMap(newDerpMap)
2900-
require.Equal(t, []int{2}, conn1.DERPMap().RegionIDs())
2899+
conn1.TailnetConn().SetDERPMap(newDerpMap)
2900+
require.Equal(t, []int{2}, conn1.TailnetConn().DERPMap().RegionIDs())
29012901
t.Log("set the new DERPMap on conn1")
29022902
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
29032903
defer cancel()
@@ -3264,7 +3264,7 @@ func setupSSHSessionOnPort(
32643264
}
32653265

32663266
func setupAgent(t testing.TB, metadata agentsdk.Manifest, ptyTimeout time.Duration, opts ...func(*agenttest.Client, *agent.Options)) (
3267-
*workspacesdk.AgentConn,
3267+
workspacesdk.AgentConn,
32683268
*agenttest.Client,
32693269
<-chan *proto.Stats,
32703270
afero.Fs,

cli/ping.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (r *RootCmd) ping() *serpent.Command {
147147
}
148148
defer conn.Close()
149149

150-
derpMap := conn.DERPMap()
150+
derpMap := conn.TailnetConn().DERPMap()
151151

152152
diagCtx, diagCancel := context.WithTimeout(inv.Context(), 30*time.Second)
153153
defer diagCancel()
@@ -156,7 +156,7 @@ func (r *RootCmd) ping() *serpent.Command {
156156
// Silent ping to determine whether we should show diags
157157
_, didP2p, _, _ := conn.Ping(ctx)
158158

159-
ni := conn.GetNetInfo()
159+
ni := conn.TailnetConn().GetNetInfo()
160160
connDiags := cliui.ConnDiags{
161161
DisableDirect: r.disableDirect,
162162
LocalNetInfo: ni,

cli/portforward.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ func (r *RootCmd) portForward() *serpent.Command {
221221
func listenAndPortForward(
222222
ctx context.Context,
223223
inv *serpent.Invocation,
224-
conn *workspacesdk.AgentConn,
224+
conn workspacesdk.AgentConn,
225225
wg *sync.WaitGroup,
226226
spec portForwardSpec,
227227
logger slog.Logger,

cli/speedtest.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
139139
if err != nil {
140140
continue
141141
}
142-
status := conn.Status()
142+
status := conn.TailnetConn().Status()
143143
if len(status.Peers()) != 1 {
144144
continue
145145
}
@@ -189,7 +189,7 @@ func (r *RootCmd) speedtest() *serpent.Command {
189189
outputResult.Intervals[i] = interval
190190
}
191191
}
192-
conn.Conn.SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
192+
conn.TailnetConn().SendSpeedtestTelemetry(outputResult.Overall.ThroughputMbits)
193193
out, err := formatter.Format(inv.Context(), outputResult)
194194
if err != nil {
195195
return err

cli/ssh.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ func (r *RootCmd) ssh() *serpent.Command {
590590
}
591591

592592
err = sshSession.Wait()
593-
conn.SendDisconnectedTelemetry()
593+
conn.TailnetConn().SendDisconnectedTelemetry()
594594
if err != nil {
595595
if exitErr := (&gossh.ExitError{}); errors.As(err, &exitErr) {
596596
// Clear the error since it's not useful beyond
@@ -1364,7 +1364,7 @@ func getUsageAppName(usageApp string) codersdk.UsageAppName {
13641364

13651365
func setStatsCallback(
13661366
ctx context.Context,
1367-
agentConn *workspacesdk.AgentConn,
1367+
agentConn workspacesdk.AgentConn,
13681368
logger slog.Logger,
13691369
networkInfoDir string,
13701370
networkInfoInterval time.Duration,
@@ -1437,7 +1437,7 @@ func setStatsCallback(
14371437

14381438
now := time.Now()
14391439
cb(now, now.Add(time.Nanosecond), map[netlogtype.Connection]netlogtype.Counts{}, map[netlogtype.Connection]netlogtype.Counts{})
1440-
agentConn.SetConnStatsCallback(networkInfoInterval, 2048, cb)
1440+
agentConn.TailnetConn().SetConnStatsCallback(networkInfoInterval, 2048, cb)
14411441
return errCh, nil
14421442
}
14431443

@@ -1451,13 +1451,13 @@ type sshNetworkStats struct {
14511451
UsingCoderConnect bool `json:"using_coder_connect"`
14521452
}
14531453

1454-
func collectNetworkStats(ctx context.Context, agentConn *workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
1454+
func collectNetworkStats(ctx context.Context, agentConn workspacesdk.AgentConn, start, end time.Time, counts map[netlogtype.Connection]netlogtype.Counts) (*sshNetworkStats, error) {
14551455
latency, p2p, pingResult, err := agentConn.Ping(ctx)
14561456
if err != nil {
14571457
return nil, err
14581458
}
1459-
node := agentConn.Node()
1460-
derpMap := agentConn.DERPMap()
1459+
node := agentConn.TailnetConn().Node()
1460+
derpMap := agentConn.TailnetConn().DERPMap()
14611461

14621462
totalRx := uint64(0)
14631463
totalTx := uint64(0)

coderd/coderd.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,9 @@ func New(options *Options) *API {
325325
})
326326
}
327327

328+
if options.PrometheusRegistry == nil {
329+
options.PrometheusRegistry = prometheus.NewRegistry()
330+
}
328331
if options.Authorizer == nil {
329332
options.Authorizer = rbac.NewCachingAuthorizer(options.PrometheusRegistry)
330333
if buildinfo.IsDev() {
@@ -381,9 +384,6 @@ func New(options *Options) *API {
381384
if options.FilesRateLimit == 0 {
382385
options.FilesRateLimit = 12
383386
}
384-
if options.PrometheusRegistry == nil {
385-
options.PrometheusRegistry = prometheus.NewRegistry()
386-
}
387387
if options.Clock == nil {
388388
options.Clock = quartz.NewReal()
389389
}

coderd/tailnet.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,9 @@ func (s *ServerTailnet) dialContext(ctx context.Context, network, addr string) (
277277
}, nil
278278
}
279279

280-
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (*workspacesdk.AgentConn, func(), error) {
280+
func (s *ServerTailnet) AgentConn(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) {
281281
var (
282-
conn *workspacesdk.AgentConn
282+
conn workspacesdk.AgentConn
283283
ret func()
284284
)
285285

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
package coderd
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"database/sql"
7+
"fmt"
8+
"io"
9+
"net/http"
10+
"net/http/httptest"
11+
"net/http/httputil"
12+
"net/url"
13+
"strings"
14+
"testing"
15+
16+
"github.com/go-chi/chi/v5"
17+
"github.com/google/uuid"
18+
"github.com/stretchr/testify/require"
19+
"go.uber.org/mock/gomock"
20+
21+
"cdr.dev/slog"
22+
"cdr.dev/slog/sloggers/slogtest"
23+
"github.com/coder/coder/v2/coderd/database"
24+
"github.com/coder/coder/v2/coderd/database/dbmock"
25+
"github.com/coder/coder/v2/coderd/database/dbtime"
26+
"github.com/coder/coder/v2/coderd/httpmw"
27+
"github.com/coder/coder/v2/coderd/workspaceapps/appurl"
28+
"github.com/coder/coder/v2/codersdk"
29+
"github.com/coder/coder/v2/codersdk/workspacesdk"
30+
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
31+
"github.com/coder/coder/v2/codersdk/wsjson"
32+
"github.com/coder/coder/v2/tailnet"
33+
"github.com/coder/coder/v2/tailnet/tailnettest"
34+
"github.com/coder/coder/v2/testutil"
35+
"github.com/coder/websocket"
36+
)
37+
38+
type fakeAgentProvider struct {
39+
agentConn func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error)
40+
}
41+
42+
func (fakeAgentProvider) ReverseProxy(targetURL, dashboardURL *url.URL, agentID uuid.UUID, app appurl.ApplicationURL, wildcardHost string) *httputil.ReverseProxy {
43+
panic("unimplemented")
44+
}
45+
46+
func (f fakeAgentProvider) AgentConn(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
47+
if f.agentConn != nil {
48+
return f.agentConn(ctx, agentID)
49+
}
50+
51+
panic("unimplemented")
52+
}
53+
54+
func (fakeAgentProvider) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) {
55+
panic("unimplemented")
56+
}
57+
58+
func (fakeAgentProvider) Close() error {
59+
return nil
60+
}
61+
62+
func TestWatchAgentContainers(t *testing.T) {
63+
t.Parallel()
64+
65+
t.Run("WebSocketClosesProperly", func(t *testing.T) {
66+
t.Parallel()
67+
68+
// This test ensures that the agent containers `/watch` websocket can gracefully
69+
// handle the underlying websocket unexpectedly closing. This test was created in
70+
// response to this issue: https://github.com/coder/coder/issues/19372
71+
72+
var (
73+
ctx = testutil.Context(t, testutil.WaitShort)
74+
logger = slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug).Named("coderd")
75+
76+
mCtrl = gomock.NewController(t)
77+
mDB = dbmock.NewMockStore(mCtrl)
78+
mCoordinator = tailnettest.NewMockCoordinator(mCtrl)
79+
mAgentConn = agentconnmock.NewMockAgentConn(mCtrl)
80+
81+
fAgentProvider = fakeAgentProvider{
82+
agentConn: func(ctx context.Context, agentID uuid.UUID) (_ workspacesdk.AgentConn, release func(), _ error) {
83+
return mAgentConn, func() {}, nil
84+
},
85+
}
86+
87+
workspaceID = uuid.New()
88+
agentID = uuid.New()
89+
resourceID = uuid.New()
90+
jobID = uuid.New()
91+
buildID = uuid.New()
92+
93+
containersCh = make(chan codersdk.WorkspaceAgentListContainersResponse)
94+
95+
r = chi.NewMux()
96+
97+
api = API{
98+
ctx: ctx,
99+
Options: &Options{
100+
AgentInactiveDisconnectTimeout: testutil.WaitShort,
101+
Database: mDB,
102+
Logger: logger,
103+
DeploymentValues: &codersdk.DeploymentValues{},
104+
TailnetCoordinator: tailnettest.NewFakeCoordinator(),
105+
},
106+
}
107+
)
108+
109+
var tailnetCoordinator tailnet.Coordinator = mCoordinator
110+
api.TailnetCoordinator.Store(&tailnetCoordinator)
111+
api.agentProvider = fAgentProvider
112+
113+
// Setup: Allow `ExtractWorkspaceAgentParams` to complete.
114+
mDB.EXPECT().GetWorkspaceAgentByID(gomock.Any(), agentID).Return(database.WorkspaceAgent{
115+
ID: agentID,
116+
ResourceID: resourceID,
117+
LifecycleState: database.WorkspaceAgentLifecycleStateReady,
118+
FirstConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
119+
LastConnectedAt: sql.NullTime{Valid: true, Time: dbtime.Now()},
120+
}, nil)
121+
mDB.EXPECT().GetWorkspaceResourceByID(gomock.Any(), resourceID).Return(database.WorkspaceResource{
122+
ID: resourceID,
123+
JobID: jobID,
124+
}, nil)
125+
mDB.EXPECT().GetProvisionerJobByID(gomock.Any(), jobID).Return(database.ProvisionerJob{
126+
ID: jobID,
127+
Type: database.ProvisionerJobTypeWorkspaceBuild,
128+
}, nil)
129+
mDB.EXPECT().GetWorkspaceBuildByJobID(gomock.Any(), jobID).Return(database.WorkspaceBuild{
130+
WorkspaceID: workspaceID,
131+
ID: buildID,
132+
}, nil)
133+
134+
// And: Allow `db2dsk.WorkspaceAgent` to complete.
135+
mCoordinator.EXPECT().Node(gomock.Any()).Return(nil)
136+
137+
// And: Allow `WatchContainers` to be called, returing our `containersCh` channel.
138+
mAgentConn.EXPECT().WatchContainers(gomock.Any(), gomock.Any()).
139+
Return(containersCh, io.NopCloser(&bytes.Buffer{}), nil)
140+
141+
// And: We mount the HTTP Handler
142+
r.With(httpmw.ExtractWorkspaceAgentParam(mDB)).
143+
Get("/workspaceagents/{workspaceagent}/containers/watch", api.watchWorkspaceAgentContainers)
144+
145+
// Given: We create the HTTP server
146+
srv := httptest.NewServer(r)
147+
defer srv.Close()
148+
149+
// And: Dial the WebSocket
150+
wsURL := strings.Replace(srv.URL, "http://", "ws://", 1)
151+
conn, resp, err := websocket.Dial(ctx, fmt.Sprintf("%s/workspaceagents/%s/containers/watch", wsURL, agentID), nil)
152+
require.NoError(t, err)
153+
if resp.Body != nil {
154+
defer resp.Body.Close()
155+
}
156+
157+
// And: Create a streaming decoder
158+
decoder := wsjson.NewDecoder[codersdk.WorkspaceAgentListContainersResponse](conn, websocket.MessageText, logger)
159+
defer decoder.Close()
160+
decodeCh := decoder.Chan()
161+
162+
// And: We can successfully send through the channel.
163+
testutil.RequireSend(ctx, t, containersCh, codersdk.WorkspaceAgentListContainersResponse{
164+
Containers: []codersdk.WorkspaceAgentContainer{{
165+
ID: "test-container-id",
166+
}},
167+
})
168+
169+
// And: Receive the data.
170+
containerResp := testutil.RequireReceive(ctx, t, decodeCh)
171+
require.Len(t, containerResp.Containers, 1)
172+
require.Equal(t, "test-container-id", containerResp.Containers[0].ID)
173+
174+
// When: We close the `containersCh`
175+
close(containersCh)
176+
177+
// Then: We expect `decodeCh` to be closed.
178+
select {
179+
case <-ctx.Done():
180+
t.Fail()
181+
182+
case _, ok := <-decodeCh:
183+
require.False(t, ok, "channel is expected to be closed")
184+
}
185+
})
186+
}

0 commit comments

Comments
 (0)