Skip to content

Commit a209825

Browse files
authored
feat: Support --header for CLI commands to support proxies (#4008)
Fixes #3527.
1 parent 846dd99 commit a209825

File tree

4 files changed

+80
-2
lines changed

4 files changed

+80
-2
lines changed

cli/cliflag/cliflag.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
6161
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
6262
}
6363

64+
func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
65+
v, ok := os.LookupEnv(env)
66+
if !ok || v == "" {
67+
if v == "" {
68+
def = []string{}
69+
} else {
70+
def = strings.Split(v, ",")
71+
}
72+
}
73+
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
74+
}
75+
6476
func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
6577
val, ok := os.LookupEnv(env)
6678
if ok {

cli/login.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,10 @@ func login() *cobra.Command {
6666
serverURL.Scheme = "https"
6767
}
6868

69-
client := codersdk.New(serverURL)
69+
client, err := createUnauthenticatedClient(cmd, serverURL)
70+
if err != nil {
71+
return err
72+
}
7073

7174
// Try to check the version of the server prior to logging in.
7275
// It may be useful to warn the user if they are trying to login

cli/root.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"flag"
66
"fmt"
7+
"net/http"
78
"net/url"
89
"os"
910
"strings"
@@ -41,6 +42,7 @@ const (
4142
varAgentToken = "agent-token"
4243
varAgentURL = "agent-url"
4344
varGlobalConfig = "global-config"
45+
varHeader = "header"
4446
varNoOpen = "no-open"
4547
varNoVersionCheck = "no-version-warning"
4648
varNoFeatureWarning = "no-feature-warning"
@@ -174,6 +176,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
174176
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.")
175177
_ = cmd.PersistentFlags().MarkHidden(varAgentURL)
176178
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.")
179+
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"")
177180
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.")
178181
_ = cmd.PersistentFlags().MarkHidden(varForceTty)
179182
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.")
@@ -237,8 +240,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
237240
return nil, err
238241
}
239242
}
243+
client, err := createUnauthenticatedClient(cmd, serverURL)
244+
if err != nil {
245+
return nil, err
246+
}
247+
client.SessionToken = token
248+
return client, nil
249+
}
250+
251+
func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) {
240252
client := codersdk.New(serverURL)
241-
client.SessionToken = strings.TrimSpace(token)
253+
headers, err := cmd.Flags().GetStringArray(varHeader)
254+
if err != nil {
255+
return nil, err
256+
}
257+
transport := &headerTransport{
258+
transport: http.DefaultTransport,
259+
headers: map[string]string{},
260+
}
261+
for _, header := range headers {
262+
parts := strings.SplitN(header, "=", 2)
263+
if len(parts) < 2 {
264+
return nil, xerrors.Errorf("split header %q had less than two parts", header)
265+
}
266+
transport.headers[parts[0]] = parts[1]
267+
}
268+
client.HTTPClient.Transport = transport
242269
return client, nil
243270
}
244271

@@ -530,3 +557,15 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {
530557

531558
return nil
532559
}
560+
561+
type headerTransport struct {
562+
transport http.RoundTripper
563+
headers map[string]string
564+
}
565+
566+
func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
567+
for k, v := range h.headers {
568+
req.Header.Add(k, v)
569+
}
570+
return h.transport.RoundTrip(req)
571+
}

cli/root_test.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package cli_test
22

33
import (
44
"bytes"
5+
"net/http"
6+
"net/http/httptest"
57
"testing"
68

79
"github.com/spf13/cobra"
10+
"github.com/stretchr/testify/assert"
811
"github.com/stretchr/testify/require"
912
"golang.org/x/xerrors"
1013

@@ -129,4 +132,25 @@ func TestRoot(t *testing.T) {
129132
require.Contains(t, output, buildinfo.Version(), "has version")
130133
require.Contains(t, output, buildinfo.ExternalURL(), "has url")
131134
})
135+
136+
t.Run("Header", func(t *testing.T) {
137+
t.Parallel()
138+
139+
done := make(chan struct{})
140+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141+
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
142+
w.WriteHeader(http.StatusGone)
143+
select {
144+
case <-done:
145+
close(done)
146+
default:
147+
}
148+
}))
149+
defer srv.Close()
150+
buf := new(bytes.Buffer)
151+
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
152+
cmd.SetOut(buf)
153+
// This won't succeed, because we're using the login cmd to assert requests.
154+
_ = cmd.Execute()
155+
})
132156
}

0 commit comments

Comments
 (0)