Skip to content

Commit a2dd618

Browse files
authored
feat: Use environment variables and startup script in agent (#1147)
These values were ignored. Environment variables are applied to new sessions, and are refreshed on reconnect. This is cool because a workspace could be updated with new environment variables without requiring a complete start/stop. The startup script is only ran once regardless of changes, which feels like the expected behavior.
1 parent 09405dd commit a2dd618

File tree

10 files changed

+189
-28
lines changed

10 files changed

+189
-28
lines changed

.vscode/settings.json

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"gographviz",
1717
"goleak",
1818
"gossh",
19+
"gsyslog",
1920
"hashicorp",
2021
"hclsyntax",
2122
"httpmw",

agent/agent.go

+87-5
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ import (
1111
"os"
1212
"os/exec"
1313
"os/user"
14+
"runtime"
1415
"sync"
1516
"time"
1617

18+
gsyslog "github.com/hashicorp/go-syslog"
19+
"go.uber.org/atomic"
20+
1721
"cdr.dev/slog"
1822
"github.com/coder/coder/agent/usershell"
1923
"github.com/coder/coder/peer"
@@ -29,10 +33,11 @@ import (
2933
)
3034

3135
type Options struct {
32-
Logger slog.Logger
36+
EnvironmentVariables map[string]string
37+
StartupScript string
3338
}
3439

35-
type Dialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
40+
type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)
3641

3742
func New(dialer Dialer, logger slog.Logger) io.Closer {
3843
ctx, cancelFunc := context.WithCancel(context.Background())
@@ -55,16 +60,21 @@ type agent struct {
5560
closeMutex sync.Mutex
5661
closed chan struct{}
5762

58-
sshServer *ssh.Server
63+
// Environment variables sent by Coder to inject for shell sessions.
64+
// This is atomic because values can change after reconnect.
65+
envVars atomic.Value
66+
startupScript atomic.Bool
67+
sshServer *ssh.Server
5968
}
6069

6170
func (a *agent) run(ctx context.Context) {
71+
var options *Options
6272
var peerListener *peerbroker.Listener
6373
var err error
6474
// An exponential back-off occurs when the connection is failing to dial.
6575
// This is to prevent server spam in case of a coderd outage.
6676
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
67-
peerListener, err = a.dialer(ctx, a.logger)
77+
options, peerListener, err = a.dialer(ctx, a.logger)
6878
if err != nil {
6979
if errors.Is(err, context.Canceled) {
7080
return
@@ -83,6 +93,20 @@ func (a *agent) run(ctx context.Context) {
8393
return
8494
default:
8595
}
96+
a.envVars.Store(options.EnvironmentVariables)
97+
98+
if a.startupScript.CAS(false, true) {
99+
// The startup script has not ran yet!
100+
go func() {
101+
err := a.runStartupScript(ctx, options.StartupScript)
102+
if errors.Is(err, context.Canceled) {
103+
return
104+
}
105+
if err != nil {
106+
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
107+
}
108+
}()
109+
}
86110

87111
for {
88112
conn, err := peerListener.Accept()
@@ -101,6 +125,48 @@ func (a *agent) run(ctx context.Context) {
101125
}
102126
}
103127

128+
func (*agent) runStartupScript(ctx context.Context, script string) error {
129+
if script == "" {
130+
return nil
131+
}
132+
currentUser, err := user.Current()
133+
if err != nil {
134+
return xerrors.Errorf("get current user: %w", err)
135+
}
136+
username := currentUser.Username
137+
138+
shell, err := usershell.Get(username)
139+
if err != nil {
140+
return xerrors.Errorf("get user shell: %w", err)
141+
}
142+
143+
var writer io.WriteCloser
144+
// Attempt to use the syslog to write startup information.
145+
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
146+
if err != nil {
147+
// If the syslog isn't supported or cannot be created, use a text file in temp.
148+
writer, err = os.CreateTemp("", "coder-startup-script.txt")
149+
if err != nil {
150+
return xerrors.Errorf("open startup script log file: %w", err)
151+
}
152+
}
153+
defer func() {
154+
_ = writer.Close()
155+
}()
156+
caller := "-c"
157+
if runtime.GOOS == "windows" {
158+
caller = "/c"
159+
}
160+
cmd := exec.CommandContext(ctx, shell, caller, script)
161+
cmd.Stdout = writer
162+
cmd.Stderr = writer
163+
err = cmd.Run()
164+
if err != nil {
165+
return xerrors.Errorf("run: %w", err)
166+
}
167+
return nil
168+
}
169+
104170
func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
105171
go func() {
106172
select {
@@ -230,8 +296,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
230296

231297
// OpenSSH executes all commands with the users current shell.
232298
// We replicate that behavior for IDE support.
233-
cmd := exec.CommandContext(session.Context(), shell, "-c", command)
299+
caller := "-c"
300+
if runtime.GOOS == "windows" {
301+
caller = "/c"
302+
}
303+
cmd := exec.CommandContext(session.Context(), shell, caller, command)
234304
cmd.Env = append(os.Environ(), session.Environ()...)
305+
306+
// Load environment variables passed via the agent.
307+
envVars := a.envVars.Load()
308+
if envVars != nil {
309+
envVarMap, ok := envVars.(map[string]string)
310+
if ok {
311+
for key, value := range envVarMap {
312+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
313+
}
314+
}
315+
}
316+
235317
executablePath, err := os.Executable()
236318
if err != nil {
237319
return xerrors.Errorf("getting os executable: %w", err)

agent/agent_test.go

+59-10
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ import (
1212
"strconv"
1313
"strings"
1414
"testing"
15+
"time"
1516

1617
"github.com/pion/webrtc/v3"
1718
"github.com/pkg/sftp"
1819
"github.com/stretchr/testify/require"
1920
"go.uber.org/goleak"
2021
"golang.org/x/crypto/ssh"
22+
"golang.org/x/text/encoding/unicode"
23+
"golang.org/x/text/transform"
2124

2225
"cdr.dev/slog"
2326
"cdr.dev/slog/sloggers/slogtest"
@@ -37,7 +40,7 @@ func TestAgent(t *testing.T) {
3740
t.Parallel()
3841
t.Run("SessionExec", func(t *testing.T) {
3942
t.Parallel()
40-
session := setupSSHSession(t)
43+
session := setupSSHSession(t, nil)
4144

4245
command := "echo test"
4346
if runtime.GOOS == "windows" {
@@ -50,7 +53,7 @@ func TestAgent(t *testing.T) {
5053

5154
t.Run("GitSSH", func(t *testing.T) {
5255
t.Parallel()
53-
session := setupSSHSession(t)
56+
session := setupSSHSession(t, nil)
5457
command := "sh -c 'echo $GIT_SSH_COMMAND'"
5558
if runtime.GOOS == "windows" {
5659
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
@@ -68,7 +71,7 @@ func TestAgent(t *testing.T) {
6871
// it seems like it could be either.
6972
t.Skip("ConPTY appears to be inconsistent on Windows.")
7073
}
71-
session := setupSSHSession(t)
74+
session := setupSSHSession(t, nil)
7275
command := "bash"
7376
if runtime.GOOS == "windows" {
7477
command = "cmd.exe"
@@ -128,7 +131,7 @@ func TestAgent(t *testing.T) {
128131

129132
t.Run("SFTP", func(t *testing.T) {
130133
t.Parallel()
131-
sshClient, err := setupAgent(t).SSHClient()
134+
sshClient, err := setupAgent(t, nil).SSHClient()
132135
require.NoError(t, err)
133136
client, err := sftp.NewClient(sshClient)
134137
require.NoError(t, err)
@@ -140,10 +143,52 @@ func TestAgent(t *testing.T) {
140143
_, err = os.Stat(tempFile)
141144
require.NoError(t, err)
142145
})
146+
147+
t.Run("EnvironmentVariables", func(t *testing.T) {
148+
t.Parallel()
149+
key := "EXAMPLE"
150+
value := "value"
151+
session := setupSSHSession(t, &agent.Options{
152+
EnvironmentVariables: map[string]string{
153+
key: value,
154+
},
155+
})
156+
command := "sh -c 'echo $" + key + "'"
157+
if runtime.GOOS == "windows" {
158+
command = "cmd.exe /c echo %" + key + "%"
159+
}
160+
output, err := session.Output(command)
161+
require.NoError(t, err)
162+
require.Equal(t, value, strings.TrimSpace(string(output)))
163+
})
164+
165+
t.Run("StartupScript", func(t *testing.T) {
166+
t.Parallel()
167+
tempPath := filepath.Join(os.TempDir(), "content.txt")
168+
content := "somethingnice"
169+
setupAgent(t, &agent.Options{
170+
StartupScript: "echo " + content + " > " + tempPath,
171+
})
172+
var gotContent string
173+
require.Eventually(t, func() bool {
174+
content, err := os.ReadFile(tempPath)
175+
if err != nil {
176+
return false
177+
}
178+
if runtime.GOOS == "windows" {
179+
// Windows uses UTF16! 🪟🪟🪟
180+
content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
181+
require.NoError(t, err)
182+
}
183+
gotContent = string(content)
184+
return true
185+
}, 15*time.Second, 100*time.Millisecond)
186+
require.Equal(t, content, strings.TrimSpace(gotContent))
187+
})
143188
}
144189

145190
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
146-
agentConn := setupAgent(t)
191+
agentConn := setupAgent(t, nil)
147192
listener, err := net.Listen("tcp", "127.0.0.1:0")
148193
require.NoError(t, err)
149194
go func() {
@@ -171,18 +216,22 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
171216
return exec.Command("ssh", args...)
172217
}
173218

174-
func setupSSHSession(t *testing.T) *ssh.Session {
175-
sshClient, err := setupAgent(t).SSHClient()
219+
func setupSSHSession(t *testing.T, options *agent.Options) *ssh.Session {
220+
sshClient, err := setupAgent(t, options).SSHClient()
176221
require.NoError(t, err)
177222
session, err := sshClient.NewSession()
178223
require.NoError(t, err)
179224
return session
180225
}
181226

182-
func setupAgent(t *testing.T) *agent.Conn {
227+
func setupAgent(t *testing.T, options *agent.Options) *agent.Conn {
228+
if options == nil {
229+
options = &agent.Options{}
230+
}
183231
client, server := provisionersdk.TransportPipe()
184-
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
185-
return peerbroker.Listen(server, nil)
232+
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*agent.Options, *peerbroker.Listener, error) {
233+
listener, err := peerbroker.Listen(server, nil)
234+
return options, listener, err
186235
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
187236
t.Cleanup(func() {
188237
_ = client.Close()

cli/gitssh.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ import (
77
"os/exec"
88
"strings"
99

10-
"github.com/coder/coder/cli/cliui"
11-
"github.com/coder/coder/codersdk"
1210
"github.com/spf13/cobra"
1311
"golang.org/x/xerrors"
12+
13+
"github.com/coder/coder/cli/cliui"
14+
"github.com/coder/coder/codersdk"
1415
)
1516

1617
func gitssh() *cobra.Command {

coderd/coderd.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ func New(options *Options) (http.Handler, func()) {
197197
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
198198
r.Route("/me", func(r chi.Router) {
199199
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
200-
r.Get("/", api.workspaceAgentListen)
200+
r.Get("/", api.workspaceAgentMe)
201+
r.Get("/listen", api.workspaceAgentListen)
201202
r.Get("/gitsshkey", api.agentGitSSHKey)
202203
r.Get("/turn", api.workspaceAgentTurn)
203204
r.Get("/iceservers", api.workspaceAgentICEServers)

coderd/workspaceagents.go

+12
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
8888
}
8989
}
9090

91+
func (api *api) workspaceAgentMe(rw http.ResponseWriter, r *http.Request) {
92+
agent := httpmw.WorkspaceAgent(r)
93+
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
94+
if err != nil {
95+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
96+
Message: fmt.Sprintf("convert workspace agent: %s", err),
97+
})
98+
return
99+
}
100+
httpapi.Write(rw, http.StatusOK, apiAgent)
101+
}
102+
91103
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
92104
api.websocketWaitMutex.Lock()
93105
api.websocketWaitGroup.Add(1)

coderd/workspaceagents_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ func TestWorkspaceAgentListen(t *testing.T) {
102102
})
103103
_, err = conn.Ping()
104104
require.NoError(t, err)
105+
_, err = agentClient.WorkspaceAgent(context.Background(), codersdk.Me)
106+
require.NoError(t, err)
105107
}
106108

107109
func TestWorkspaceAgentTURN(t *testing.T) {

0 commit comments

Comments
 (0)