Skip to content

Commit f771652

Browse files
committed
feat: Use environment variables and startup script in agent
These values were ignored. Environment variables are applied to new sessions, and are refreshed on reconnect. This is cool because a workspace could be updated with new environment variables without requiring a complete start/stop. The startup script is only ran once regardless of changes, which feels like the expected behavior.
1 parent 8c27b4e commit f771652

File tree

10 files changed

+189
-28
lines changed

10 files changed

+189
-28
lines changed

.vscode/settings.json

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"gographviz",
1717
"goleak",
1818
"gossh",
19+
"gsyslog",
1920
"hashicorp",
2021
"hclsyntax",
2122
"httpmw",

agent/agent.go

+87-5
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,13 @@ import (
1111
"os"
1212
"os/exec"
1313
"os/user"
14+
"runtime"
1415
"sync"
1516
"time"
1617

18+
gsyslog "github.com/hashicorp/go-syslog"
19+
"go.uber.org/atomic"
20+
1721
"cdr.dev/slog"
1822
"github.com/coder/coder/agent/usershell"
1923
"github.com/coder/coder/peer"
@@ -29,10 +33,11 @@ import (
2933
)
3034

3135
type Options struct {
32-
Logger slog.Logger
36+
EnvironmentVariables map[string]string
37+
StartupScript string
3338
}
3439

35-
type Dialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
40+
type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)
3641

