Skip to content

Commit 489eba0

Browse files
committed
Add authentication header, improve comments, and add tests for the CLI
1 parent efc3025 commit 489eba0

File tree

7 files changed

+139
-12
lines changed

7 files changed

+139
-12
lines changed

cli/ssh_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ func setupWorkspaceForAgent(t *testing.T, mutate func([]*proto.Agent) []*proto.A
6565
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
6666
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
6767
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
68+
workspace, err := client.Workspace(context.Background(), workspace.ID)
69+
require.NoError(t, err)
6870

6971
return client, workspace, agentToken
7072
}

cli/vscodeipc.go

+28-9
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,46 @@ import (
55
"net"
66
"net/http"
77
"net/url"
8-
"os"
98

109
"github.com/google/uuid"
1110
"github.com/spf13/cobra"
1211
"golang.org/x/xerrors"
1312

13+
"github.com/coder/coder/cli/cliflag"
1414
"github.com/coder/coder/cli/vscodeipc"
1515
"github.com/coder/coder/codersdk"
1616
)
1717

1818
// vscodeipcCmd spawns a local HTTP server on the provided port that listens to messages.
1919
// It's made for use by the Coder VS Code extension. See: https://github.com/coder/vscode-coder
2020
func vscodeipcCmd() *cobra.Command {
21-
var port uint16
21+
var (
22+
rawURL string
23+
token string
24+
port uint16
25+
)
2226
cmd := &cobra.Command{
2327
Use: "vscodeipc <workspace-agent>",
2428
Args: cobra.ExactArgs(1),
2529
Hidden: true,
2630
RunE: func(cmd *cobra.Command, args []string) error {
27-
rawURL := os.Getenv("CODER_URL")
2831
if rawURL == "" {
2932
return xerrors.New("CODER_URL must be set!")
3033
}
31-
token := os.Getenv("CODER_TOKEN")
34+
// token is validated in a header on each request to prevent
35+
// unauthenticated clients from connecting.
3236
if token == "" {
3337
return xerrors.New("CODER_TOKEN must be set!")
3438
}
35-
if port == 0 {
36-
return xerrors.Errorf("port must be specified!")
37-
}
3839
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
3940
if err != nil {
4041
return xerrors.Errorf("listen: %w", err)
4142
}
4243
defer listener.Close()
44+
addr, ok := listener.Addr().(*net.TCPAddr)
45+
if !ok {
46+
return xerrors.Errorf("listener.Addr() is not a *net.TCPAddr: %T", listener.Addr())
47+
}
4348
url, err := url.Parse(rawURL)
4449
if err != nil {
4550
return err
@@ -56,13 +61,27 @@ func vscodeipcCmd() *cobra.Command {
5661
return err
5762
}
5863
defer closer.Close()
64+
// nolint:gosec
5965
server := http.Server{
6066
Handler: handler,
6167
}
62-
cmd.Printf("Ready\n")
63-
return server.Serve(listener)
68+
defer server.Close()
69+
cmd.Printf("%d\n", addr.Port)
70+
errChan := make(chan error, 1)
71+
go func() {
72+
err := server.Serve(listener)
73+
errChan <- err
74+
}()
75+
select {
76+
case <-cmd.Context().Done():
77+
return cmd.Context().Err()
78+
case err := <-errChan:
79+
return err
80+
}
6481
},
6582
}
83+
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "u", "CODER_URL", "", "The URL of the Coder instance!")
84+
cliflag.StringVarP(cmd.Flags(), &token, "token", "t", "CODER_TOKEN", "", "The session token of the user!")
6685
cmd.Flags().Uint16VarP(&port, "port", "p", 0, "The port to listen on!")
6786
return cmd
6887
}

cli/vscodeipc/vscodeipc.go

