Skip to content

Commit 743db9e

Browse files
committed
fix: properly handle shebangs by writing files
See #10134 (comment)
1 parent db8592f commit 743db9e

File tree

4 files changed

+42
-16
lines changed

4 files changed

+42
-16
lines changed

agent/agent.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -292,11 +292,12 @@ func (a *agent) collectMetadata(ctx context.Context, md codersdk.WorkspaceAgentM
292292
// if it can guarantee the clocks are synchronized.
293293
CollectedAt: now,
294294
}
295-
cmdPty, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
295+
cmdPty, cleanup, err := a.sshServer.CreateCommand(ctx, md.Script, nil)
296296
if err != nil {
297297
result.Error = fmt.Sprintf("create cmd: %+v", err)
298298
return result
299299
}
300+
defer cleanup()
300301
cmd := cmdPty.AsExec()
301302

302303
cmd.Stdout = &out
@@ -1069,7 +1070,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10691070
}()
10701071

10711072
// Empty command will default to the users shell!
1072-
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
1073+
cmd, cleanup, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
10731074
if err != nil {
10741075
a.metrics.reconnectingPTYErrors.WithLabelValues("create_command").Add(1)
10751076
return xerrors.Errorf("create command: %w", err)
@@ -1082,6 +1083,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10821083

10831084
if err = a.trackConnGoroutine(func() {
10841085
rpty.Wait()
1086+
cleanup()
10851087
a.reconnectingPTYs.Delete(msg.ID)
10861088
}); err != nil {
10871089
rpty.Close(err)

agent/agentscripts/agentscripts.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,11 @@ func (r *Runner) run(ctx context.Context, script codersdk.WorkspaceAgentScript)
171171
cmdCtx, ctxCancel = context.WithTimeout(ctx, script.Timeout)
172172
defer ctxCancel()
173173
}
174-
cmdPty, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
174+
cmdPty, cleanup, err := r.SSHServer.CreateCommand(cmdCtx, script.Script, nil)
175175
if err != nil {
176176
return xerrors.Errorf("%s script: create command: %w", logPath, err)
177177
}
178+
defer cleanup()
178179
cmd = cmdPty.AsExec()
179180
cmd.SysProcAttr = cmdSysProcAttr()
180181
cmd.WaitDelay = 10 * time.Second

agent/agentssh/agentssh.go

+31-11
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66
"crypto/rand"
77
"crypto/rsa"
8+
"crypto/sha256"
89
"errors"
910
"fmt"
1011
"io"
@@ -274,7 +275,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
274275
magicTypeLabel := magicTypeMetricLabel(magicType)
275276
sshPty, windowSize, isPty := session.Pty()
276277

277-
cmd, err := s.CreateCommand(ctx, session.RawCommand(), env)
278+
cmd, cleanup, err := s.CreateCommand(ctx, session.RawCommand(), env)
278279
if err != nil {
279280
ptyLabel := "no"
280281
if isPty {
@@ -283,6 +284,7 @@ func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr er
283284
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, ptyLabel, "create_command").Add(1)
284285
return err
285286
}
287+
defer cleanup()
286288

287289
if ssh.AgentRequested(session) {
288290
l, err := ssh.NewAgentListener()
@@ -493,21 +495,21 @@ func (s *Server) sftpHandler(session ssh.Session) {
493495
// CreateCommand processes raw command input with OpenSSH-like behavior.
494496
// If the script provided is empty, it will default to the users shell.
495497
// This injects environment variables specified by the user at launch too.
496-
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, error) {
498+
func (s *Server) CreateCommand(ctx context.Context, script string, env []string) (*pty.Cmd, func(), error) {
497499
currentUser, err := user.Current()
498500
if err != nil {
499-
return nil, xerrors.Errorf("get current user: %w", err)
501+
return nil, nil, xerrors.Errorf("get current user: %w", err)
500502
}
501503
username := currentUser.Username
502504

503505
shell, err := usershell.Get(username)
504506
if err != nil {
505-
return nil, xerrors.Errorf("get user shell: %w", err)
507+
return nil, nil, xerrors.Errorf("get user shell: %w", err)
506508
}
507509

508510
manifest := s.Manifest.Load()
509511
if manifest == nil {
510-
return nil, xerrors.Errorf("no metadata was provided")
512+
return nil, nil, xerrors.Errorf("no metadata was provided")
511513
}
512514

513515
// OpenSSH executes all commands with the users current shell.
@@ -518,7 +520,9 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
518520
}
519521
name := shell
520522
args := []string{caller, script}
521-
523+
cleanup := func() {
524+
// Default to noop. This only applies for scripts.
525+
}
522526
// A preceding space is generally not idiomatic for a shebang,
523527
// but in Terraform it's quite standard to use <<EOF for a multi-line
524528
// string which would indent with spaces, so we accept it for user-ease.
@@ -531,15 +535,31 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
531535
shebang = strings.TrimPrefix(shebang, "#!")
532536
words, err := shellquote.Split(shebang)
533537
if err != nil {
534-
return nil, xerrors.Errorf("split shebang: %w", err)
538+
return nil, nil, xerrors.Errorf("split shebang: %w", err)
535539
}
536540
name = words[0]
537541
if len(words) > 1 {
538542
args = words[1:]
539543
} else {
540544
args = []string{}
541545
}
542-
args = append(args, caller, script)
546+
scriptSha := sha256.Sum256([]byte(script))
547+
tempFile, err := os.CreateTemp("", fmt.Sprintf("coder-script-%x", scriptSha))
548+
if err != nil {
549+
return nil, nil, xerrors.Errorf("create temp file: %w", err)
550+
}
551+
cleanup = func() {
552+
_ = os.Remove(tempFile.Name())
553+
}
554+
_, err = tempFile.WriteString(script)
555+
if err != nil {
556+
return nil, nil, xerrors.Errorf("write temp file: %w", err)
557+
}
558+
err = tempFile.Close()
559+
if err != nil {
560+
return nil, nil, xerrors.Errorf("close temp file: %w", err)
561+
}
562+
args = append(args, tempFile.Name())
543563
}
544564

545565
// gliderlabs/ssh returns a command slice of zero
@@ -563,14 +583,14 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
563583
// Default to user home if a directory is not set.
564584
homedir, err := userHomeDir()
565585
if err != nil {
566-
return nil, xerrors.Errorf("get home dir: %w", err)
586+
return nil, nil, xerrors.Errorf("get home dir: %w", err)
567587
}
568588
cmd.Dir = homedir
569589
}
570590
cmd.Env = append(os.Environ(), env...)
571591
executablePath, err := os.Executable()
572592
if err != nil {
573-
return nil, xerrors.Errorf("getting os executable: %w", err)
593+
return nil, nil, xerrors.Errorf("getting os executable: %w", err)
574594
}
575595
// Set environment variables reliable detection of being inside a
576596
// Coder workspace.
@@ -615,7 +635,7 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
615635
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", envKey, value))
616636
}
617637

