diff --git a/cli/cliflag/cliflag.go b/cli/cliflag/cliflag.go index 42722ecc1cfb3..7970fc2d64b58 100644 --- a/cli/cliflag/cliflag.go +++ b/cli/cliflag/cliflag.go @@ -61,6 +61,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env)) } +func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) { + v, ok := os.LookupEnv(env) + if !ok || v == "" { + if v == "" { + def = []string{} + } else { + def = strings.Split(v, ",") + } + } + flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env)) +} + func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) { val, ok := os.LookupEnv(env) if ok { diff --git a/cli/login.go b/cli/login.go index 9abc0b48d9c2f..a178c34ca2b98 100644 --- a/cli/login.go +++ b/cli/login.go @@ -66,7 +66,10 @@ func login() *cobra.Command { serverURL.Scheme = "https" } - client := codersdk.New(serverURL) + client, err := createUnauthenticatedClient(cmd, serverURL) + if err != nil { + return err + } // Try to check the version of the server prior to logging in. // It may be useful to warn the user if they are trying to login diff --git a/cli/root.go b/cli/root.go index 870346d76d360..779c47c07c1c3 100644 --- a/cli/root.go +++ b/cli/root.go @@ -4,6 +4,7 @@ import ( "context" "flag" "fmt" + "net/http" "net/url" "os" "strings" @@ -41,6 +42,7 @@ const ( varAgentToken = "agent-token" varAgentURL = "agent-url" varGlobalConfig = "global-config" + varHeader = "header" varNoOpen = "no-open" varNoVersionCheck = "no-version-warning" varNoFeatureWarning = "no-feature-warning" @@ -174,6 +176,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command { cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.") _ = cmd.PersistentFlags().MarkHidden(varAgentURL) cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.") + cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"") cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.") _ = cmd.PersistentFlags().MarkHidden(varForceTty) cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.") @@ -237,8 +240,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) { return nil, err } } + client, err := createUnauthenticatedClient(cmd, serverURL) + if err != nil { + return nil, err + } + client.SessionToken = token + return client, nil +} + +func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) { client := codersdk.New(serverURL) - client.SessionToken = strings.TrimSpace(token) + headers, err := cmd.Flags().GetStringArray(varHeader) + if err != nil { + return nil, err + } + transport := &headerTransport{ + transport: http.DefaultTransport, + headers: map[string]string{}, + } + for _, header := range headers { + parts := strings.SplitN(header, "=", 2) + if len(parts) < 2 { + return nil, xerrors.Errorf("split header %q had less than two parts", header) + } + transport.headers[parts[0]] = parts[1] + } + client.HTTPClient.Transport = transport return client, nil } @@ -530,3 +557,15 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { return nil } + +type headerTransport struct { + transport http.RoundTripper + headers map[string]string +} + +func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + for k, v := range h.headers { + req.Header.Add(k, v) + } + return h.transport.RoundTrip(req) +} diff --git a/cli/root_test.go b/cli/root_test.go index 04a9b2c99ecda..617d2f90bc327 100644 --- a/cli/root_test.go +++ b/cli/root_test.go @@ -2,9 +2,12 @@ package cli_test import ( "bytes" + "net/http" + "net/http/httptest" "testing" "github.com/spf13/cobra" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/xerrors" @@ -129,4 +132,25 @@ func TestRoot(t *testing.T) { require.Contains(t, output, buildinfo.Version(), "has version") require.Contains(t, output, buildinfo.ExternalURL(), "has url") }) + + t.Run("Header", func(t *testing.T) { + t.Parallel() + + done := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "wow", r.Header.Get("X-Testing")) + w.WriteHeader(http.StatusGone) + select { + case <-done: + close(done) + default: + } + })) + defer srv.Close() + buf := new(bytes.Buffer) + cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL) + cmd.SetOut(buf) + // This won't succeed, because we're using the login cmd to assert requests. + _ = cmd.Execute() + }) }