Skip to content

Commit aa99c6c

Browse files
committed
Adjust to use interface to cast headers
1 parent 2e81ac6 commit aa99c6c

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

cli/root.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -334,22 +334,16 @@ func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*coder
334334
}
335335
transport := &headerTransport{
336336
transport: http.DefaultTransport,
337-
headers: map[string]string{},
337+
header: http.Header{},
338338
}
339339
for _, header := range headers {
340340
parts := strings.SplitN(header, "=", 2)
341341
if len(parts) < 2 {
342342
return nil, xerrors.Errorf("split header %q had less than two parts", header)
343343
}
344-
transport.headers[parts[0]] = parts[1]
344+
transport.header.Add(parts[0], parts[1])
345345
}
346-
347346
client.HTTPClient.Transport = transport
348-
client.DERPHeader = &http.Header{}
349-
for header, value := range transport.headers {
350-
client.DERPHeader.Set(header, value)
351-
}
352-
353347
return client, nil
354348
}
355349

@@ -661,12 +655,18 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
661655

662656
type headerTransport struct {
663657
transport http.RoundTripper
664-
headers map[string]string
658+
header http.Header
659+
}
660+
661+
func (h *headerTransport) Header() http.Header {
662+
return h.header.Clone()
665663
}
666664

667665
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
668-
for k, v := range h.headers {
669-
req.Header.Add(k, v)
666+
for k, v := range h.header {
667+
for _, vv := range v {
668+
req.Header.Add(k, vv)
669+
}
670670
}
671671
return h.transport.RoundTrip(req)
672672
}

cli/scaletest.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ func scaletestCleanup() *cobra.Command {
330330
client.HTTPClient = &http.Client{
331331
Transport: &headerTransport{
332332
transport: http.DefaultTransport,
333-
headers: map[string]string{
334-
codersdk.BypassRatelimitHeader: "true",
333+
header: map[string][]string{
334+
codersdk.BypassRatelimitHeader: {"true"},
335335
},
336336
},
337337
}
@@ -515,8 +515,8 @@ It is recommended that all rate limits are disabled on the server before running
515515
client.HTTPClient = &http.Client{
516516
Transport: &headerTransport{
517517
transport: http.DefaultTransport,
518-
headers: map[string]string{
519-
codersdk.BypassRatelimitHeader: "true",
518+
header: map[string][]string{
519+
codersdk.BypassRatelimitHeader: {"true"},
520520
},
521521
},
522522
}

codersdk/client.go

-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ type Client struct {
7979
HTTPClient *http.Client
8080
URL *url.URL
8181

82-
DERPHeader *http.Header
83-
8482
// Logger is optionally provided to log requests.
8583
// Method, URL, and response code will be logged by default.
8684
Logger slog.Logger

codersdk/workspaceagents.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,13 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti
143143
}
144144

145145
ip := tailnet.IP()
146-
header := c.DERPHeader.Clone()
146+
var header http.Header
147+
headerTransport, ok := c.HTTPClient.Transport.(interface {
148+
Header() http.Header
149+
})
150+
if ok {
151+
header = headerTransport.Header().Clone()
152+
}
147153
conn, err := tailnet.NewConn(&tailnet.Options{
148154
Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)},
149155
DERPMap: connInfo.DERPMap,

tailnet/conn.go

-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ func NewConn(options *Options) (conn *Conn, err error) {
161161
if !ok {
162162
return nil, xerrors.New("get wireguard internals")
163163
}
164-
165164
if options.DERPHeader != nil {
166165
magicConn.SetDERPHeader(options.DERPHeader.Clone())
167166
}

0 commit comments

Comments
 (0)