From 5c4e24b4fc576e9065bc08038492b9735a0cf311 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Wed, 24 May 2023 19:16:20 +0000 Subject: [PATCH 1/3] chore(cli): correctly report telemetry even when transport replaced By introducing the "ExtraHeaders" map, we can apply headers even when handlers replace the transport, as in the case of our scaletests. --- cli/root.go | 27 +++++++++++++++------------ cli/vscodessh.go | 2 +- codersdk/client.go | 10 ++++++++-- go.mod | 2 ++ 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/cli/root.go b/cli/root.go index 3cad8a9d4a4db..f505e653a30c2 100644 --- a/cli/root.go +++ b/cli/root.go @@ -27,6 +27,7 @@ import ( "cdr.dev/slog" "github.com/charmbracelet/lipgloss" + "github.com/gobwas/httphead" "github.com/mattn/go-isatty" "github.com/coder/coder/buildinfo" @@ -494,14 +495,15 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { } err = r.setClient( client, r.clientURL, - append(r.header, codersdk.CLITelemetryHeader+"="+ - base64.StdEncoding.EncodeToString(byt), - ), ) if err != nil { return err } + client.ExtraHeaders.Set(codersdk.CLITelemetryHeader, + base64.StdEncoding.EncodeToString(byt), + ) + client.SetSessionToken(r.token) if r.debugHTTP { @@ -544,28 +546,29 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { } } -func (*RootCmd) setClient(client *codersdk.Client, serverURL *url.URL, headers []string) error { +func (r *RootCmd) setClient(client *codersdk.Client, serverURL *url.URL) error { transport := &headerTransport{ transport: http.DefaultTransport, header: http.Header{}, } - for _, header := range headers { - parts := strings.SplitN(header, "=", 2) - if len(parts) < 2 { - return xerrors.Errorf("split header %q had less than two parts", header) - } - transport.header.Add(parts[0], parts[1]) - } client.URL = serverURL client.HTTPClient = &http.Client{ Transport: transport, } + client.ExtraHeaders = make(http.Header) + for _, hd := range r.header { + k, v, ok := httphead.ParseHeaderLine([]byte(hd)) + if !ok { + return xerrors.Errorf("invalid header: %s", hd) + } + client.ExtraHeaders.Add(string(k), string(v)) + } return nil } func (r *RootCmd) createUnauthenticatedClient(serverURL *url.URL) (*codersdk.Client, error) { var client codersdk.Client - err := r.setClient(&client, serverURL, r.header) + err := r.setClient(&client, serverURL) return &client, err } diff --git a/cli/vscodessh.go b/cli/vscodessh.go index 424aa362e3590..363f0215a7696 100644 --- a/cli/vscodessh.go +++ b/cli/vscodessh.go @@ -83,7 +83,7 @@ func (r *RootCmd) vscodeSSH() *clibase.Cmd { client.SetSessionToken(string(sessionToken)) // This adds custom headers to the request! - err = r.setClient(client, serverURL, r.header) + err = r.setClient(client, serverURL) if err != nil { return xerrors.Errorf("set client: %w", err) } diff --git a/codersdk/client.go b/codersdk/client.go index 19f765c097c5d..3c1bf39c192a8 100644 --- a/codersdk/client.go +++ b/codersdk/client.go @@ -85,8 +85,9 @@ var loggableMimeTypes = map[string]struct{}{ // New creates a Coder client for the provided URL. func New(serverURL *url.URL) *Client { return &Client{ - URL: serverURL, - HTTPClient: &http.Client{}, + URL: serverURL, + HTTPClient: &http.Client{}, + ExtraHeaders: make(http.Header), } } @@ -96,6 +97,9 @@ type Client struct { mu sync.RWMutex // Protects following. sessionToken string + // ExtraHeaders are headers to add to every request. + ExtraHeaders http.Header + HTTPClient *http.Client URL *url.URL @@ -189,6 +193,8 @@ func (c *Client) Request(ctx context.Context, method, path string, body interfac return nil, xerrors.Errorf("create request: %w", err) } + req.Header = c.ExtraHeaders.Clone() + tokenHeader := c.SessionTokenHeader if tokenHeader == "" { tokenHeader = SessionTokenHeader diff --git a/go.mod b/go.mod index f1b1e58f75831..cc7edcc5d7e25 100644 --- a/go.mod +++ b/go.mod @@ -352,3 +352,5 @@ require ( howett.net/plist v1.0.0 // indirect inet.af/peercred v0.0.0-20210906144145-0893ea02156a // indirect ) + +require github.com/gobwas/httphead v0.1.0 From e5f218cc31dbe83f323fc6266b5f3cf31b5f40e6 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Wed, 24 May 2023 19:34:52 +0000 Subject: [PATCH 2/3] Only send telemetry header when it's small --- cli/root.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/cli/root.go b/cli/root.go index f505e653a30c2..fc29764d8a920 100644 --- a/cli/root.go +++ b/cli/root.go @@ -486,24 +486,29 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { return err } } + err = r.setClient( + client, r.clientURL, + ) + if err != nil { + return err + } + // Add telemetry headers: telemInv := telemetryInvocation(i) byt, err := json.Marshal(telemInv) if err != nil { // Should be impossible panic(err) } - err = r.setClient( - client, r.clientURL, - ) - if err != nil { - return err + if s := base64.StdEncoding.EncodeToString(byt); len(s) < 4096 { + // Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values, + // we don't want to send headers that are too long. + client.ExtraHeaders.Set( + codersdk.CLITelemetryHeader, + s, + ) } - client.ExtraHeaders.Set(codersdk.CLITelemetryHeader, - base64.StdEncoding.EncodeToString(byt), - ) - client.SetSessionToken(r.token) if r.debugHTTP { From 796cd2037c4bdcc0e0f483faf03b037ccabef723 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Wed, 24 May 2023 19:38:31 +0000 Subject: [PATCH 3/3] Use functions --- cli/root.go | 58 ++++++++++++++++++++++++++++------------------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/cli/root.go b/cli/root.go index fc29764d8a920..ef6b5fd8c2adb 100644 --- a/cli/root.go +++ b/cli/root.go @@ -429,9 +429,9 @@ type RootCmd struct { noFeatureWarning bool } -func telemetryInvocation(i *clibase.Invocation) telemetry.CLIInvocation { +func addTelemetryHeader(client *codersdk.Client, inv *clibase.Invocation) { var topts []telemetry.CLIOption - for _, opt := range i.Command.FullOptions() { + for _, opt := range inv.Command.FullOptions() { if opt.ValueSource == clibase.ValueSourceNone || opt.ValueSource == clibase.ValueSourceDefault { continue } @@ -440,11 +440,29 @@ func telemetryInvocation(i *clibase.Invocation) telemetry.CLIInvocation { ValueSource: string(opt.ValueSource), }) } - return telemetry.CLIInvocation{ - Command: i.Command.FullName(), + ti := telemetry.CLIInvocation{ + Command: inv.Command.FullName(), Options: topts, InvokedAt: time.Now(), } + + byt, err := json.Marshal(ti) + if err != nil { + // Should be impossible + panic(err) + } + + // Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values, + // we don't want to send headers that are too long. + s := base64.StdEncoding.EncodeToString(byt) + if len(s) > 4096 { + return + } + + client.ExtraHeaders.Set( + codersdk.CLITelemetryHeader, + s, + ) } // InitClient sets client to a new client. @@ -457,7 +475,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { panic("root is nil") } return func(next clibase.HandlerFunc) clibase.HandlerFunc { - return func(i *clibase.Invocation) error { + return func(inv *clibase.Invocation) error { conf := r.createConfig() var err error if r.clientURL == nil || r.clientURL.String() == "" { @@ -493,21 +511,7 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { return err } - // Add telemetry headers: - telemInv := telemetryInvocation(i) - byt, err := json.Marshal(telemInv) - if err != nil { - // Should be impossible - panic(err) - } - if s := base64.StdEncoding.EncodeToString(byt); len(s) < 4096 { - // Per https://stackoverflow.com/questions/686217/maximum-on-http-header-values, - // we don't want to send headers that are too long. - client.ExtraHeaders.Set( - codersdk.CLITelemetryHeader, - s, - ) - } + addTelemetryHeader(client, inv) client.SetSessionToken(r.token) @@ -522,31 +526,31 @@ func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { warningErr = make(chan error) ) go func() { - versionErr <- r.checkVersions(i, client) + versionErr <- r.checkVersions(inv, client) close(versionErr) }() go func() { - warningErr <- r.checkWarnings(i, client) + warningErr <- r.checkWarnings(inv, client) close(warningErr) }() if err = <-versionErr; err != nil { // Just log the error here. We never want to fail a command // due to a pre-run. - _, _ = fmt.Fprintf(i.Stderr, + _, _ = fmt.Fprintf(inv.Stderr, cliui.Styles.Warn.Render("check versions error: %s"), err) - _, _ = fmt.Fprintln(i.Stderr) + _, _ = fmt.Fprintln(inv.Stderr) } if err = <-warningErr; err != nil { // Same as above - _, _ = fmt.Fprintf(i.Stderr, + _, _ = fmt.Fprintf(inv.Stderr, cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err) - _, _ = fmt.Fprintln(i.Stderr) + _, _ = fmt.Fprintln(inv.Stderr) } - return next(i) + return next(inv) } } }