Skip to content

Commit 2c843f4

Browse files
authored
fix: fix --header flag in CLI (coder#8023)
1 parent df842b3 commit 2c843f4

File tree

5 files changed

+140
-51
lines changed

5 files changed

+140
-51
lines changed

cli/root.go

+20-18
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,13 @@ import (
2121
"text/tabwriter"
2222
"time"
2323

24-
"golang.org/x/exp/slices"
25-
"golang.org/x/xerrors"
26-
27-
"cdr.dev/slog"
28-
2924
"github.com/charmbracelet/lipgloss"
30-
"github.com/gobwas/httphead"
3125
"github.com/mattn/go-isatty"
3226
"github.com/mitchellh/go-wordwrap"
27+
"golang.org/x/exp/slices"
28+
"golang.org/x/xerrors"
3329

30+
"cdr.dev/slog"
3431
"github.com/coder/coder/buildinfo"
3532
"github.com/coder/coder/cli/clibase"
3633
"github.com/coder/coder/cli/cliui"
@@ -430,6 +427,15 @@ type RootCmd struct {
430427
}
431428

432429
func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
430+
transport, ok := client.HTTPClient.Transport.(*headerTransport)
431+
if !ok {
432+
transport = &headerTransport{
433+
transport: client.HTTPClient.Transport,
434+
header: http.Header{},
435+
}
436+
client.HTTPClient.Transport = transport
437+
}
438+
433439
var topts []telemetry.CLIOption
434440
for _, opt := range inv.Command.FullOptions() {
435441
if opt.ValueSource == clibase.ValueSourceNone || opt.ValueSource == clibase.ValueSourceDefault {
@@ -459,10 +465,7 @@ func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) {
459465
return
460466
}
461467

462-
client.ExtraHeaders.Set(
463-
codersdk.CLITelemetryHeader,
464-
s,
465-
)
468+
transport.header.Add(codersdk.CLITelemetryHeader, s)
466469
}
467470

468471
// InitClient sets client to a new client.
@@ -560,18 +563,17 @@ func (r *RootCmd) setClient(client *codersdk.Client, serverURL *url.URL) error {
560563
transport: http.DefaultTransport,
561564
header: http.Header{},
562565
}
566+
for _, header := range r.header {
567+
parts := strings.SplitN(header, "=", 2)
568+
if len(parts) < 2 {
569+
return xerrors.Errorf("split header %q had less than two parts", header)
570+
}
571+
transport.header.Add(parts[0], parts[1])
572+
}
563573
client.URL = serverURL
564574
client.HTTPClient = &http.Client{
565575
Transport: transport,
566576
}
567-
client.ExtraHeaders = make(http.Header)
568-
for _, hd := range r.header {
569-
k, v, ok := httphead.ParseHeaderLine([]byte(hd))
570-
if !ok {
571-
return xerrors.Errorf("invalid header: %s", hd)
572-
}
573-
client.ExtraHeaders.Add(string(k), string(v))
574-
}
575577
return nil
576578
}
577579

cli/root_test.go

+108-9
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,19 @@ package cli_test
22

33
import (
44
"bytes"
5+
"fmt"
56
"net/http"
67
"net/http/httptest"
8+
"strings"
9+
"sync/atomic"
710
"testing"
811

912
"github.com/coder/coder/cli/clibase"
13+
"github.com/coder/coder/coderd"
14+
"github.com/coder/coder/coderd/coderdtest"
15+
"github.com/coder/coder/codersdk"
16+
"github.com/coder/coder/pty/ptytest"
17+
"github.com/coder/coder/testutil"
1018

1119
"github.com/stretchr/testify/assert"
1220
"github.com/stretchr/testify/require"
@@ -64,21 +72,112 @@ func TestRoot(t *testing.T) {
6472
t.Run("Header", func(t *testing.T) {
6573
t.Parallel()
6674

67-
done := make(chan struct{})
75+
var called int64
6876
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
77+
atomic.AddInt64(&called, 1)
6978
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
79+
assert.Equal(t, "Dean was Here!", r.Header.Get("Cool-Header"))
7080
w.WriteHeader(http.StatusGone)
71-
select {
72-
case <-done:
73-
close(done)
74-
default:
75-
}
7681
}))
7782
defer srv.Close()
7883
buf := new(bytes.Buffer)
79-
inv, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
84+
inv, _ := clitest.New(t,
85+
"--no-feature-warning",
86+
"--no-version-warning",
87+
"--header", "X-Testing=wow",
88+
"--header", "Cool-Header=Dean was Here!",
89+
"login", srv.URL,
90+
)
8091
inv.Stdout = buf
81-
// This won't succeed, because we're using the login cmd to assert requests.
82-
_ = inv.Run()
92+
93+
err := inv.Run()
94+
require.Error(t, err)
95+
require.ErrorContains(t, err, "unexpected status code 410")
96+
require.EqualValues(t, 1, atomic.LoadInt64(&called), "called exactly once")
97+
})
98+
}
99+
100+
// TestDERPHeaders ensures that the client sends the global `--header`s to the
101+
// DERP server when connecting.
102+
func TestDERPHeaders(t *testing.T) {
103+
t.Parallel()
104+
105+
// Create a coderd API instance the hard way since we need to change the
106+
// handler to inject our custom /derp handler.
107+
setHandler, cancelFunc, serverURL, newOptions := coderdtest.NewOptions(t, nil)
108+
109+
// We set the handler after server creation for the access URL.
110+
coderAPI := coderd.New(newOptions)
111+
setHandler(coderAPI.RootHandler)
112+
provisionerCloser := coderdtest.NewProvisionerDaemon(t, coderAPI)
113+
t.Cleanup(func() {
114+
_ = provisionerCloser.Close()
115+
})
116+
client := codersdk.New(serverURL)
117+
t.Cleanup(func() {
118+
cancelFunc()
119+
_ = provisionerCloser.Close()
120+
_ = coderAPI.Close()
121+
client.HTTPClient.CloseIdleConnections()
122+
})
123+
124+
var (
125+
user = coderdtest.CreateFirstUser(t, client)
126+
workspace = runAgent(t, client, user.UserID)
127+
)
128+
129+
// Inject custom /derp handler so we can inspect the headers.
130+
var (
131+
expectedHeaders = map[string]string{
132+
"X-Test-Header": "test-value",
133+
"Cool-Header": "Dean was Here!",
134+
}
135+
derpCalled int64
136+
)
137+
setHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
138+
if strings.HasPrefix(r.URL.Path, "/derp") {
139+
ok := true
140+
for k, v := range expectedHeaders {
141+
if r.Header.Get(k) != v {
142+
ok = false
143+
break
144+
}
145+
}
146+
if ok {
147+
// Only increment if all the headers are set, because the agent
148+
// calls derp also.
149+
atomic.AddInt64(&derpCalled, 1)
150+
}
151+
}
152+
153+
coderAPI.RootHandler.ServeHTTP(w, r)
154+
}))
155+
156+
// Connect with the headers set as args.
157+
args := []string{
158+
"--no-feature-warning",
159+
"--no-version-warning",
160+
"ping", workspace.Name,
161+
"-n", "1",
162+
}
163+
for k, v := range expectedHeaders {
164+
args = append(args, "--header", fmt.Sprintf("%s=%s", k, v))
165+
}
166+
inv, root := clitest.New(t, args...)
167+
clitest.SetupConfig(t, client, root)
168+
pty := ptytest.New(t)
169+
inv.Stdin = pty.Input()
170+
inv.Stderr = pty.Output()
171+
inv.Stdout = pty.Output()
172+
173+
ctx := testutil.Context(t, testutil.WaitLong)
174+
cmdDone := tGo(t, func() {
175+
err := inv.WithContext(ctx).Run()
176+
assert.NoError(t, err)
83177
})
178+
179+
pty.ExpectMatch("pong from " + workspace.Name)
180+
<-cmdDone
181+
182+
require.Greater(t, atomic.LoadInt64(&derpCalled), int64(0), "expected /derp to be called at least once")
84183
}

