Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat: Support --header for CLI commands to support proxies
Fixes #3527.
  • Loading branch information
kylecarbs committed Sep 12, 2022
commit 7a42d60cfb4fdf1373eefbe5875e5287d9c7e904
12 changes: 12 additions & 0 deletions cli/cliflag/cliflag.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string
flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env))
}

func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) {
v, ok := os.LookupEnv(env)
if !ok || v == "" {
if v == "" {
def = []string{}
} else {
def = strings.Split(v, ",")
}
}
flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env))
}

func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) {
val, ok := os.LookupEnv(env)
if ok {
Expand Down
5 changes: 4 additions & 1 deletion cli/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func login() *cobra.Command {
serverURL.Scheme = "https"
}

client := codersdk.New(serverURL)
client, err := createUnauthenticatedClient(cmd, serverURL)
if err != nil {
return err
}

// Try to check the version of the server prior to logging in.
// It may be useful to warn the user if they are trying to login
Expand Down
41 changes: 40 additions & 1 deletion cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"flag"
"fmt"
"net/http"
"net/url"
"os"
"strings"
Expand Down Expand Up @@ -41,6 +42,7 @@ const (
varAgentToken = "agent-token"
varAgentURL = "agent-url"
varGlobalConfig = "global-config"
varHeader = "header"
varNoOpen = "no-open"
varNoVersionCheck = "no-version-warning"
varNoFeatureWarning = "no-feature-warning"
Expand Down Expand Up @@ -174,6 +176,7 @@ func Root(subcommands []*cobra.Command) *cobra.Command {
cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "Specify the URL for an agent to access your deployment.")
_ = cmd.PersistentFlags().MarkHidden(varAgentURL)
cliflag.String(cmd.PersistentFlags(), varGlobalConfig, "", "CODER_CONFIG_DIR", configdir.LocalConfig("coderv2"), "Specify the path to the global `coder` config directory.")
cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"")
cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.")
_ = cmd.PersistentFlags().MarkHidden(varForceTty)
cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.")
Expand Down Expand Up @@ -237,8 +240,32 @@ func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) {
return nil, err
}
}
client, err := createUnauthenticatedClient(cmd, serverURL)
if err != nil {
return nil, err
}
client.SessionToken = token
return client, nil
}

func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) {
client := codersdk.New(serverURL)
client.SessionToken = strings.TrimSpace(token)
headers, err := cmd.Flags().GetStringArray(varHeader)
if err != nil {
return nil, err
}
transport := &headerTransport{
transport: http.DefaultTransport,
headers: map[string]string{},
}
for _, header := range headers {
parts := strings.SplitN(header, "=", 2)
if len(parts) < 2 {
return nil, xerrors.Errorf("split header %q had less than two parts", header)
}
transport.headers[parts[0]] = parts[1]
}
client.HTTPClient.Transport = transport
return client, nil
}

Expand Down Expand Up @@ -530,3 +557,15 @@ func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error {

return nil
}

type headerTransport struct {
transport http.RoundTripper
headers map[string]string
}

func (h *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
for k, v := range h.headers {
req.Header.Add(k, v)
}
return h.transport.RoundTrip(req)
}
24 changes: 24 additions & 0 deletions cli/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package cli_test

import (
"bytes"
"net/http"
"net/http/httptest"
"testing"

"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"

Expand Down Expand Up @@ -129,4 +132,25 @@ func TestRoot(t *testing.T) {
require.Contains(t, output, buildinfo.Version(), "has version")
require.Contains(t, output, buildinfo.ExternalURL(), "has url")
})

t.Run("Header", func(t *testing.T) {
t.Parallel()

done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "wow", r.Header.Get("X-Testing"))
w.WriteHeader(http.StatusGone)
select {
case <-done:
close(done)
default:
}
}))
defer srv.Close()
buf := new(bytes.Buffer)
cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL)
cmd.SetOut(buf)
// This won't succeed, because we're using the login cmd to assert requests.
_ = cmd.Execute()
})
}