Skip to content

Commit e9b7463

Browse files
committed
Add workspace route proxying endpoint
- Makes the workspace conn cache concurrency-safe - Reduces unnecessary open checks in `peer.Channel` - Fixes the use of a temporary context when dialing a workspace agent
1 parent 4d8b257 commit e9b7463

File tree

13 files changed

+247
-182
lines changed

13 files changed

+247
-182
lines changed

agent/conn.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func (c *Conn) DialContext(ctx context.Context, network string, addr string) (ne
102102
var res dialResponse
103103
err = dec.Decode(&res)
104104
if err != nil {
105-
return nil, xerrors.Errorf("failed to decode initial packet: %w", err)
105+
return nil, xerrors.Errorf("decode agent dial response: %w", err)
106106
}
107107
if res.Error != "" {
108108
_ = channel.Close()

coderd/coderd.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ import (
1515
"golang.org/x/xerrors"
1616
"google.golang.org/api/idtoken"
1717

18-
"github.com/go-chi/cors"
19-
2018
sdktrace "go.opentelemetry.io/otel/sdk/trace"
2119

2220
"cdr.dev/slog"
@@ -97,7 +95,7 @@ func New(options *Options) *API {
9795
tracing.HTTPMW(api.TracerProvider, "coderd.http"),
9896
)
9997

100-
r.Route("/{user}/{workspaceagent}/{application}", func(r chi.Router) {
98+
r.Route("/@{user}/{workspaceagent}/apps/{application}", func(r chi.Router) {
10199
r.Use(
102100
httpmw.RateLimitPerMinute(options.APIRateLimit),
103101
apiKeyMiddleware,
@@ -327,9 +325,6 @@ func New(options *Options) *API {
327325
r.Put("/extend", api.putExtendWorkspace)
328326
})
329327
})
330-
r.Route("/wildcardauth", func(r chi.Router) {
331-
r.Use(cors.Handler(cors.Options{}))
332-
})
333328
r.Route("/workspacebuilds/{workspacebuild}", func(r chi.Router) {
334329
r.Use(
335330
apiKeyMiddleware,
@@ -357,10 +352,12 @@ type API struct {
357352
}
358353

359354
// Close waits for all WebSocket connections to drain before returning.
360-
func (api *API) Close() {
355+
func (api *API) Close() error {
361356
api.websocketWaitMutex.Lock()
362357
api.websocketWaitGroup.Wait()
363358
api.websocketWaitMutex.Unlock()
359+
360+
return api.workspaceAgentCache.Close()
364361
}
365362

366363
func debugLogRequest(log slog.Logger) func(http.Handler) http.Handler {

coderd/coderdtest/coderdtest.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, *coderd.API)
172172
cancelFunc()
173173
_ = turnServer.Close()
174174
srv.Close()
175-
coderAPI.Close()
175+
_ = coderAPI.Close()
176176
})
177177

178178
return codersdk.New(serverURL), coderAPI

coderd/database/databasefake/databasefake.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ func (q *fakeQuerier) GetWorkspaceAgentsByResourceIDs(_ context.Context, resourc
10611061
return workspaceAgents, nil
10621062
}
10631063

1064-
func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(ctx context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) {
1064+
func (q *fakeQuerier) GetWorkspaceAppByAgentIDAndName(_ context.Context, arg database.GetWorkspaceAppByAgentIDAndNameParams) (database.WorkspaceApp, error) {
10651065
q.mutex.RLock()
10661066
defer q.mutex.RUnlock()
10671067

coderd/workspaceagents.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package coderd
22

33
import (
4+
"context"
45
"database/sql"
56
"encoding/json"
67
"fmt"
@@ -382,12 +383,12 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
382383
}()
383384
// Accept text connections, because it's more developer friendly.
384385
wsNetConn := websocket.NetConn(r.Context(), conn, websocket.MessageBinary)
385-
agentConn, err := api.dialWorkspaceAgent(r, workspaceAgent.ID)
386+
agentConn, release, err := api.workspaceAgentCache.Acquire(r, workspaceAgent.ID)
386387
if err != nil {
387388
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial workspace agent: %s", err))
388389
return
389390
}
390-
defer agentConn.Close()
391+
defer release()
391392
ptNetConn, err := agentConn.ReconnectingPTY(reconnect.String(), uint16(height), uint16(width), "")
392393
if err != nil {
393394
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("dial: %s", err))
@@ -404,8 +405,9 @@ func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
404405
// dialWorkspaceAgent connects to a workspace agent by ID.
405406
func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.Conn, error) {
406407
client, server := provisionersdk.TransportPipe()
408+
ctx, cancelFunc := context.WithCancel(context.Background())
407409
go func() {
408-
_ = peerbroker.ProxyListen(r.Context(), server, peerbroker.ProxyOptions{
410+
_ = peerbroker.ProxyListen(ctx, server, peerbroker.ProxyOptions{
409411
ChannelID: agentID.String(),
410412
Logger: api.Logger.Named("peerbroker-proxy-dial"),
411413
Pubsub: api.Pubsub,
@@ -415,8 +417,9 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
415417
}()
416418

417419
peerClient := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(client))
418-
stream, err := peerClient.NegotiateConnection(r.Context())
420+
stream, err := peerClient.NegotiateConnection(ctx)
419421
if err != nil {
422+
cancelFunc()
420423
return nil, xerrors.Errorf("negotiate: %w", err)
421424
}
422425
options := &peer.ConnOptions{
@@ -452,8 +455,13 @@ func (api *API) dialWorkspaceAgent(r *http.Request, agentID uuid.UUID) (*agent.C
452455
}))
453456
peerConn, err := peerbroker.Dial(stream, append(api.ICEServers, turnconn.Proxy), options)
454457
if err != nil {
458+
cancelFunc()
455459
return nil, xerrors.Errorf("dial: %w", err)
456460
}
461+
go func() {
462+
<-peerConn.Closed()
463+
cancelFunc()
464+
}()
457465
return &agent.Conn{
458466
Negotiator: peerClient,
459467
Conn: peerConn,

coderd/workspaceapps.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,22 @@ func (api *API) workspaceAppsProxyPath(rw http.ResponseWriter, r *http.Request)
123123
defer release()
124124

125125
proxy := httputil.NewSingleHostReverseProxy(appURL)
126+
// Write the error directly using our format!
127+
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
128+
httpapi.Write(w, http.StatusBadGateway, httpapi.Response{
129+
Message: err.Error(),
130+
})
131+
}
126132
proxy.Transport = conn.HTTPTransport()
127-
r.URL.Path = chi.URLParam(r, "*")
133+
path := chi.URLParam(r, "*")
134+
if !strings.HasSuffix(r.URL.Path, "/") && path == "" {
135+
// Web applications typically request paths relative to the
136+
// root URL. This allows for routing behind a proxy or subpath.
137+
// See https://github.com/coder/code-server/issues/241 for examples.
138+
r.URL.Path += "/"
139+
http.Redirect(rw, r, r.URL.String(), http.StatusTemporaryRedirect)
140+
return
141+
}
142+
r.URL.Path = path
128143
proxy.ServeHTTP(rw, r)
129144
}

coderd/workspaceapps_test.go

Lines changed: 68 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,69 +21,81 @@ import (
2121

2222
func TestWorkspaceAppsProxyPath(t *testing.T) {
2323
t.Parallel()
24-
t.Run("Proxies", func(t *testing.T) {
25-
t.Parallel()
26-
// #nosec
27-
ln, err := net.Listen("tcp", ":0")
28-
require.NoError(t, err)
29-
server := http.Server{
30-
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
31-
w.WriteHeader(http.StatusOK)
32-
}),
33-
}
34-
t.Cleanup(func() {
35-
_ = server.Close()
36-
_ = ln.Close()
37-
})
38-
go server.Serve(ln)
39-
tcpAddr, _ := ln.Addr().(*net.TCPAddr)
24+
// #nosec
25+
ln, err := net.Listen("tcp", ":0")
26+
require.NoError(t, err)
27+
server := http.Server{
28+
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29+
w.WriteHeader(http.StatusOK)
30+
}),
31+
}
32+
t.Cleanup(func() {
33+
_ = server.Close()
34+
_ = ln.Close()
35+
})
36+
go server.Serve(ln)
37+
tcpAddr, _ := ln.Addr().(*net.TCPAddr)
4038

41-
client, coderAPI := coderdtest.NewWithAPI(t, nil)
42-
user := coderdtest.CreateFirstUser(t, client)
43-
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
44-
authToken := uuid.NewString()
45-
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
46-
Parse: echo.ParseComplete,
47-
ProvisionDryRun: echo.ProvisionComplete,
48-
Provision: []*proto.Provision_Response{{
49-
Type: &proto.Provision_Response_Complete{
50-
Complete: &proto.Provision_Complete{
51-
Resources: []*proto.Resource{{
52-
Name: "example",
53-
Type: "aws_instance",
54-
Agents: []*proto.Agent{{
55-
Id: uuid.NewString(),
56-
Auth: &proto.Agent_Token{
57-
Token: authToken,
58-
},
59-
Apps: []*proto.App{{
60-
Name: "example",
61-
Url: fmt.Sprintf("http://127.0.0.1:%d", tcpAddr.Port),
62-
}},
39+
client, coderAPI := coderdtest.NewWithAPI(t, nil)
40+
user := coderdtest.CreateFirstUser(t, client)
41+
daemonCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
42+
authToken := uuid.NewString()
43+
version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{
44+
Parse: echo.ParseComplete,
45+
ProvisionDryRun: echo.ProvisionComplete,
46+
Provision: []*proto.Provision_Response{{
47+
Type: &proto.Provision_Response_Complete{
48+
Complete: &proto.Provision_Complete{
49+
Resources: []*proto.Resource{{
50+
Name: "example",
51+
Type: "aws_instance",
52+
Agents: []*proto.Agent{{
53+
Id: uuid.NewString(),
54+
Auth: &proto.Agent_Token{
55+
Token: authToken,
56+
},
57+
Apps: []*proto.App{{
58+
Name: "example",
59+
Url: fmt.Sprintf("http://127.0.0.1:%d", tcpAddr.Port),
6360
}},
6461
}},
65-
},
62+
}},
6663
},
67-
}},
68-
})
69-
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
70-
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
71-
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
72-
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
73-
daemonCloser.Close()
64+
},
65+
}},
66+
})
67+
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
68+
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
69+
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
70+
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
71+
daemonCloser.Close()
7472

75-
agentClient := codersdk.New(client.URL)
76-
agentClient.SessionToken = authToken
77-
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
78-
Logger: slogtest.Make(t, nil),
79-
})
80-
t.Cleanup(func() {
81-
_ = agentCloser.Close()
82-
})
83-
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
73+
agentClient := codersdk.New(client.URL)
74+
agentClient.SessionToken = authToken
75+
agentCloser := agent.New(agentClient.ListenWorkspaceAgent, &agent.Options{
76+
Logger: slogtest.Make(t, nil),
77+
})
78+
t.Cleanup(func() {
79+
_ = agentCloser.Close()
80+
})
81+
coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID)
82+
client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
83+
return http.ErrUseLastResponse
84+
}
8485

85-
resp, err := client.Request(context.Background(), http.MethodGet, "/me/"+workspace.Name+"/example", nil)
86+
t.Run("RedirectsWithSlash", func(t *testing.T) {
87+
t.Parallel()
88+
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example", nil)
89+
require.NoError(t, err)
90+
defer resp.Body.Close()
91+
require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
92+
})
93+
94+
t.Run("Proxies", func(t *testing.T) {
95+
t.Parallel()
96+
resp, err := client.Request(context.Background(), http.MethodGet, "/@me/"+workspace.Name+"/apps/example/", nil)
8697
require.NoError(t, err)
98+
defer resp.Body.Close()
8799
body, err := io.ReadAll(resp.Body)
88100
require.NoError(t, err)
89101
require.Equal(t, "", string(body))

0 commit comments

Comments
 (0)