Skip to content

Commit 72b8701

Browse files
committed
chore: make workspace sdk dialer fail fast for authnz errors
1 parent b0ba798 commit 72b8701

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

codersdk/workspacesdk/dialer.go

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ var permanentErrorStatuses = []int{
2424
http.StatusBadRequest, // returned if API mismatch
2525
http.StatusNotFound, // returned if user doesn't have permission or agent doesn't exist
2626
http.StatusInternalServerError, // returned if database is not reachable,
27+
http.StatusUnauthorized, // returned if user is not authenticated
28+
http.StatusForbidden, // returned if user is not authorized
2729
}
2830

2931
type WebsocketDialer struct {
@@ -39,6 +41,24 @@ type WebsocketDialer struct {
3941
isFirst bool
4042
}
4143

44+
// checkResumeTokenFailure checks if the parsed error indicates a resume token failure
45+
// and updates the resumeTokenFailed flag accordingly. Returns true if a resume token
46+
// failure was detected.
47+
func (w *WebsocketDialer) checkResumeTokenFailure(ctx context.Context, sdkErr *codersdk.Error) bool {
48+
if sdkErr == nil {
49+
return false
50+
}
51+
52+
for _, v := range sdkErr.Validations {
53+
if v.Field == "resume_token" {
54+
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
55+
w.resumeTokenFailed = true
56+
return true
57+
}
58+
}
59+
return false
60+
}
61+
4262
type WebsocketDialerOption func(*WebsocketDialer)
4363

4464
func WithWorkspaceUpdates(req *proto.WorkspaceUpdatesRequest) WebsocketDialerOption {
@@ -82,9 +102,14 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
82102
if w.isFirst {
83103
if res != nil && slices.Contains(permanentErrorStatuses, res.StatusCode) {
84104
err = codersdk.ReadBodyAsError(res)
85-
// A bit more human-readable help in the case the API version was rejected
86105
var sdkErr *codersdk.Error
87106
if xerrors.As(err, &sdkErr) {
107+
// Check for resume token failure first
108+
if w.checkResumeTokenFailure(ctx, sdkErr) {
109+
return tailnet.ControlProtocolClients{}, err
110+
}
111+
112+
// A bit more human-readable help in the case the API version was rejected
88113
if sdkErr.Message == AgentAPIMismatchMessage &&
89114
sdkErr.StatusCode() == http.StatusBadRequest {
90115
sdkErr.Helper = fmt.Sprintf(
@@ -107,13 +132,8 @@ func (w *WebsocketDialer) Dial(ctx context.Context, r tailnet.ResumeTokenControl
107132
bodyErr := codersdk.ReadBodyAsError(res)
108133
var sdkErr *codersdk.Error
109134
if xerrors.As(bodyErr, &sdkErr) {
110-
for _, v := range sdkErr.Validations {
111-
if v.Field == "resume_token" {
112-
// Unset the resume token for the next attempt
113-
w.logger.Warn(ctx, "failed to dial tailnet v2+ API: server replied invalid resume token; unsetting for next connection attempt")
114-
w.resumeTokenFailed = true
115-
return tailnet.ControlProtocolClients{}, err
116-
}
135+
if w.checkResumeTokenFailure(ctx, sdkErr) {
136+
return tailnet.ControlProtocolClients{}, err
117137
}
118138
}
119139
if !errors.Is(err, context.Canceled) {

codersdk/workspacesdk/dialer_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,46 @@ func TestWebsocketDialer_ResumeTokenFailure(t *testing.T) {
270270
require.Error(t, err)
271271
}
272272

273+
func TestWebsocketDialer_UnauthenticatedFailFast(t *testing.T) {
274+
t.Parallel()
275+
ctx := testutil.Context(t, testutil.WaitShort)
276+
logger := slogtest.Make(t, &slogtest.Options{
277+
IgnoreErrors: true,
278+
}).Leveled(slog.LevelDebug)
279+
280+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
281+
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
282+
}))
283+
defer svr.Close()
284+
svrURL, err := url.Parse(svr.URL)
285+
require.NoError(t, err)
286+
287+
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
288+
289+
_, err = uut.Dial(ctx, nil)
290+
require.Error(t, err)
291+
}
292+
293+
func TestWebsocketDialer_UnauthorizedFailFast(t *testing.T) {
294+
t.Parallel()
295+
ctx := testutil.Context(t, testutil.WaitShort)
296+
logger := slogtest.Make(t, &slogtest.Options{
297+
IgnoreErrors: true,
298+
}).Leveled(slog.LevelDebug)
299+
300+
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
301+
httpapi.Write(ctx, w, http.StatusUnauthorized, codersdk.Response{})
302+
}))
303+
defer svr.Close()
304+
svrURL, err := url.Parse(svr.URL)
305+
require.NoError(t, err)
306+
307+
uut := workspacesdk.NewWebsocketDialer(logger, svrURL, &websocket.DialOptions{})
308+
309+
_, err = uut.Dial(ctx, nil)
310+
require.Error(t, err)
311+
}
312+
273313
func TestWebsocketDialer_UplevelVersion(t *testing.T) {
274314
t.Parallel()
275315
ctx := testutil.Context(t, testutil.WaitShort)

0 commit comments

Comments
 (0)