+32-3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ import (
3232
//
3333
// This persists a single workspace connection, and lets you execute commands, check
3434
// for network information, and forward ports.
35+
//
36+
// The VS Code extension is located at https://github.com/coder/vscode-coder. The
37+
// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc`
38+
// which calls this function. This API must maintain backawards compatibility with
39+
// the extension to support prior versions of Coder.
3540
func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) {
3641
if options == nil {
3742
options = &codersdk.DialWorkspaceAgentOptions{}
@@ -47,6 +52,27 @@ func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, option
4752
agentConn: agentConn,
4853
}
4954
r := chi.NewRouter()
55+
// This is to prevent unauthorized clients on the same machine from executing
56+
// requests on behalf of the workspace.
57+
r.Use(func(h http.Handler) http.Handler {
58+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59+
token := r.Header.Get("Coder-Session-Token")
60+
if token == "" {
61+
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
62+
Message: "A session token must be provided in the `Coder-Session-Token` header.",
63+
})
64+
return
65+
}
66+
if token != client.SessionToken() {
67+
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
68+
Message: "The session token provided doesn't match the one used to create the client.",
69+
})
70+
return
71+
}
72+
w.Header().Set("Access-Control-Allow-Origin", "*")
73+
h.ServeHTTP(w, r)
74+
})
75+
})
5076
r.Get("/port/{port}", api.port)
5177
r.Get("/network", api.network)
5278
r.Post("/execute", api.execute)
@@ -160,8 +186,9 @@ func (api *api) network(w http.ResponseWriter, r *http.Request) {
160186
totalRx += stat.RxBytes
161187
totalTx += stat.TxBytes
162188
}
189+
// Tracking the time since last request is required because
190+
// ExtractTrafficStats() resets its counters after each call.
163191
dur := time.Since(api.lastNetwork)
164-
165192
uploadSecs := float64(totalTx) / dur.Seconds()
166193
downloadSecs := float64(totalRx) / dur.Seconds()
167194

