diff --git a/cli/configssh.go b/cli/configssh.go index 09a2b471fcd53..cb91cec8c0d8d 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -19,6 +19,7 @@ import ( "github.com/cli/safeexec" "github.com/pkg/diff" "github.com/pkg/diff/write" + "golang.org/x/exp/constraints" "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" @@ -51,6 +52,8 @@ type sshConfigOptions struct { userHostPrefix string sshOptions []string disableAutostart bool + header []string + headerCommand string } // addOptions expects options in the form of "option=value" or "option value". @@ -100,15 +103,25 @@ func (o *sshConfigOptions) addOption(option string) error { } func (o sshConfigOptions) equal(other sshConfigOptions) bool { - // Compare without side-effects or regard to order. - opt1 := slices.Clone(o.sshOptions) - sort.Strings(opt1) - opt2 := slices.Clone(other.sshOptions) - sort.Strings(opt2) - if !slices.Equal(opt1, opt2) { + if !slicesSortedEqual(o.sshOptions, other.sshOptions) { return false } - return o.waitEnum == other.waitEnum && o.userHostPrefix == other.userHostPrefix && o.disableAutostart == other.disableAutostart + if !slicesSortedEqual(o.header, other.header) { + return false + } + return o.waitEnum == other.waitEnum && o.userHostPrefix == other.userHostPrefix && o.disableAutostart == other.disableAutostart && o.headerCommand == other.headerCommand +} + +// slicesSortedEqual compares two slices without side-effects or regard to order. +func slicesSortedEqual[S ~[]E, E constraints.Ordered](a, b S) bool { + if len(a) != len(b) { + return false + } + a = slices.Clone(a) + slices.Sort(a) + b = slices.Clone(b) + slices.Sort(b) + return slices.Equal(a, b) } func (o sshConfigOptions) asList() (list []string) { @@ -124,6 +137,13 @@ func (o sshConfigOptions) asList() (list []string) { for _, opt := range o.sshOptions { list = append(list, fmt.Sprintf("ssh-option: %s", opt)) } + for _, h := range o.header { + list = append(list, fmt.Sprintf("header: %s", h)) + } + if o.headerCommand != "" { + list = append(list, fmt.Sprintf("header-command: %s", o.headerCommand)) + } + return list } @@ -230,6 +250,8 @@ func (r *RootCmd) configSSH() *clibase.Cmd { // specifies skip-proxy-command, then wait cannot be applied. return xerrors.Errorf("cannot specify both --skip-proxy-command and --wait") } + sshConfigOpts.header = r.header + sshConfigOpts.headerCommand = r.headerCommand recvWorkspaceConfigs := sshPrepareWorkspaceConfigs(inv.Context(), client) @@ -393,6 +415,14 @@ func (r *RootCmd) configSSH() *clibase.Cmd { } if !skipProxyCommand { + rootFlags := fmt.Sprintf("--global-config %s", escapedGlobalConfig) + for _, h := range sshConfigOpts.header { + rootFlags += fmt.Sprintf(" --header %q", h) + } + if sshConfigOpts.headerCommand != "" { + rootFlags += fmt.Sprintf(" --header-command %q", sshConfigOpts.headerCommand) + } + flags := "" if sshConfigOpts.waitEnum != "auto" { flags += " --wait=" + sshConfigOpts.waitEnum @@ -401,8 +431,8 @@ func (r *RootCmd) configSSH() *clibase.Cmd { flags += " --disable-autostart=true" } defaultOptions = append(defaultOptions, fmt.Sprintf( - "ProxyCommand %s --global-config %s ssh --stdio%s %s", - escapedCoderBinary, escapedGlobalConfig, flags, workspaceHostname, + "ProxyCommand %s %s ssh --stdio%s %s", + escapedCoderBinary, rootFlags, flags, workspaceHostname, )) } @@ -623,6 +653,12 @@ func sshConfigWriteSectionHeader(w io.Writer, addNewline bool, o sshConfigOption for _, opt := range o.sshOptions { _, _ = fmt.Fprintf(&ow, "# :%s=%s\n", "ssh-option", opt) } + for _, h := range o.header { + _, _ = fmt.Fprintf(&ow, "# :%s=%s\n", "header", h) + } + if o.headerCommand != "" { + _, _ = fmt.Fprintf(&ow, "# :%s=%s\n", "header-command", o.headerCommand) + } if ow.Len() > 0 { _, _ = fmt.Fprint(w, sshConfigOptionsHeader) _, _ = fmt.Fprint(w, ow.String()) @@ -654,6 +690,10 @@ func sshConfigParseLastOptions(r io.Reader) (o sshConfigOptions) { o.sshOptions = append(o.sshOptions, parts[1]) case "disable-autostart": o.disableAutostart, _ = strconv.ParseBool(parts[1]) + case "header": + o.header = append(o.header, parts[1]) + case "header-command": + o.headerCommand = parts[1] default: // Unknown option, ignore. } diff --git a/cli/configssh_test.go b/cli/configssh_test.go index d87d1fa7024e6..ee66e350c1582 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -462,6 +462,9 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { "# Last config-ssh options:", "# :wait=yes", "# :ssh-host-prefix=coder-test.", + "# :header=X-Test-Header=foo", + "# :header=X-Test-Header2=bar", + "# :header-command=printf h1=v1 h2=\"v2\" h3='v3'", "#", headerEnd, "", @@ -471,6 +474,9 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { "--yes", "--wait=yes", "--ssh-host-prefix", "coder-test.", + "--header", "X-Test-Header=foo", + "--header", "X-Test-Header2=bar", + "--header-command", "printf h1=v1 h2=\"v2\" h3='v3'", }, }, { @@ -563,6 +569,55 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { regexMatch: "ProxyCommand /foo/bar/coder", }, }, + { + name: "Header", + args: []string{ + "--yes", + "--header", "X-Test-Header=foo", + "--header", "X-Test-Header2=bar", + }, + wantErr: false, + hasAgent: true, + wantConfig: wantConfig{ + regexMatch: `ProxyCommand .* --header "X-Test-Header=foo" --header "X-Test-Header2=bar" ssh`, + }, + }, + { + name: "Header command", + args: []string{ + "--yes", + "--header-command", "printf h1=v1", + }, + wantErr: false, + hasAgent: true, + wantConfig: wantConfig{ + regexMatch: `ProxyCommand .* --header-command "printf h1=v1" ssh`, + }, + }, + { + name: "Header command with double quotes", + args: []string{ + "--yes", + "--header-command", "printf h1=v1 h2=\"v2\"", + }, + wantErr: false, + hasAgent: true, + wantConfig: wantConfig{ + regexMatch: `ProxyCommand .* --header-command "printf h1=v1 h2=\\\"v2\\\"" ssh`, + }, + }, + { + name: "Header command with single quotes", + args: []string{ + "--yes", + "--header-command", "printf h1=v1 h2='v2'", + }, + wantErr: false, + hasAgent: true, + wantConfig: wantConfig{ + regexMatch: `ProxyCommand .* --header-command "printf h1=v1 h2='v2'" ssh`, + }, + }, } for _, tt := range tests { tt := tt