Skip to content

Commit 5fa6879

Browse files
committed
feat: check agent API version on connection
1 parent 35d99a8 commit 5fa6879

File tree

12 files changed

+59
-19
lines changed

12 files changed

+59
-19
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ gen/mark-fresh:
507507
examples/examples.gen.json \
508508
tailnet/tailnettest/coordinatormock.go \
509509
tailnet/tailnettest/coordinateemock.go \
510-
tailnet/tailnettest/multiagentmockmock.go \
510+
tailnet/tailnettest/multiagentmock.go \
511511
"
512512
for file in $$files; do
513513
echo "$$file"

agent/proto/version.go

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package proto
2+
3+
import (
4+
"github.com/coder/coder/v2/tailnet/proto"
5+
)
6+
7+
// CurrentVersion is the current version of the agent API. It is tied to the
8+
// tailnet API version to avoid confusion, since agents connect to the tailnet
9+
// API over the same websocket.
10+
var CurrentVersion = proto.CurrentVersion

coderd/workspaceagents.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import (
4343
"github.com/coder/coder/v2/codersdk"
4444
"github.com/coder/coder/v2/codersdk/agentsdk"
4545
"github.com/coder/coder/v2/tailnet"
46+
"github.com/coder/coder/v2/tailnet/proto"
4647
)
4748

4849
// @Summary Get workspace agent by ID
@@ -1162,7 +1163,7 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
11621163
if qv != "" {
11631164
version = qv
11641165
}
1165-
if err := tailnet.CurrentVersion.Validate(version); err != nil {
1166+
if err := proto.CurrentVersion.Validate(version); err != nil {
11661167
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
11671168
Message: "Unknown or unsupported API version",
11681169
Validations: []codersdk.ValidationError{

coderd/workspaceagentsrpc.go

+18
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"nhooyr.io/websocket"
1717

1818
"cdr.dev/slog"
19+
"github.com/coder/coder/v2/agent/proto"
1920
"github.com/coder/coder/v2/coderd/agentapi"
2021
"github.com/coder/coder/v2/coderd/database"
2122
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -37,6 +38,23 @@ import (
3738
func (api *API) workspaceAgentRPC(rw http.ResponseWriter, r *http.Request) {
3839
ctx := r.Context()
3940

41+
version := r.URL.Query().Get("version")
42+
if version == "" {
43+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
44+
Message: "Missing required query parameter: version",
45+
})
46+
return
47+
}
48+
if err := proto.CurrentVersion.Validate(version); err != nil {
49+
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
50+
Message: "Unknown or unsupported API version",
51+
Validations: []codersdk.ValidationError{
52+
{Field: "version", Detail: err.Error()},
53+
},
54+
})
55+
return
56+
}
57+
4058
api.WebsocketWaitMutex.Lock()
4159
api.WebsocketWaitGroup.Add(1)
4260
api.WebsocketWaitMutex.Unlock()

codersdk/agentsdk/agentsdk.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"tailscale.com/tailcfg"
2222

2323
"cdr.dev/slog"
24+
"github.com/coder/coder/v2/agent/proto"
2425
"github.com/coder/coder/v2/codersdk"
2526
drpcsdk "github.com/coder/coder/v2/codersdk/drpc"
2627
"github.com/coder/retry"
@@ -281,18 +282,22 @@ func (c *Client) DERPMapUpdates(ctx context.Context) (<-chan DERPMapUpdate, io.C
281282
}, nil
282283
}
283284

