diff --git a/cli/root.go b/cli/root.go index f3db7af279f62..8d6f14b06f9a1 100644 --- a/cli/root.go +++ b/cli/root.go @@ -334,14 +334,14 @@ func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*coder } transport := &headerTransport{ transport: http.DefaultTransport, - headers: map[string]string{}, + header: http.Header{}, } for _, header := range headers { parts := strings.SplitN(header, "=", 2) if len(parts) < 2 { return nil, xerrors.Errorf("split header %q had less than two parts", header) } - transport.headers[parts[0]] = parts[1] + transport.header.Add(parts[0], parts[1]) } client.HTTPClient.Transport = transport return client, nil @@ -655,12 +655,18 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { type headerTransport struct { transport http.RoundTripper - headers map[string]string + header http.Header +} + +func (h *headerTransport) Header() http.Header { + return h.header.Clone() } func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - for k, v := range h.headers { - req.Header.Add(k, v) + for k, v := range h.header { + for _, vv := range v { + req.Header.Add(k, vv) + } } return h.transport.RoundTrip(req) } diff --git a/cli/scaletest.go b/cli/scaletest.go index b367b580bba60..bea8c7fd17c9d 100644 --- a/cli/scaletest.go +++ b/cli/scaletest.go @@ -330,8 +330,8 @@ func scaletestCleanup() *cobra.Command { client.HTTPClient = &http.Client{ Transport: &headerTransport{ transport: http.DefaultTransport, - headers: map[string]string{ - codersdk.BypassRatelimitHeader: "true", + header: map[string][]string{ + codersdk.BypassRatelimitHeader: {"true"}, }, }, } @@ -515,8 +515,8 @@ It is recommended that all rate limits are disabled on the server before running client.HTTPClient = &http.Client{ Transport: &headerTransport{ transport: http.DefaultTransport, - headers: map[string]string{ - codersdk.BypassRatelimitHeader: "true", + header: map[string][]string{ + codersdk.BypassRatelimitHeader: {"true"}, }, }, } diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 4d953f8e050d5..5c52061c37ec8 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -143,9 +143,17 @@ func (c *Client) DialWorkspaceAgent(ctx context.Context, agentID uuid.UUID, opti } ip := tailnet.IP() + var header http.Header + headerTransport, ok := c.HTTPClient.Transport.(interface { + Header() http.Header + }) + if ok { + header = headerTransport.Header() + } conn, err := tailnet.NewConn(&tailnet.Options{ Addresses: []netip.Prefix{netip.PrefixFrom(ip, 128)}, DERPMap: connInfo.DERPMap, + DERPHeader: &header, Logger: options.Logger, BlockEndpoints: options.BlockEndpoints, }) diff --git a/tailnet/conn.go b/tailnet/conn.go index a43e249f80dac..db259573e05dc 100644 --- a/tailnet/conn.go +++ b/tailnet/conn.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/http" "net/netip" "reflect" "strconv" @@ -50,8 +51,9 @@ func init() { } type Options struct { - Addresses []netip.Prefix - DERPMap *tailcfg.DERPMap + Addresses []netip.Prefix + DERPMap *tailcfg.DERPMap + DERPHeader *http.Header // BlockEndpoints specifies whether P2P endpoints are blocked. // If so, only DERPs can establish connections. @@ -159,6 +161,9 @@ func NewConn(options *Options) (conn *Conn, err error) { if !ok { return nil, xerrors.New("get wireguard internals") } + if options.DERPHeader != nil { + magicConn.SetDERPHeader(options.DERPHeader.Clone()) + } // Update the keys for the magic connection! err = magicConn.SetPrivateKey(nodePrivateKey)