diff --git a/cli/root.go b/cli/root.go index 3cad8a9d4a4db..ef6b5fd8c2adb 100644 --- a/cli/root.go +++ b/cli/root.go @@ -27,6 +27,7 @@ import ( "cdr.dev/slog" "github.com/charmbracelet/lipgloss" + "github.com/gobwas/httphead" "github.com/mattn/go-isatty" "github.com/coder/coder/buildinfo" @@ -428,9 +429,9 @@ type RootCmd struct { noFeatureWarning bool } -func telemetryInvocation(i *clibase.Invocation) telemetry.CLIInvocation { +func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) { var topts []telemetry.CLIOption - for _, opt := range i.Command.FullOptions() { + for _, opt := range inv.Command.FullOptions() { if opt.ValueSource == clibase.ValueSourceNone || opt.ValueSource == clibase.ValueSourceDefault { continue } @@ -439,11 +440,29 @@ func telemetryInvocation(i *clibase.Invocation) telemetry.CLIInvocation { ValueSource: string(opt.ValueSource), }) } - return telemetry.CLIInvocation{ - Command: i.Command.FullName(), + ti := telemetry.CLIInvocation{ + Command: inv.Command.FullName(), Options: topts, InvokedAt: time.Now(), } + + byt, err := json.Marshal(ti) + if err != nil { + // Should be impossible + panic(err) + } + + // Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values, + // we don't want to send headers that are too long. + s := base64.StdEncoding.EncodeToString(byt) + if len(s) > 4096 { + return + } + + client.ExtraHeaders.Set( + codersdk.CLITelemetryHeader, + s, + ) } // InitClient sets client to a new client. @@ -456,7 +475,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { panic("root is nil") } return func(next clibase.HandlerFunc) clibase.HandlerFunc { - return func(i *clibase.Invocation) error { + return func(inv *clibase.Invocation) error { conf := r.createConfig() var err error if r.clientURL == nil || r.clientURL.String() == "" { @@ -485,23 +504,15 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { return err } } - - telemInv := telemetryInvocation(i) - byt, err := json.Marshal(telemInv) - if err != nil { - // Should be impossible - panic(err) - } err = r.setClient( client, r.clientURL, - append(r.header, codersdk.CLITelemetryHeader+"="+ - base64.StdEncoding.EncodeToString(byt), - ), ) if err != nil { return err } + addTelemetryHeader(client, inv) + client.SetSessionToken(r.token) if r.debugHTTP { @@ -515,57 +526,58 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { warningErr = make(chan error) ) go func() { - versionErr <- r.checkVersions(i, client) + versionErr <- r.checkVersions(inv, client) close(versionErr) }() go func() { - warningErr <- r.checkWarnings(i, client) + warningErr <- r.checkWarnings(inv, client) close(warningErr) }() if err = <-versionErr; err != nil { // Just log the error here. We never want to fail a command // due to a pre-run. - _, _ = fmt.Fprintf(i.Stderr, + _, _ = fmt.Fprintf(inv.Stderr, cliui.Styles.Warn.Render("check versions error: %s"), err) - _, _ = fmt.Fprintln(i.Stderr) + _, _ = fmt.Fprintln(inv.Stderr) } if err = <-warningErr; err != nil { // Same as above - _, _ = fmt.Fprintf(i.Stderr, + _, _ = fmt.Fprintf(inv.Stderr, cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err) - _, _ = fmt.Fprintln(i.Stderr) + _, _ = fmt.Fprintln(inv.Stderr) } - return next(i) + return next(inv) } } } -func (*RootCmd) setClient(client *codersdk.Client, serverURL *url.URL, headers []string) error { +func (r *RootCmd) setClient(client *codersdk.Client, serverURL *url.URL) error { transport := &headerTransport{ transport: http.DefaultTransport, header: http.Header{}, } - for _, header := range headers { - parts := strings.SplitN(header, "=", 2) - if len(parts) < 2 { - return xerrors.Errorf("split header %q had less than two parts", header) - } - transport.header.Add(parts[0], parts[1]) - } client.URL = serverURL client.HTTPClient = &http.Client{ Transport: transport, } + client.ExtraHeaders = make(http.Header) + for _, hd := range r.header { + k, v, ok := httphead.ParseHeaderLine([]byte(hd)) + if !ok { + return xerrors.Errorf("invalid header: %s", hd) + } + client.ExtraHeaders.Add(string(k), string(v)) + } return nil } func (r *RootCmd) createUnauthenticatedClient(serverURL *url.URL) (*codersdk.Client, error) { var client codersdk.Client - err := r.setClient(&client, serverURL, r.header) + err := r.setClient(&client, serverURL) return &client, err } diff --git a/cli/vscodessh.go b/cli/vscodessh.go index 424aa362e3590..363f0215a7696 100644 --- a/cli/vscodessh.go +++ b/cli/vscodessh.go @@ -83,7 +83,7 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd { client.SetSessionToken(string(sessionToken)) // This adds custom headers to the request! - err = r.setClient(client, serverURL, r.header) + err = r.setClient(client, serverURL) if err != nil { return xerrors.Errorf("set client: %w", err) } diff --git a/codersdk/client.go b/codersdk/client.go index 19f765c097c5d..3c1bf39c192a8 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -85,8 +85,9 @@ var loggableMimeTypes = map[string]struct{}{ // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ - URL: serverURL, - HTTPClient: &http.Client{}, + URL: serverURL, + HTTPClient: &http.Client{}, + ExtraHeaders: make(http.Header), } } @@ -96,6 +97,9 @@ type Client struct { mu sync.RWMutex // Protects following. sessionToken string + // ExtraHeaders are headers to add to every request. + ExtraHeaders http.Header + HTTPClient *http.Client URL *url.URL @@ -189,6 +193,8 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac return nil, xerrors.Errorf("create request: %w", err) } + req.Header = c.ExtraHeaders.Clone() + tokenHeader := c.SessionTokenHeader if tokenHeader == "" { tokenHeader = SessionTokenHeader diff --git a/go.mod b/go.mod index f1b1e58f75831..cc7edcc5d7e25 100644 --- a/go.mod +++ b/go.mod @@ -352,3 +352,5 @@ require ( howett.net/plist v1.0.0 // indirect inet.af/peercred v0.0.0-20210906144145-0893ea02156a // indirect ) + +require github.com/gobwas/httphead v0.1.0