@@ -198,7 +225,6 @@ func (api *api) execute(w http.ResponseWriter, r *http.Request) {
198225
api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background())
199226
})
200227
if api.sshClientErr != nil {
201-
fmt.Printf("WE GOT TO BEGIN ERR! %s\n", api.sshClientErr)
202228
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
203229
Message: "Failed to create SSH client.",
204230
Detail: api.sshClientErr.Error(),
@@ -216,7 +242,10 @@ func (api *api) execute(w http.ResponseWriter, r *http.Request) {
216242
defer session.Close()
217243
f, ok := w.(http.Flusher)
218244
if !ok {
219-
panic("http.ResponseWriter is not http.Flusher")
245+
httpapi.Write(r.Context(), w, http.StatusInternalServerError, codersdk.Response{
246+
Message: fmt.Sprintf("http.ResponseWriter is not http.Flusher: %T", w),
247+
})
248+
return
220249
}
221250

222251
execWriter := &execWriter{w, f}

cli/vscodeipc/vscodeipc_test.go

+29
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"github.com/google/uuid"
1717
"github.com/spf13/afero"
18+
"github.com/stretchr/testify/assert"
1819
"github.com/stretchr/testify/require"
1920
"go.uber.org/goleak"
2021
"nhooyr.io/websocket"
@@ -42,31 +43,37 @@ func TestVSCodeIPC(t *testing.T) {
4243
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4344
switch r.URL.Path {
4445
case fmt.Sprintf("/api/v2/workspaceagents/%s/connection", id):
46+
assert.Equal(t, r.Method, http.MethodGet)
4547
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentConnectionInfo{
4648
DERPMap: derpMap,
4749
})
4850
return
4951
case fmt.Sprintf("/api/v2/workspaceagents/%s/coordinate", id):
52+
assert.Equal(t, r.Method, http.MethodGet)
5053
ws, err := websocket.Accept(w, r, nil)
5154
require.NoError(t, err)
5255
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
5356
_ = coordinator.ServeClient(conn, uuid.New(), id)
5457
return
5558
case "/api/v2/workspaceagents/me/version":
59+
assert.Equal(t, r.Method, http.MethodPost)
5660
w.WriteHeader(http.StatusOK)
5761
return
5862
case "/api/v2/workspaceagents/me/metadata":
63+
assert.Equal(t, r.Method, http.MethodGet)
5964
httpapi.Write(ctx, w, http.StatusOK, codersdk.WorkspaceAgentMetadata{
6065
DERPMap: derpMap,
6166
})
6267
return
6368
case "/api/v2/workspaceagents/me/coordinate":
69+
assert.Equal(t, r.Method, http.MethodGet)
6470
ws, err := websocket.Accept(w, r, nil)
6571
require.NoError(t, err)
6672
conn := websocket.NetConn(ctx, ws, websocket.MessageBinary)
6773
_ = coordinator.ServeAgent(conn, id)
6874
return
6975
case "/api/v2/workspaceagents/me/report-stats":
76+
assert.Equal(t, r.Method, http.MethodPost)
7077
w.WriteHeader(http.StatusOK)
7178
return
7279
case "/":
@@ -80,6 +87,8 @@ func TestVSCodeIPC(t *testing.T) {
8087
srvURL, _ := url.Parse(srv.URL)
8188

8289
client := codersdk.New(srvURL)
90+
token := uuid.New().String()
91+
client.SetSessionToken(token)
8392
agentConn := agent.New(agent.Options{
8493
Client: client,
8594
Filesystem: afero.NewMemMapFs(),
@@ -99,6 +108,7 @@ func TestVSCodeIPC(t *testing.T) {
99108
require.Eventually(t, func() bool {
100109
res := httptest.NewRecorder()
101110
req := httptest.NewRequest(http.MethodGet, "/network", nil)
111+
req.Header.Set("Coder-Session-Token", token)
102112
handler.ServeHTTP(res, req)
103113
network := &vscodeipc.NetworkResponse{}
104114
err = json.NewDecoder(res.Body).Decode(&network)
@@ -109,6 +119,23 @@ func TestVSCodeIPC(t *testing.T) {
109119
_, port, err := net.SplitHostPort(srvURL.Host)
110120
require.NoError(t, err)
111121

122+
t.Run("NoSessionToken", func(t *testing.T) {
123+
t.Parallel()
124+
res := httptest.NewRecorder()
125+
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/port/%s", port), nil)
126+
handler.ServeHTTP(res, req)
127+
require.Equal(t, http.StatusUnauthorized, res.Code)
128+
})
129+
130+
t.Run("MismatchedSessionToken", func(t *testing.T) {
131+
t.Parallel()
132+
res := httptest.NewRecorder()
133+
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/port/%s", port), nil)
134+
req.Header.Set("Coder-Session-Token", uuid.NewString())
135+
handler.ServeHTTP(res, req)
136+
require.Equal(t, http.StatusUnauthorized, res.Code)
137+
})
138+
112139
t.Run("Port", func(t *testing.T) {
113140
// Tests that the port endpoint can be used for forward traffic.
114141
// For this test, we simply use the already listening httptest server.
@@ -118,6 +145,7 @@ func TestVSCodeIPC(t *testing.T) {
118145
defer output.Close()
119146
res := &hijackable{httptest.NewRecorder(), output}
120147
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/port/%s", port), nil)
148+
req.Header.Set("Coder-Session-Token", token)
121149
go handler.ServeHTTP(res, req)
122150

123151
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://127.0.0.1/", nil)
@@ -147,6 +175,7 @@ func TestVSCodeIPC(t *testing.T) {
147175
Command: "echo test",
148176
})
149177
req := httptest.NewRequest(http.MethodPost, "/execute", bytes.NewReader(data))
178+
req.Header.Set("Coder-Session-Token", token)
150179
handler.ServeHTTP(res, req)
151180

152181
decoder := json.NewDecoder(res.Body)

cli/vscodeipc_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package cli_test
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/cli/clitest"
11+
"github.com/coder/coder/testutil"
12+
)
13+
14+
func TestVSCodeIPC(t *testing.T) {
15+
t.Parallel()
16+
// Ensures the vscodeipc command outputs it's running port!
17+
// This signifies to the caller that it's ready to accept requests.
18+
t.Run("PortOutputs", func(t *testing.T) {
19+
t.Parallel()
20+
client, workspace, _ := setupWorkspaceForAgent(t, nil)
21+
cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
22+
"--token", client.SessionToken(), "--url", client.URL.String())
23+
var buf bytes.Buffer
24+
cmd.SetOut(&buf)
25+
ctx, cancelFunc := testutil.Context(t)
26+
defer cancelFunc()
27+
done := make(chan error)
28+
go func() {
29+
err := cmd.ExecuteContext(ctx)
30+
done <- err
31+
}()
32+
33+
var line string
34+
require.Eventually(t, func() bool {
35+
fmt.Printf("Looking for port!\n")
36+
var err error
37+
line, err = buf.ReadString('\n')
38+
return err == nil
39+
}, testutil.WaitMedium, testutil.IntervalFast)
40+
t.Logf("Port: %s\n", line)
41+
42+
cancelFunc()
43+
<-done
44+
})
45+
}

codersdk/agentconn.go

+2
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ func (c *AgentConn) AwaitReachable(ctx context.Context) bool {
139139
return c.Conn.AwaitReachable(ctx, TailnetIP)
140140
}
141141

142+
// Ping pings the agent and returns the round-trip time.
143+
// The bool returns true if the ping was made P2P.
142144
func (c *AgentConn) Ping(ctx context.Context) (time.Duration, bool, error) {
143145
ctx, span := tracing.StartSpan(ctx)
144146
defer span.End()

tailnet/conn.go

+1
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ func (c *Conn) Status() *ipnstate.Status {
408408
}
409409

410410
// Ping sends a Disco ping to the Wireguard engine.
411+
// The bool returned is true if the ping was performed P2P.
411412
func (c *Conn) Ping(ctx context.Context, ip netip.Addr) (time.Duration, bool, error) {
412413
errCh := make(chan error, 1)
413414
prChan := make(chan *ipnstate.PingResult, 1)

0 commit comments

Comments
 (0)