3742
func New(dialer Dialer, logger slog.Logger) io.Closer {
3843
ctx, cancelFunc := context.WithCancel(context.Background())
@@ -55,16 +60,21 @@ type agent struct {
5560
closeMutex sync.Mutex
5661
closed chan struct{}
5762

58-
sshServer *ssh.Server
63+
// Environment variables sent by Coder to inject for shell sessions.
64+
// This is atomic because values can change after reconnect.
65+
envVars atomic.Value
66+
startupScript atomic.Bool
67+
sshServer *ssh.Server
5968
}
6069

6170
func (a *agent) run(ctx context.Context) {
71+
var options *Options
6272
var peerListener *peerbroker.Listener
6373
var err error
6474
// An exponential back-off occurs when the connection is failing to dial.
6575
// This is to prevent server spam in case of a coderd outage.
6676
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
67-
peerListener, err = a.dialer(ctx, a.logger)
77+
options, peerListener, err = a.dialer(ctx, a.logger)
6878
if err != nil {
6979
if errors.Is(err, context.Canceled) {
7080
return
@@ -83,6 +93,20 @@ func (a *agent) run(ctx context.Context) {
8393
return
8494
default:
8595
}
96+
a.envVars.Store(options.EnvironmentVariables)
97+
98+
if a.startupScript.CAS(false, true) {
99+
// The startup script has not ran yet!
100+
go func() {
101+
err := a.runStartupScript(ctx, options.StartupScript)
102+
if errors.Is(err, context.Canceled) {
103+
return
104+
}
105+
if err != nil {
106+
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
107+
}
108+
}()
109+
}
86110

87111
for {
88112
conn, err := peerListener.Accept()
@@ -101,6 +125,48 @@ func (a *agent) run(ctx context.Context) {
101125
}
102126
}
103127

128+
func (*agent) runStartupScript(ctx context.Context, script string) error {
129+
if script == "" {
130+
return nil
131+
}
132+
currentUser, err := user.Current()
133+
if err != nil {
134+
return xerrors.Errorf("get current user: %w", err)
135+
}
136+
username := currentUser.Username
137+
138+
shell, err := usershell.Get(username)
139+
if err != nil {
140+
return xerrors.Errorf("get user shell: %w", err)
141+
}
142+
143+
var writer io.WriteCloser
144+
// Attempt to use the syslog to write startup information.
145+
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
146+
if err != nil {
147+
// If the syslog isn't supported or cannot be created, use a text file in temp.
148+
writer, err = os.CreateTemp("", "coder-startup-script.txt")
149+
if err != nil {
150+
return xerrors.Errorf("open startup script log file: %w", err)
151+
}
152+
}
153+
defer func() {
154+
_ = writer.Close()
155+
}()
156+
caller := "-c"
157+
if runtime.GOOS == "windows" {
158+
caller = "/c"
159+
}
160+
cmd := exec.CommandContext(ctx, shell, caller, script)
161+
cmd.Stdout = writer
162+
cmd.Stderr = writer
163+
err = cmd.Run()
164+
if err != nil {
165+
return xerrors.Errorf("run: %w", err)
166+
}
167+
return nil
168+
}
169+
104170
func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
105171
go func() {
106172
select {
@@ -230,8 +296,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
230296

231297
// OpenSSH executes all commands with the users current shell.
232298
// We replicate that behavior for IDE support.
233-
cmd := exec.CommandContext(session.Context(), shell, "-c", command)
299+
caller := "-c"
300+
if runtime.GOOS == "windows" {
301+
caller = "/c"
302+
}
303+
cmd := exec.CommandContext(session.Context(), shell, caller, command)
234304
cmd.Env = append(os.Environ(), session.Environ()...)
305+
306+
// Load environment variables passed via the agent.
307+
envVars := a.envVars.Load()
308+
if envVars != nil {
309+
envVarMap, ok := envVars.(map[string]string)
310+
if ok {
311+
for key, value := range envVarMap {
312+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
313+
}
314+
}
315+
}
316+
235317
executablePath, err := os.Executable()
236318
if err != nil {
237319
return xerrors.Errorf("getting os executable: %w", err)

agent/agent_test.go

+59-10
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@ import (
1212
"strconv"
1313
"strings"
1414
"testing"
15+
"time"
1516

1617
"github.com/pion/webrtc/v3"
1718
"github.com/pkg/sftp"
1819
"github.com/stretchr/testify/require"
1920
"go.uber.org/goleak"
2021
"golang.org/x/crypto/ssh"
22+
"golang.org/x/text/encoding/unicode"
23+
"golang.org/x/text/transform"
2124

2225
"cdr.dev/slog"
2326
"cdr.dev/slog/sloggers/slogtest"
@@ -37,7 +40,7 @@ func TestAgent(t *testing.T) {
3740
t.Parallel()
3841
t.Run("SessionExec", func(t *testing.T) {
3942
t.Parallel()
40-
session := setupSSHSession(t)
43+
session := setupSSHSession(t, nil)
4144

4245
command := "echo test"
4346
if runtime.GOOS == "windows" {
@@ -50,7 +53,7 @@ func TestAgent(t *testing.T) {
5053

5154
t.Run("GitSSH", func(t *testing.T) {
5255
t.Parallel()
53-
session := setupSSHSession(t)
56+
session := setupSSHSession(t, nil)
5457
command := "sh -c 'echo $GIT_SSH_COMMAND'"
5558
if runtime.GOOS == "windows" {
5659
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
@@ -62,7 +65,7 @@ func TestAgent(t *testing.T) {
6265

6366
t.Run("SessionTTY", func(t *testing.T) {
6467
t.Parallel()
65-
session := setupSSHSession(t)
68+
session := setupSSHSession(t, nil)
6669
command := "bash"
6770
if runtime.GOOS == "windows" {
6871
command = "cmd.exe"
@@ -117,7 +120,7 @@ func TestAgent(t *testing.T) {
117120

118121
t.Run("SFTP", func(t *testing.T) {
119122
t.Parallel()
120-
sshClient, err := setupAgent(t).SSHClient()
123+
sshClient, err := setupAgent(t, nil).SSHClient()
121124
require.NoError(t, err)
122125
client, err := sftp.NewClient(sshClient)
123126
require.NoError(t, err)
@@ -129,10 +132,52 @@ func TestAgent(t *testing.T) {
129132
_, err = os.Stat(tempFile)
130133
require.NoError(t, err)
131134
})
135+
136+
t.Run("EnvironmentVariables", func(t *testing.T) {
137+
t.Parallel()
138+
key := "EXAMPLE"
139+
value := "value"
140+
session := setupSSHSession(t, &agent.Options{
141+
EnvironmentVariables: map[string]string{
142+
key: value,
143+
},
144+
})
145+
command := "sh -c 'echo $" + key + "'"
146+
if runtime.GOOS == "windows" {
147+
command = "cmd.exe /c echo %" + key + "%"
148+
}
149+
output, err := session.Output(command)
150+
require.NoError(t, err)
151+
require.Equal(t, value, strings.TrimSpace(string(output)))
152+
})
153+
154+
t.Run("StartupScript", func(t *testing.T) {
155+
t.Parallel()
156+
tempPath := filepath.Join(os.TempDir(), "content.txt")
157+
content := "somethingnice"
158+
setupAgent(t, &agent.Options{
159+
StartupScript: "echo " + content + " > " + tempPath,
160+
})
161+
var gotContent string
162+
require.Eventually(t, func() bool {
163+
content, err := os.ReadFile(tempPath)
164+
if err != nil {
165+
return false
166+
}
167+
if runtime.GOOS == "windows" {
168+
// Windows uses UTF16! 🪟🪟🪟
169+
content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
170+
require.NoError(t, err)
171+
}
172+
gotContent = string(content)
173+
return true
174+
}, 15*time.Second, 100*time.Millisecond)
175+
require.Equal(t, content, strings.TrimSpace(gotContent))
176+
})
132177
}
133178

134179
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
135-
agentConn := setupAgent(t)
180+
agentConn := setupAgent(t, nil)
136181
listener, err := net.Listen("tcp", "127.0.0.1:0")
137182
require.NoError(t, err)
138183
go func() {
@@ -160,18 +205,22 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
160205
return exec.Command("ssh", args...)
161206
}
162207

163-
func setupSSHSession(t *testing.T) *ssh.Session {
164-
sshClient, err := setupAgent(t).SSHClient()
208+
func setupSSHSession(t *testing.T, options *agent.Options) *ssh.Session {
209+
sshClient, err := setupAgent(t, options).SSHClient()
165210
require.NoError(t, err)
166211
session, err := sshClient.NewSession()
167212
require.NoError(t, err)
168213
return session
169214
}
170215

171-
func setupAgent(t *testing.T) *agent.Conn {
216+
func setupAgent(t *testing.T, options *agent.Options) *agent.Conn {
217+
if options == nil {
218+
options = &agent.Options{}
219+
}
172220
client, server := provisionersdk.TransportPipe()
173-
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
174-
return peerbroker.Listen(server, nil)
221+
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*agent.Options, *peerbroker.Listener, error) {
222+
listener, err := peerbroker.Listen(server, nil)
223+
return options, listener, err
175224
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
176225
t.Cleanup(func() {
177226
_ = client.Close()

cli/gitssh.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ import (
77
"os/exec"
88
"strings"
99

10-
"github.com/coder/coder/cli/cliui"
11-
"github.com/coder/coder/codersdk"
1210
"github.com/spf13/cobra"
1311
"golang.org/x/xerrors"
12+
13+
"github.com/coder/coder/cli/cliui"
14+
"github.com/coder/coder/codersdk"
1415
)
1516

1617
func gitssh() *cobra.Command {

coderd/coderd.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ func New(options *Options) (http.Handler, func()) {
197197
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
198198
r.Route("/me", func(r chi.Router) {
199199
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
200-
r.Get("/", api.workspaceAgentListen)
200+
r.Get("/", api.workspaceAgentMe)
201+
r.Get("/listen", api.workspaceAgentListen)
201202
r.Get("/gitsshkey", api.agentGitSSHKey)
202203
r.Get("/turn", api.workspaceAgentTurn)
203204
r.Get("/iceservers", api.workspaceAgentICEServers)

coderd/workspaceagents.go

+12
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,18 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
8888
}
8989
}
9090

91+
func (api *api) workspaceAgentMe(rw http.ResponseWriter, r *http.Request) {
92+
agent := httpmw.WorkspaceAgent(r)
93+
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
94+
if err != nil {
95+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
96+
Message: fmt.Sprintf("convert workspace agent: %s", err),
97+
})
98+
return
99+
}
100+
httpapi.Write(rw, http.StatusOK, apiAgent)
101+
}
102+
91103
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
92104
api.websocketWaitMutex.Lock()
93105
api.websocketWaitGroup.Add(1)

coderd/workspaceagents_test.go

+2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ func TestWorkspaceAgentListen(t *testing.T) {
102102
})
103103
_, err = conn.Ping()
104104
require.NoError(t, err)
105+
_, err = agentClient.WorkspaceAgent(context.Background(), codersdk.Me)
106+
require.NoError(t, err)
105107
}
106108

107109
func TestWorkspaceAgentTURN(t *testing.T) {

0 commit comments

Comments
 (0)