diff --git a/cli/configssh.go b/cli/configssh.go index bc4ea9541e11c..6151b0a6d5323 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -89,18 +89,23 @@ func sshFetchWorkspaceConfigs(ctx context.Context, client *codersdk.Client) ([]s } wc := sshWorkspaceConfig{Name: workspace.Name} + var agents []codersdk.WorkspaceAgent for _, resource := range resources { if resource.Transition != codersdk.WorkspaceTransitionStart { continue } - for _, agent := range resource.Agents { - hostname := workspace.Name - if len(resource.Agents) > 1 { - hostname += "." + agent.Name - } - wc.Hosts = append(wc.Hosts, hostname) - } + agents = append(agents, resource.Agents...) + } + + // handle both WORKSPACE and WORKSPACE.AGENT syntax + if len(agents) == 1 { + wc.Hosts = append(wc.Hosts, workspace.Name) + } + for _, agent := range agents { + hostname := workspace.Name + "." + agent.Name + wc.Hosts = append(wc.Hosts, hostname) } + workspaceConfigs[i] = wc return nil diff --git a/cli/configssh_test.go b/cli/configssh_test.go index df3aa0b99f872..2a6c8cd88f322 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -1,6 +1,8 @@ package cli_test import ( + "bufio" + "bytes" "context" "fmt" "io" @@ -692,3 +694,152 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { }) } } + +func TestConfigSSH_Hostnames(t *testing.T) { + t.Parallel() + + type resourceSpec struct { + name string + agents []string + } + tests := []struct { + name string + resources []resourceSpec + expected []string + }{ + { + name: "one resource with one agent", + resources: []resourceSpec{ + {name: "foo", agents: []string{"agent1"}}, + }, + expected: []string{"coder.@", "coder.@.agent1"}, + }, + { + name: "one resource with two agents", + resources: []resourceSpec{ + {name: "foo", agents: []string{"agent1", "agent2"}}, + }, + expected: []string{"coder.@.agent1", "coder.@.agent2"}, + }, + { + name: "two resources with one agent", + resources: []resourceSpec{ + {name: "foo", agents: []string{"agent1"}}, + {name: "bar"}, + }, + expected: []string{"coder.@", "coder.@.agent1"}, + }, + { + name: "two resources with two agents", + resources: []resourceSpec{ + {name: "foo", agents: []string{"agent1"}}, + {name: "bar", agents: []string{"agent2"}}, + }, + expected: []string{"coder.@.agent1", "coder.@.agent2"}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var resources []*proto.Resource + for _, resourceSpec := range tt.resources { + resource := &proto.Resource{ + Name: resourceSpec.name, + Type: "aws_instance", + } + for _, agentName := range resourceSpec.agents { + resource.Agents = append(resource.Agents, &proto.Agent{ + Id: uuid.NewString(), + Name: agentName, + }) + } + resources = append(resources, resource) + } + + provisionResponse := []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: resources, + }, + }, + }} + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerD: true}) + user := coderdtest.CreateFirstUser(t, client) + // authToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: provisionResponse, + Provision: provisionResponse, + }) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + sshConfigFile, _ := sshConfigFileNames(t) + + cmd, root := clitest.New(t, "config-ssh", "--ssh-config-file", sshConfigFile) + clitest.SetupConfig(t, client, root) + doneChan := make(chan struct{}) + pty := ptytest.New(t) + cmd.SetIn(pty.Input()) + cmd.SetOut(pty.Output()) + go func() { + defer close(doneChan) + err := cmd.Execute() + assert.NoError(t, err) + }() + + matches := []struct { + match, write string + }{ + {match: "Continue?", write: "yes"}, + } + for _, m := range matches { + pty.ExpectMatch(m.match) + pty.WriteLine(m.write) + } + + <-doneChan + + var expectedHosts []string + for _, hostnamePattern := range tt.expected { + hostname := strings.ReplaceAll(hostnamePattern, "@", workspace.Name) + expectedHosts = append(expectedHosts, hostname) + } + + hosts := sshConfigFileParseHosts(t, sshConfigFile) + require.ElementsMatch(t, expectedHosts, hosts) + }) + } +} + +// sshConfigFileParseHosts reads a file in the format of .ssh/config and extracts +// the hostnames that are listed in "Host" directives. +func sshConfigFileParseHosts(t *testing.T, name string) []string { + t.Helper() + b, err := os.ReadFile(name) + require.NoError(t, err) + + var result []string + lineScanner := bufio.NewScanner(bytes.NewBuffer(b)) + for lineScanner.Scan() { + line := lineScanner.Text() + line = strings.TrimSpace(line) + + tokenScanner := bufio.NewScanner(bytes.NewBufferString(line)) + tokenScanner.Split(bufio.ScanWords) + ok := tokenScanner.Scan() + if ok && tokenScanner.Text() == "Host" { + for tokenScanner.Scan() { + result = append(result, tokenScanner.Text()) + } + } + } + + return result +}