Skip to content

Commit 72155f0

Browse files
committed
refactor(agent/agentssh): move envs to agent and add agentssh config struct
1 parent 0e1bad4 commit 72155f0

File tree

9 files changed

+238
-131
lines changed

9 files changed

+238
-131
lines changed

agent/agent.go

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ func New(options Options) Agent {
146146
logger: options.Logger,
147147
closeCancel: cancelFunc,
148148
closed: make(chan struct{}),
149-
envVars: options.EnvironmentVariables,
149+
environmentVariables: options.EnvironmentVariables,
150150
client: options.Client,
151151
exchangeToken: options.ExchangeToken,
152152
filesystem: options.Filesystem,
@@ -169,6 +169,7 @@ func New(options Options) Agent {
169169
prometheusRegistry: prometheusRegistry,
170170
metrics: newAgentMetrics(prometheusRegistry),
171171
}
172+
a.serviceBanner.Store(&codersdk.ServiceBannerConfig{})
172173
a.init(ctx)
173174
return a
174175
}
@@ -196,7 +197,7 @@ type agent struct {
196197
closeMutex sync.Mutex
197198
closed chan struct{}
198199

199-
envVars map[string]string
200+
environmentVariables map[string]string
200201

201202
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
202203
reportMetadataInterval time.Duration
@@ -235,14 +236,16 @@ func (a *agent) TailnetConn() *tailnet.Conn {
235236
}
236237

237238
func (a *agent) init(ctx context.Context) {
238-
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.sshMaxTimeout, "")
239+
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
240+
MaxTimeout: a.sshMaxTimeout,
241+
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
242+
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
243+
UpdateEnv: a.updateCommandEnv,
244+
WorkingDirectory: func() string { return a.manifest.Load().Directory },
245+
})
239246
if err != nil {
240247
panic(err)
241248
}
242-
sshSrv.Env = a.envVars
243-
sshSrv.AgentToken = func() string { return *a.sessionToken.Load() }
244-
sshSrv.Manifest = &a.manifest
245-
sshSrv.ServiceBanner = &a.serviceBanner
246249
a.sshServer = sshSrv
247250
a.scriptRunner = agentscripts.New(agentscripts.Options{
248251
LogDir: a.logDir,
@@ -879,6 +882,80 @@ func (a *agent) run(ctx context.Context) error {
879882
return eg.Wait()
880883
}
881884

885+
// updateCommandEnv updates the provided command environment with the
886+
// following set of environment variables:
887+
// -
888+
func (a *agent) updateCommandEnv(current []string) (updated []string, err error) {
889+
manifest := a.manifest.Load()
890+
if manifest == nil {
891+
return nil, xerrors.Errorf("no manifest")
892+
}
893+
894+
executablePath, err := os.Executable()
895+
if err != nil {
896+
return nil, xerrors.Errorf("getting os executable: %w", err)
897+
}
898+
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")
899+
900+
// Define environment variables that should be set for all commands,
901+
// and then merge them with the current environment.
902+
envs := map[string]string{
903+
// Set env vars indicating we're inside a Coder workspace.
904+
"CODER": "true",
905+
"CODER_WORKSPACE_NAME": manifest.WorkspaceName,
906+
"CODER_WORKSPACE_AGENT_NAME": manifest.AgentName,
907+
908+
// Specific Coder subcommands require the agent token exposed!
909+
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),
910+
911+
// Git on Windows resolves with UNIX-style paths.
912+
// If using backslashes, it's unable to find the executable.
913+
"GIT_SSH_COMMAND": fmt.Sprintf("%s gitssh --", unixExecutablePath),
914+
// Hide Coder message on code-server's "Getting Started" page
915+
"CS_DISABLE_GETTING_STARTED_OVERRIDE": "true",
916+
}
917+
918+
// This adds the ports dialog to code-server that enables
919+
// proxying a port dynamically.
920+
// If this is empty string, do not set anything. Code-server auto defaults
921+
// using its basepath to construct a path based port proxy.
922+
if manifest.VSCodePortProxyURI != "" {
923+
envs["VSCODE_PROXY_URI"] = manifest.VSCodePortProxyURI
924+
}
925+
926+
// Allow any of the current env to override what we defined above.
927+
for _, env := range current {
928+
parts := strings.SplitN(env, "=", 2)
929+
if len(parts) != 2 {
930+
continue
931+
}
932+
if _, ok := envs[parts[0]]; !ok {
933+
envs[parts[0]] = parts[1]
934+
}
935+
}
936+
937+
// Load environment variables passed via the agent manifest.
938+
// These override all variables we manually specify.
939+
for k, v := range manifest.EnvironmentVariables {
940+
// Expanding environment variables allows for customization
941+
// of the $PATH, among other variables. Customers can prepend
942+
// or append to the $PATH, so allowing expand is required!
943+
envs[k] = os.ExpandEnv(v)
944+
}
945+
946+
// Agent-level environment variables should take over all. This is
947+
// used for setting agent-specific variables like CODER_AGENT_TOKEN
948+
// and GIT_ASKPASS.
949+
for k, v := range a.environmentVariables {
950+
envs[k] = v
951+
}
952+
953+
for k, v := range envs {
954+
updated = append(updated, fmt.Sprintf("%s=%s", k, v))
955+
}
956+
return updated, nil
957+
}
958+
882959
func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
883960
if len(a.addresses) == 0 {
884961
return []netip.Prefix{
@@ -1314,7 +1391,7 @@ func (a *agent) manageProcessPriorityLoop(ctx context.Context) {
13141391
}
13151392
}()
13161393

1317-
if val := a.envVars[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
1394+
if val := a.environmentVariables[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
13181395
a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ",
13191396
slog.F("env_var", EnvProcPrioMgmt),
13201397
slog.F("value", val),

agent/agent_test.go

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,68 @@ func TestAgent_SessionExec(t *testing.T) {
281281
require.Equal(t, "test", strings.TrimSpace(string(output)))
282282
}
283283

284+
//nolint:tparallel // Sub tests need to run sequentially.
285+
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
286+
t.Parallel()
287+
288+
manifest := agentsdk.Manifest{
289+
EnvironmentVariables: map[string]string{
290+
"MY_MANIFEST": "true",
291+
"MY_OVERRIDE": "false",
292+
"MY_SESSION_MANIFEST": "false",
293+
},
294+
}
295+
banner := codersdk.ServiceBannerConfig{}
296+
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
297+
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
298+
})
299+
300+
err := session.Setenv("MY_SESSION_MANIFEST", "true")
301+
require.NoError(t, err)
302+
err = session.Setenv("MY_SESSION", "true")
303+
require.NoError(t, err)
304+
305+
command := "sh"
306+
echoEnv := func(t *testing.T, w io.Writer, r io.Reader, env string) string {
307+
if runtime.GOOS == "windows" {
308+
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
309+
require.NoError(t, err)
310+
} else {
311+
_, err := fmt.Fprintf(w, "echo $%s\n", env)
312+
require.NoError(t, err)
313+
}
314+
scanner := bufio.NewScanner(r)
315+
require.True(t, scanner.Scan())
316+
t.Logf("%s=%s", env, scanner.Text())
317+
return scanner.Text()
318+
}
319+
if runtime.GOOS == "windows" {
320+
command = "cmd.exe"
321+
}
322+
stdin, err := session.StdinPipe()
323+
require.NoError(t, err)
324+
defer stdin.Close()
325+
stdout, err := session.StdoutPipe()
326+
require.NoError(t, err)
327+
328+
err = session.Start(command)
329+
require.NoError(t, err)
330+
331+
//nolint:paralleltest // These tests need to run sequentially.
332+
for k, partialV := range map[string]string{
333+
"CODER": "true", // From the agent.
334+
"MY_MANIFEST": "true", // From the manifest.
335+
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
336+
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
337+
"MY_SESSION": "true", // From the session.
338+
} {
339+
t.Run(k, func(t *testing.T) {
340+
out := echoEnv(t, stdin, stdout, k)
341+
require.Contains(t, strings.TrimSpace(out), partialV)
342+
})
343+
}
344+
}
345+
284346
func TestAgent_GitSSH(t *testing.T) {
285347
t.Parallel()
286348
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
@@ -1991,15 +2053,17 @@ func setupSSHSession(
19912053
manifest agentsdk.Manifest,
19922054
serviceBanner codersdk.ServiceBannerConfig,
19932055
prepareFS func(fs afero.Fs),
2056+
opts ...func(*agenttest.Client, *agent.Options),
19942057
) *ssh.Session {
19952058
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
19962059
defer cancel()
1997-
//nolint:dogsled
1998-
conn, _, _, fs, _ := setupAgent(t, manifest, 0, func(c *agenttest.Client, _ *agent.Options) {
2060+
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
19992061
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
20002062
return serviceBanner, nil
20012063
})
20022064
})
2065+
//nolint:dogsled
2066+
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
20032067
if prepareFS != nil {
20042068
prepareFS(fs)
20052069
}
@@ -2057,6 +2121,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
20572121
Filesystem: fs,
20582122
Logger: logger.Named("agent"),
20592123
ReconnectingPTYTimeout: ptyTimeout,
2124+
EnvironmentVariables: map[string]string{},
20602125
}
20612126

20622127
for _, opt := range opts {

agent/agentscripts/agentscripts_test.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"github.com/prometheus/client_golang/prometheus"
99
"github.com/spf13/afero"
1010
"github.com/stretchr/testify/require"
11-
"go.uber.org/atomic"
1211
"go.uber.org/goleak"
1312

1413
"cdr.dev/slog/sloggers/slogtest"
@@ -72,10 +71,8 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL
7271
}
7372
fs := afero.NewMemMapFs()
7473
logger := slogtest.Make(t, nil)
75-
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, 0, "")
74+
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
7675
require.NoError(t, err)
77-
s.AgentToken = func() string { return "" }
78-
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
7976
t.Cleanup(func() {
8077
_ = s.Close()
8178
})

0 commit comments

Comments
 (0)