284-
// Listen connects to the workspace agent coordinate WebSocket
285+
// Listen connects to the workspace agent API WebSocket
285286
// that handles connection negotiation.
286287
func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
287-
coordinateURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
288+
rpcURL, err := c.SDK.URL.Parse("/api/v2/workspaceagents/me/rpc")
288289
if err != nil {
289290
return nil, xerrors.Errorf("parse url: %w", err)
290291
}
292+
q := rpcURL.Query()
293+
q.Add("version", proto.CurrentVersion.String())
294+
rpcURL.RawQuery = q.Encode()
295+
291296
jar, err := cookiejar.New(nil)
292297
if err != nil {
293298
return nil, xerrors.Errorf("create cookie jar: %w", err)
294299
}
295-
jar.SetCookies(coordinateURL, []*http.Cookie{{
300+
jar.SetCookies(rpcURL, []*http.Cookie{{
296301
Name: codersdk.SessionTokenCookie,
297302
Value: c.SDK.SessionToken(),
298303
}})
@@ -301,7 +306,7 @@ func (c *Client) Listen(ctx context.Context) (drpc.Conn, error) {
301306
Transport: c.SDK.HTTPClient.Transport,
302307
}
303308
// nolint:bodyclose
304-
conn, res, err := websocket.Dial(ctx, coordinateURL.String(), &websocket.DialOptions{
309+
conn, res, err := websocket.Dial(ctx, rpcURL.String(), &websocket.DialOptions{
305310
HTTPClient: httpClient,
306311
})
307312
if err != nil {

codersdk/workspaceagents.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"cdr.dev/slog"
2323
"github.com/coder/coder/v2/coderd/tracing"
2424
"github.com/coder/coder/v2/tailnet"
25+
"github.com/coder/coder/v2/tailnet/proto"
2526
"github.com/coder/retry"
2627
)
2728

@@ -314,7 +315,7 @@ func (c *Client) DialWorkspaceAgent(dialCtx context.Context, agentID uuid.UUID,
314315
return nil, xerrors.Errorf("parse url: %w", err)
315316
}
316317
q := coordinateURL.Query()
317-
q.Add("version", tailnet.CurrentVersion.String())
318+
q.Add("version", proto.CurrentVersion.String())
318319
coordinateURL.RawQuery = q.Encode()
319320
closedCoordinator := make(chan struct{})
320321
// Must only ever be used once, send error OR close to avoid

enterprise/coderd/workspaceproxycoordinate.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"github.com/coder/coder/v2/coderd/util/apiversion"
1212
"github.com/coder/coder/v2/codersdk"
1313
"github.com/coder/coder/v2/enterprise/wsproxy/wsproxysdk"
14-
agpl "github.com/coder/coder/v2/tailnet"
14+
"github.com/coder/coder/v2/tailnet/proto"
1515
)
1616

1717
// @Summary Agent is legacy
@@ -59,7 +59,7 @@ func (api *API) workspaceProxyCoordinate(rw http.ResponseWriter, r *http.Request
5959
if qv != "" {
6060
version = qv
6161
}
62-
if err := agpl.CurrentVersion.Validate(version); err != nil {
62+
if err := proto.CurrentVersion.Validate(version); err != nil {
6363
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
6464
Message: "Unknown or unsupported API version",
6565
Validations: []codersdk.ValidationError{

enterprise/wsproxy/wsproxysdk/wsproxysdk.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ func (c *Client) DialCoordinator(ctx context.Context) (agpl.MultiAgentConn, erro
439439
return nil, xerrors.Errorf("parse url: %w", err)
440440
}
441441
q := coordinateURL.Query()
442-
q.Add("version", agpl.CurrentVersion.String())
442+
q.Add("version", proto.CurrentVersion.String())
443443
coordinateURL.RawQuery = q.Encode()
444444
coordinateHeaders := make(http.Header)
445445
tokenHeader := codersdk.SessionTokenHeader

enterprise/wsproxy/wsproxysdk/wsproxysdk_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ func TestDialCoordinator(t *testing.T) {
194194
return
195195
}
196196
version := r.URL.Query().Get("version")
197-
if !assert.Equal(t, version, agpl.CurrentVersion.String()) {
197+
if !assert.Equal(t, version, proto.CurrentVersion.String()) {
198198
return
199199
}
200200
nc := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)

tailnet/coordinator_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ func TestRemoteCoordination(t *testing.T) {
460460

461461
serveErr := make(chan error, 1)
462462
go func() {
463-
err := svc.ServeClient(ctx, tailnet.CurrentVersion.String(), sC, clientID, agentID)
463+
err := svc.ServeClient(ctx, proto.CurrentVersion.String(), sC, clientID, agentID)
464464
serveErr <- err
465465
}()
466466

tailnet/proto/version.go

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package proto
2+
3+
import (
4+
"github.com/coder/coder/v2/coderd/util/apiversion"
5+
)
6+
7+
const (
8+
CurrentMajor = 2
9+
CurrentMinor = 0
10+
)
11+
12+
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)

tailnet/service.go

-7
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ import (
2020
"golang.org/x/xerrors"
2121
)
2222

23-
const (
24-
CurrentMajor = 2
25-
CurrentMinor = 0
26-
)
27-
28-
var CurrentVersion = apiversion.New(CurrentMajor, CurrentMinor).WithBackwardCompat(1)
29-
3023
type streamIDContextKey struct{}
3124

3225
// StreamID identifies the caller of the CoordinateTailnet RPC. We store this

0 commit comments

Comments
 (0)