coderd/coderdtest/coderdtest.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ func NewOptions(t testing.TB, options *Options) (func(http.Handler), context.Can
256256
var handler http.Handler
257257
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
258258
mutex.RLock()
259-
defer mutex.RUnlock()
259+
handler := handler
260+
mutex.RUnlock()
260261
if handler != nil {
261262
handler.ServeHTTP(w, r)
262263
}

codersdk/client.go

+2-8
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,8 @@ var loggableMimeTypes = map[string]struct{}{
8585
// New creates a Coder client for the provided URL.
8686
func New(serverURL *url.URL) *Client {
8787
return &Client{
88-
URL: serverURL,
89-
HTTPClient: &http.Client{},
90-
ExtraHeaders: make(http.Header),
88+
URL: serverURL,
89+
HTTPClient: &http.Client{},
9190
}
9291
}
9392

@@ -97,9 +96,6 @@ type Client struct {
9796
mu sync.RWMutex // Protects following.
9897
sessionToken string
9998

100-
// ExtraHeaders are headers to add to every request.
101-
ExtraHeaders http.Header
102-
10399
HTTPClient *http.Client
104100
URL *url.URL
105101

@@ -193,8 +189,6 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac
193189
return nil, xerrors.Errorf("create request: %w", err)
194190
}
195191

196-
req.Header = c.ExtraHeaders.Clone()
197-
198192
tokenHeader := c.SessionTokenHeader
199193
if tokenHeader == "" {
200194
tokenHeader = SessionTokenHeader

go.mod

+8-15
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ require (
8080
github.com/coreos/go-oidc/v3 v3.6.0
8181
github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf
8282
github.com/creack/pty v1.1.18
83+
github.com/dave/dst v0.27.2
8384
github.com/elastic/go-sysinfo v1.11.0
8485
github.com/fatih/color v1.15.0
8586
github.com/fatih/structs v1.1.0
@@ -189,6 +190,7 @@ require (
189190
github.com/Microsoft/go-winio v0.6.1 // indirect
190191
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect
191192
github.com/OneOfOne/xxhash v1.2.8 // indirect
193+
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8 // indirect
192194
github.com/agext/levenshtein v1.2.3 // indirect
193195
github.com/agnivade/levenshtein v1.1.1 // indirect
194196
github.com/akutz/memconn v0.1.0 // indirect
@@ -206,6 +208,7 @@ require (
206208
github.com/charmbracelet/bubbles v0.15.0 // indirect
207209
github.com/charmbracelet/bubbletea v0.23.2 // indirect
208210
github.com/clbanning/mxj/v2 v2.5.7 // indirect
211+
github.com/cloudflare/circl v1.3.3 // indirect
209212
github.com/containerd/console v1.0.3 // indirect
210213
github.com/containerd/continuity v0.3.0 // indirect
211214
github.com/coreos/go-iptables v0.6.0 // indirect
@@ -217,6 +220,7 @@ require (
217220
github.com/docker/go-units v0.5.0 // indirect
218221
github.com/elastic/go-windows v1.0.0 // indirect
219222
github.com/fxamacker/cbor/v2 v2.4.0 // indirect
223+
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
220224
github.com/ghodss/yaml v1.0.0 // indirect
221225
github.com/gin-gonic/gin v1.9.1 // indirect
222226
github.com/go-logr/stdr v1.2.2 // indirect
@@ -227,6 +231,7 @@ require (
227231
github.com/go-openapi/swag v0.19.15 // indirect
228232
github.com/go-playground/locales v0.14.1 // indirect
229233
github.com/go-playground/universal-translator v0.18.1 // indirect
234+
github.com/go-test/deep v1.0.8 // indirect
230235
github.com/go-toast/toast v0.0.0-20190211030409-01e6764cf0a4 // indirect
231236
github.com/gobwas/glob v0.2.3 // indirect
232237
github.com/gobwas/ws v1.1.0 // indirect
@@ -317,6 +322,7 @@ require (
317322
github.com/tchap/go-patricia/v2 v2.3.1 // indirect
318323
github.com/tcnksm/go-httpstat v0.2.0 // indirect
319324
github.com/tdewolff/parse/v2 v2.6.6 // indirect
325+
github.com/tdewolff/test v1.0.9 // indirect
320326
github.com/u-root/uio v0.0.0-20221213070652-c3537552635f // indirect
321327
github.com/vishvananda/netlink v1.1.1-0.20211118161826-650dca95af54 // indirect
322328
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
@@ -346,22 +352,9 @@ require (
346352
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
347353
google.golang.org/appengine v1.6.7 // indirect
348354
google.golang.org/genproto v0.0.0-20230530153820-e85fd2cbaebc // indirect
355+
google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect
356+
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
349357
gopkg.in/yaml.v2 v2.4.0 // indirect
350358
howett.net/plist v1.0.0 // indirect
351359
inet.af/peercred v0.0.0-20210906144145-0893ea02156a // indirect
352360
)
353-
354-
require (
355-
github.com/dave/dst v0.27.2
356-
github.com/gobwas/httphead v0.1.0
357-
)
358-
359-
require (
360-
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8 // indirect
361-
github.com/cloudflare/circl v1.3.3 // indirect
362-
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
363-
github.com/go-test/deep v1.0.8 // indirect
364-
github.com/tdewolff/test v1.0.9 // indirect
365-
google.golang.org/genproto/googleapis/api v0.0.0-20230530153820-e85fd2cbaebc // indirect
366-
google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect
367-
)

0 commit comments

Comments
 (0)