618-
return cmd, nil
638+
return cmd, cleanup, nil
619639
}
620640

621641
func (s *Server) Serve(l net.Listener) (retErr error) {

agent/agentssh/agentssh_test.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,22 @@ func TestNewServer_ExecuteShebang(t *testing.T) {
9090

9191
t.Run("Basic", func(t *testing.T) {
9292
t.Parallel()
93-
cmd, err := s.CreateCommand(ctx, `#!/bin/bash
93+
cmd, cleanup, err := s.CreateCommand(ctx, `#!/bin/bash
9494
echo test`, nil)
9595
require.NoError(t, err)
96+
t.Cleanup(cleanup)
9697
output, err := cmd.AsExec().CombinedOutput()
9798
require.NoError(t, err)
9899
require.Equal(t, "test\n", string(output))
99100
})
100101
t.Run("Args", func(t *testing.T) {
101102
t.Parallel()
102-
cmd, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
103+
cmd, cleanup, err := s.CreateCommand(ctx, `#!/usr/bin/env bash
103104
echo test`, nil)
104105
require.NoError(t, err)
106+
t.Cleanup(cleanup)
105107
output, err := cmd.AsExec().CombinedOutput()
108+
t.Log(string(output))
106109
require.NoError(t, err)
107110
require.Equal(t, "test\n", string(output))
108111
})

0 commit comments

Comments
 (0)