Skip to content

Commit ab45e44

Browse files
committed
Fix requested changes
1 parent cb3725e commit ab45e44

File tree

4 files changed

+45
-34
lines changed

4 files changed

+45
-34
lines changed

cli/vscodeipc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func vscodeipcCmd() *cobra.Command {
6666
Handler: handler,
6767
}
6868
defer server.Close()
69-
cmd.Printf("%d\n", addr.Port)
69+
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", addr.String())
7070
errChan := make(chan error, 1)
7171
go func() {
7272
err := server.Serve(listener)

cli/vscodeipc/vscodeipc.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
//
3636
// The VS Code extension is located at https://github.com/coder/vscode-coder. The
3737
// extension downloads the slim binary from `/bin/*` and executes `coder vscodeipc`
38-
// which calls this function. This API must maintain backawards compatibility with
38+
// which calls this function. This API must maintain backward compatibility with
3939
// the extension to support prior versions of Coder.
4040
func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, options *codersdk.DialWorkspaceAgentOptions) (http.Handler, io.Closer, error) {
4141
if options == nil {
@@ -54,28 +54,12 @@ func New(ctx context.Context, client *codersdk.Client, agentID uuid.UUID, option
5454
r := chi.NewRouter()
5555
// This is to prevent unauthorized clients on the same machine from executing
5656
// requests on behalf of the workspace.
57-
r.Use(func(h http.Handler) http.Handler {
58-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59-
token := r.Header.Get("Coder-Session-Token")
60-
if token == "" {
61-
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
62-
Message: "A session token must be provided in the `Coder-Session-Token` header.",
63-
})
64-
return
65-
}
66-
if token != client.SessionToken() {
67-
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
68-
Message: "The session token provided doesn't match the one used to create the client.",
69-
})
70-
return
71-
}
72-
w.Header().Set("Access-Control-Allow-Origin", "*")
73-
h.ServeHTTP(w, r)
74-
})
57+
r.Use(sessionTokenMiddleware(client.SessionToken()))
58+
r.Route("/v1", func(r chi.Router) {
59+
r.Get("/port/{port}", api.port)
60+
r.Get("/network", api.network)
61+
r.Post("/execute", api.execute)
7562
})
76-
r.Get("/port/{port}", api.port)
77-
r.Get("/network", api.network)
78-
r.Post("/execute", api.execute)
7963
return r, api, nil
8064
}
8165

@@ -137,9 +121,9 @@ func (api *api) port(w http.ResponseWriter, r *http.Request) {
137121
httpapi.InternalServerError(w, err)
138122
return
139123
}
124+
defer localConn.Close()
140125

141126
_ = brw.Flush()
142-
defer localConn.Close()
143127
agent.Bicopy(r.Context(), localConn, remoteConn)
144128
}
145129

@@ -222,6 +206,9 @@ func (api *api) execute(w http.ResponseWriter, r *http.Request) {
222206
api.sshClientOnce.Do(func() {
223207
// The SSH client is lazily created because it's not needed for
224208
// all requests. It's only needed for the execute endpoint.
209+
//
210+
// It's alright if this fails on the first execution, because
211+
// a new instance of this API is created for each remote SSH request.
225212
api.sshClient, api.sshClientErr = api.agentConn.SSHClient(context.Background())
226213
})
227214
if api.sshClientErr != nil {
@@ -292,7 +279,32 @@ func (e *execWriter) Write(data []byte) (int, error) {
292279
if err != nil {
293280
return 0, err
294281
}
295-
_, _ = e.w.Write(js)
282+
_, err = e.w.Write(js)
283+
if err != nil {
284+
return 0, err
285+
}
296286
e.f.Flush()
297287
return len(data), nil
298288
}
289+
290+
func sessionTokenMiddleware(sessionToken string) func(h http.Handler) http.Handler {
291+
return func(h http.Handler) http.Handler {
292+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293+
token := r.Header.Get("Coder-IPC-Token")
294+
if token == "" {
295+
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
296+
Message: "A session token must be provided in the `Coder-IPC-Token` header.",
297+
})
298+
return
299+
}
300+
if token != sessionToken {
301+
httpapi.Write(r.Context(), w, http.StatusUnauthorized, codersdk.Response{
302+
Message: "The session token provided doesn't match the one used to create the client.",
303+
})
304+
return
305+
}
306+
w.Header().Set("Access-Control-Allow-Origin", "*")
307+
h.ServeHTTP(w, r)
308+
})
309+
}
310+
}

cli/vscodeipc/vscodeipc_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func TestVSCodeIPC(t *testing.T) {
115115
handler.ServeHTTP(res, req)
116116
network := &vscodeipc.NetworkResponse{}
117117
err = json.NewDecoder(res.Body).Decode(&network)
118-
require.NoError(t, err)
118+
assert.NoError(t, err)
119119
return network.Latency != 0
120120
}, testutil.WaitLong, testutil.IntervalFast)
121121

cli/vscodeipc_test.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package cli_test
22

33
import (
4-
"bytes"
5-
"fmt"
4+
"io"
65
"testing"
76

87
"github.com/stretchr/testify/require"
@@ -20,8 +19,8 @@ func TestVSCodeIPC(t *testing.T) {
2019
client, workspace, _ := setupWorkspaceForAgent(t, nil)
2120
cmd, _ := clitest.New(t, "vscodeipc", workspace.LatestBuild.Resources[0].Agents[0].ID.String(),
2221
"--token", client.SessionToken(), "--url", client.URL.String())
23-
var buf bytes.Buffer
24-
cmd.SetOut(&buf)
22+
rdr, wtr := io.Pipe()
23+
cmd.SetOut(wtr)
2524
ctx, cancelFunc := testutil.Context(t)
2625
defer cancelFunc()
2726
done := make(chan error, 1)
@@ -30,14 +29,14 @@ func TestVSCodeIPC(t *testing.T) {
3029
done <- err
3130
}()
3231

33-
var line string
32+
buf := make([]byte, 64)
3433
require.Eventually(t, func() bool {
35-
t.Log("Looking for port!")
34+
t.Log("Looking for address!")
3635
var err error
37-
line, err = buf.ReadString('\n')
36+
_, err = rdr.Read(buf)
3837
return err == nil
3938
}, testutil.WaitMedium, testutil.IntervalFast)
40-
t.Logf("Port: %s\n", line)
39+
t.Logf("Address: %s\n", buf)
4140

4241
cancelFunc()
4342
<-done

0 commit comments

Comments
 (0)