diff --git a/cli/gitssh.go b/cli/gitssh.go index e36f5deadc3d2..b18b919f79515 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -1,9 +1,15 @@ package cli import ( + "bufio" + "bytes" + "context" "fmt" + "io" "os" "os/exec" + "os/signal" + "path/filepath" "strings" "github.com/spf13/cobra" @@ -13,16 +19,30 @@ import ( ) func gitssh() *cobra.Command { - return &cobra.Command{ + cmd := &cobra.Command{ Use: "gitssh", Hidden: true, Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`, RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + env := os.Environ() + + // Catch interrupt signals to ensure the temporary private + // key file is cleaned up on most cases. + ctx, stop := signal.NotifyContext(ctx, interruptSignals...) + defer stop() + + // Early check so errors are reported immediately. + identityFiles, err := parseIdentityFilesForHost(ctx, args, env) + if err != nil { + return err + } + client, err := createAgentClient(cmd) if err != nil { return xerrors.Errorf("create agent client: %w", err) } - key, err := client.AgentGitSSHKey(cmd.Context()) + key, err := client.AgentGitSSHKey(ctx) if err != nil { return xerrors.Errorf("get agent git ssh token: %w", err) } @@ -44,8 +64,23 @@ func gitssh() *cobra.Command { return xerrors.Errorf("close temp gitsshkey file: %w", err) } - args = append([]string{"-i", privateKeyFile.Name()}, args...) - c := exec.CommandContext(cmd.Context(), "ssh", args...) + // Append our key, giving precedence to user keys. Note that + // OpenSSH server are typically configured with MaxAuthTries + // set to the default value of 6. This means that only the 6 + // first keys can be tried. However, we will assume that if + // a user has configured 6+ keys for a host, they know what + // they're doing. This behavior is critical if a server has + // been configured with MaxAuthTries set to 1. + identityFiles = append(identityFiles, privateKeyFile.Name()) + + var identityArgs []string + for _, id := range identityFiles { + identityArgs = append(identityArgs, "-i", id) + } + + args = append(identityArgs, args...) + c := exec.CommandContext(ctx, "ssh", args...) + c.Env = append(c.Env, env...) c.Stderr = cmd.ErrOrStderr() c.Stdout = cmd.OutOrStdout() c.Stdin = cmd.InOrStdin() @@ -69,4 +104,86 @@ func gitssh() *cobra.Command { return nil }, } + + return cmd +} + +// fallbackIdentityFiles is the list of identity files SSH tries when +// none have been defined for a host. +var fallbackIdentityFiles = strings.Join([]string{ + "identityfile ~/.ssh/id_rsa", + "identityfile ~/.ssh/id_dsa", + "identityfile ~/.ssh/id_ecdsa", + "identityfile ~/.ssh/id_ecdsa_sk", + "identityfile ~/.ssh/id_ed25519", + "identityfile ~/.ssh/id_ed25519_sk", + "identityfile ~/.ssh/id_xmss", +}, "\n") + +// parseIdentityFilesForHost uses ssh -G to discern what SSH keys have +// been enabled for the host (via the users SSH config) and returns a +// list of existing identity files. +// +// We do this because when no keys are defined for a host, SSH uses +// fallback keys (see above). However, by passing `-i` to attach our +// private key, we're effectively disabling the fallback keys. +// +// Example invocation: +// +// ssh -G -o SendEnv=GIT_PROTOCOL git@github.com git-upload-pack 'coder/coder' +// +// The extra arguments work without issue and lets us run the command +// as-is without stripping out the excess (git-upload-pack 'coder/coder'). +func parseIdentityFilesForHost(ctx context.Context, args, env []string) (identityFiles []string, error error) { + home, err := os.UserHomeDir() + if err != nil { + return nil, xerrors.Errorf("get user home dir failed: %w", err) + } + + var outBuf bytes.Buffer + var r io.Reader = &outBuf + + args = append([]string{"-G"}, args...) + cmd := exec.CommandContext(ctx, "ssh", args...) + cmd.Env = append(cmd.Env, env...) + cmd.Stdout = &outBuf + cmd.Stderr = io.Discard + err = cmd.Run() + if err != nil { + // If ssh -G failed, the SSH version is likely too old, fallback + // to using the default identity files. + r = strings.NewReader(fallbackIdentityFiles) + } + + s := bufio.NewScanner(r) + for s.Scan() { + line := s.Text() + if strings.HasPrefix(line, "identityfile ") { + id := strings.TrimPrefix(line, "identityfile ") + if strings.HasPrefix(id, "~/") { + id = home + id[1:] + } + // OpenSSH on Windows is weird, it supports using (and does + // use) mixed \ and / in paths. + // + // Example: C:\Users\ZeroCool/.ssh/known_hosts + // + // To check the file existence in Go, though, we want to use + // proper Windows paths. + // OpenSSH is amazing, this will work on Windows too: + // C:\Users\ZeroCool/.ssh/id_rsa + id = filepath.FromSlash(id) + + // Only include the identity file if it exists. + if _, err := os.Stat(id); err == nil { + identityFiles = append(identityFiles, id) + } + } + } + if err := s.Err(); err != nil { + // This should never happen, the check is for completeness. + return nil, xerrors.Errorf("scan ssh output: %w", err) + } + + return identityFiles, nil } diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index cdbbe6bd1f4ff..a187566f61b11 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -2,8 +2,16 @@ package cli_test import ( "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" "fmt" "net" + "os" + "path/filepath" + "strings" "sync/atomic" "testing" @@ -17,99 +25,245 @@ import ( "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" + "github.com/coder/coder/pty/ptytest" + "github.com/coder/coder/testutil" ) +func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, string, gossh.PublicKey) { + t.Helper() + + client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + user := coderdtest.CreateFirstUser(t, client) + + ctx, cancel := context.WithCancel(ctx) + defer t.Cleanup(cancel) // Defer so that cancel is the first cleanup. + + // get user public key + keypair, err := client.GitSSHKey(ctx, codersdk.Me) + require.NoError(t, err) + //nolint:dogsled + pubkey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey)) + require.NoError(t, err) + + // setup template + agentToken := uuid.NewString() + version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ + Parse: echo.ParseComplete, + ProvisionDryRun: echo.ProvisionComplete, + Provision: []*proto.Provision_Response{{ + Type: &proto.Provision_Response_Complete{ + Complete: &proto.Provision_Complete{ + Resources: []*proto.Resource{{ + Name: "somename", + Type: "someinstance", + Agents: []*proto.Agent{{ + Auth: &proto.Agent_Token{ + Token: agentToken, + }, + }}, + }}, + }, + }, + }}, + }) + template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) + coderdtest.AwaitTemplateVersionJob(t, client, version.ID) + workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) + coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) + + // start workspace agent + cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false") + agentClient := client + clitest.SetupConfig(t, agentClient, root) + + errC := make(chan error, 1) + go func() { + errC <- cmd.ExecuteContext(ctx) + }() + t.Cleanup(func() { require.NoError(t, <-errC) }) + + coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) + resources, err := client.WorkspaceResourcesByBuild(ctx, workspace.LatestBuild.ID) + require.NoError(t, err) + dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + require.NoError(t, err) + defer dialer.Close() + _, err = dialer.Ping() + require.NoError(t, err) + + return agentClient, agentToken, pubkey +} + +func serveSSHForGitSSH(t *testing.T, handler func(ssh.Session), pubkeys ...gossh.PublicKey) *net.TCPAddr { + t.Helper() + + // start ssh server + l, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { _ = l.Close() }) + + serveOpts := []ssh.Option{ + ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { + for _, pubkey := range pubkeys { + if ssh.KeysEqual(pubkey, key) { + return true + } + } + return false + }), + } + errC := make(chan error, 1) + go func() { + // as long as we get a successful session we don't care if the server errors + errC <- ssh.Serve(l, handler, serveOpts...) + }() + t.Cleanup(func() { + _ = l.Close() // Ensure server shutdown. + <-errC + }) + + // start ssh session + addr, ok := l.Addr().(*net.TCPAddr) + require.True(t, ok) + + return addr +} + +func writePrivateKeyToFile(t *testing.T, name string, key *ecdsa.PrivateKey) { + t.Helper() + + b, err := x509.MarshalPKCS8PrivateKey(key) + require.NoError(t, err) + b = pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: b, + }) + + err = os.WriteFile(name, b, 0o600) + require.NoError(t, err) +} + func TestGitSSH(t *testing.T) { t.Parallel() t.Run("Dial", func(t *testing.T) { t.Parallel() - client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - user := coderdtest.CreateFirstUser(t, client) - // get user public key - keypair, err := client.GitSSHKey(context.Background(), codersdk.Me) - require.NoError(t, err) - publicKey, _, _, _, err := gossh.ParseAuthorizedKey([]byte(keypair.PublicKey)) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() + + client, token, pubkey := prepareTestGitSSH(ctx, t) + var inc int64 + errC := make(chan error, 1) + addr := serveSSHForGitSSH(t, func(s ssh.Session) { + atomic.AddInt64(&inc, 1) + t.Log("got authenticated session") + select { + case errC <- s.Exit(0): + default: + t.Error("error channel is full") + } + }, pubkey) + + // set to agent config dir + cmd, _ := clitest.New(t, + "gitssh", + "--agent-url", client.URL.String(), + "--agent-token", token, + "--", + fmt.Sprintf("-p%d", addr.Port), + "-o", "StrictHostKeyChecking=no", + "-o", "IdentitiesOnly=yes", + "127.0.0.1", + ) + err := cmd.ExecuteContext(ctx) require.NoError(t, err) + require.EqualValues(t, 1, inc) - // setup template - agentToken := uuid.NewString() - version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, &echo.Responses{ - Parse: echo.ParseComplete, - ProvisionDryRun: echo.ProvisionComplete, - Provision: []*proto.Provision_Response{{ - Type: &proto.Provision_Response_Complete{ - Complete: &proto.Provision_Complete{ - Resources: []*proto.Resource{{ - Name: "somename", - Type: "someinstance", - Agents: []*proto.Agent{{ - Auth: &proto.Agent_Token{ - Token: agentToken, - }, - }}, - }}, - }, - }, - }}, - }) - template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - coderdtest.AwaitTemplateVersionJob(t, client, version.ID) - workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) - coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - - // start workspace agent - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String(), "--wireguard=false") - agentClient := client - clitest.SetupConfig(t, agentClient, root) - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - agentErrC := make(chan error) - go func() { - agentErrC <- cmd.ExecuteContext(ctx) - }() - - coderdtest.AwaitWorkspaceAgents(t, client, workspace.LatestBuild.ID) - resources, err := client.WorkspaceResourcesByBuild(context.Background(), workspace.LatestBuild.ID) + err = <-errC + require.NoError(t, err, "error in agent execute") + }) + + t.Run("Local SSH Keys", func(t *testing.T) { + t.Parallel() + + home := t.TempDir() + sshdir := filepath.Join(home, ".ssh") + err := os.MkdirAll(sshdir, 0o700) require.NoError(t, err) - dialer, err := client.DialWorkspaceAgent(context.Background(), resources[0].Agents[0].ID, nil) + + idFile := filepath.Join(sshdir, "id_ed25519") + privkey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - defer dialer.Close() - _, err = dialer.Ping() + localPubkey, err := gossh.NewPublicKey(&privkey.PublicKey) require.NoError(t, err) + writePrivateKeyToFile(t, idFile, privkey) + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) + defer cancel() - // start ssh server - l, err := net.Listen("tcp", "localhost:0") + client, token, coderPubkey := prepareTestGitSSH(ctx, t) + + authkey := make(chan gossh.PublicKey, 1) + addr := serveSSHForGitSSH(t, func(s ssh.Session) { + t.Logf("authenticated with: %s", gossh.MarshalAuthorizedKey(s.PublicKey())) + select { + case authkey <- s.PublicKey(): + default: + t.Error("authkey channel is full") + } + }, localPubkey, coderPubkey) + + // Create a new config which sets an identity file. + config := filepath.Join(sshdir, "config") + knownHosts := filepath.Join(sshdir, "known_hosts") + err = os.WriteFile(config, []byte(strings.Join([]string{ + "Host mytest", + " HostName 127.0.0.1", + fmt.Sprintf(" Port %d", addr.Port), + " StrictHostKeyChecking no", + " UserKnownHostsFile=" + knownHosts, + " IdentitiesOnly yes", + " IdentityFile=" + idFile, + }, "\n")), 0o600) require.NoError(t, err) - defer l.Close() - publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - return ssh.KeysEqual(publicKey, key) - }) - var inc int64 - sshErrC := make(chan error) - go func() { - // as long as we get a successful session we don't care if the server errors - _ = ssh.Serve(l, func(s ssh.Session) { - atomic.AddInt64(&inc, 1) - t.Log("got authenticated session") - sshErrC <- s.Exit(0) - }, publicKeyOption) - }() - - // start ssh session - addr, ok := l.Addr().(*net.TCPAddr) - require.True(t, ok) - // set to agent config dir - gitsshCmd, _ := clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "-o", "IdentitiesOnly=yes", "127.0.0.1") - err = gitsshCmd.ExecuteContext(context.Background()) + + pty := ptytest.New(t) + cmdArgs := []string{ + "gitssh", + "--agent-url", client.URL.String(), + "--agent-token", token, + "--", + "-F", config, + "mytest", + } + // Test authentication via local private key. + cmd, _ := clitest.New(t, cmdArgs...) + cmd.SetOut(pty.Output()) + cmd.SetErr(pty.Output()) + err = cmd.ExecuteContext(ctx) require.NoError(t, err) - require.EqualValues(t, 1, inc) + select { + case key := <-authkey: + require.Equal(t, localPubkey, key) + case <-ctx.Done(): + t.Fatal("timeout waiting for auth") + } - err = <-sshErrC - require.NoError(t, err, "error in ssh session exit") + // Delete the local private key. + err = os.Remove(idFile) + require.NoError(t, err) - cancelFunc() - err = <-agentErrC - require.NoError(t, err, "error in agent execute") + // With the local file deleted, the coder key should be used. + cmd, _ = clitest.New(t, cmdArgs...) + cmd.SetOut(pty.Output()) + cmd.SetErr(pty.Output()) + err = cmd.ExecuteContext(ctx) + require.NoError(t, err) + select { + case key := <-authkey: + require.Equal(t, coderPubkey, key) + case <-ctx.Done(): + t.Fatal("timeout waiting for auth") + } }) }