Skip to content

fix: properly handle shebangs by writing files #10194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
// if it can guarantee the clocks are synchronized.
CollectedAt: now,
}
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
cmdPty, cleanup, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
if err != nil {
result.Error = fmt.Sprintf("create cmd: %+v", err)
return result
}
defer cleanup()
cmd := cmdPty.AsExec()

cmd.Stdout = &out
Expand Down Expand Up @@ -1069,7 +1070,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
}()

// Empty command will default to the users shell!
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
cmd, cleanup, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
if err != nil {
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
return xerrors.Errorf("create command: %w", err)
Expand All @@ -1082,9 +1083,11 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m

if err = a.trackConnGoroutine(func() {
rpty.Wait()
cleanup()
a.reconnectingPTYs.Delete(msg.ID)
}); err != nil {
rpty.Close(err)
cleanup()
return xerrors.Errorf("start routine: %w", err)
}

Expand Down
3 changes: 2 additions & 1 deletion agent/agentscripts/agentscripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript)
cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout)
defer ctxCancel()
}
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
cmdPty, cleanup, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
if err != nil {
return xerrors.Errorf("%s script: create command: %w", logPath, err)
}
defer cleanup()
cmd = cmdPty.AsExec()
cmd.SysProcAttr = cmdSysProcAttr()
cmd.WaitDelay = 10 * time.Second
Expand Down
42 changes: 31 additions & 11 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -274,7 +275,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
magicTypeLabel := magicTypeMetricLabel(magicType)
sshPty, windowSize, isPty := session.Pty()

cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
cmd, cleanup, err := s.CreateCommand(ctx, session.RawCommand(), env)
if err != nil {
ptyLabel := "no"
if isPty {
Expand All @@ -283,6 +284,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, ptyLabel, "create_command").Add(1)
return err
}
defer cleanup()

if ssh.AgentRequested(session) {
l, err := ssh.NewAgentListener()
Expand Down Expand Up @@ -493,21 +495,21 @@ func (s *Server) sftpHandler(session ssh.Session) {
// CreateCommand processes raw command input with OpenSSH-like behavior.
// If the script provided is empty, it will default to the users shell.
// This injects environment variables specified by the user at launch too.
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, func(), error) {
currentUser, err := user.Current()
if err != nil {
return nil, xerrors.Errorf("get current user: %w", err)
return nil, nil, xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username

shell, err := usershell.Get(username)
if err != nil {
return nil, xerrors.Errorf("get user shell: %w", err)
return nil, nil, xerrors.Errorf("get user shell: %w", err)
}

manifest := s.Manifest.Load()
if manifest == nil {
return nil, xerrors.Errorf("no metadata was provided")
return nil, nil, xerrors.Errorf("no metadata was provided")
}

// OpenSSH executes all commands with the users current shell.
Expand All @@ -518,7 +520,9 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}
name := shell
args := []string{caller, script}

cleanup := func() {
// Default to noop. This only applies for scripts.
}
// A preceding space is generally not idiomatic for a shebang,
// but in Terraform it's quite standard to use <<EOF for a multi-line
// string which would indent with spaces, so we accept it for user-ease.
Expand All @@ -531,15 +535,31 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
shebang = strings.TrimPrefix(shebang, "#!")
words, err := shellquote.Split(shebang)
if err != nil {
return nil, xerrors.Errorf("split shebang: %w", err)
return nil, nil, xerrors.Errorf("split shebang: %w", err)
}
name = words[0]
if len(words) > 1 {
args = words[1:]
} else {
args = []string{}
}
args = append(args, caller, script)
scriptSha := sha256.Sum256([]byte(script))
tempFile, err := os.CreateTemp("", fmt.Sprintf("coder-script-%x.*", scriptSha))
if err != nil {
return nil, nil, xerrors.Errorf("create temp file: %w", err)
}
cleanup = func() {
_ = os.Remove(tempFile.Name())
}
_, err = tempFile.WriteString(script)
if err != nil {
return nil, nil, xerrors.Errorf("write temp file: %w", err)
}
err = tempFile.Close()
if err != nil {
return nil, nil, xerrors.Errorf("close temp file: %w", err)
}
args = append(args, tempFile.Name())
}

// gliderlabs/ssh returns a command slice of zero
Expand All @@ -563,14 +583,14 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
// Default to user home if a directory is not set.
homedir, err := userHomeDir()
if err != nil {
return nil, xerrors.Errorf("get home dir: %w", err)
return nil, nil, xerrors.Errorf("get home dir: %w", err)
}
cmd.Dir = homedir
}
cmd.Env = append(os.Environ(), env...)
executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
return nil, nil, xerrors.Errorf("getting os executable: %w", err)
}
// Set environment variables reliable detection of being inside a
// Coder workspace.
Expand Down Expand Up @@ -615,7 +635,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
}

return cmd, nil
return cmd, cleanup, nil
}

func (s *Server) Serve(l net.Listener) (retErr error) {
Expand Down
7 changes: 5 additions & 2 deletions agent/agentssh/agentssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,22 @@ func TestNewServer_ExecuteShebang(t *testing.T) {

t.Run("Basic", func(t *testing.T) {
t.Parallel()
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
cmd, cleanup, err := s.CreateCommand(ctx, `#!/bin/bash
echo test`, nil)
require.NoError(t, err)
t.Cleanup(cleanup)
output, err := cmd.AsExec().CombinedOutput()
require.NoError(t, err)
require.Equal(t, "test\n", string(output))
})
t.Run("Args", func(t *testing.T) {
t.Parallel()
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
cmd, cleanup, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
echo test`, nil)
require.NoError(t, err)
t.Cleanup(cleanup)
output, err := cmd.AsExec().CombinedOutput()
t.Log(string(output))
require.NoError(t, err)
require.Equal(t, "test\n", string(output))
})
Expand Down
Loading