Skip to content

refactor(agent/agentssh): move envs to agent and add agentssh config struct #12204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func New(options Options) Agent {
logger: options.Logger,
closeCancel: cancelFunc,
closed: make(chan struct{}),
envVars: options.EnvironmentVariables,
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
Expand All @@ -169,6 +169,8 @@ func New(options Options) Agent {
prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
}
a.serviceBanner.Store(new(codersdk.ServiceBannerConfig))
a.sessionToken.Store(new(string))
a.init(ctx)
return a
}
Expand Down Expand Up @@ -196,7 +198,7 @@ type agent struct {
closeMutex sync.Mutex
closed chan struct{}

envVars map[string]string
environmentVariables map[string]string

manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
reportMetadataInterval time.Duration
Expand Down Expand Up @@ -235,14 +237,16 @@ func (a *agent) TailnetConn() *tailnet.Conn {
}

func (a *agent) init(ctx context.Context) {
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, a.sshMaxTimeout, "")
sshSrv, err := agentssh.NewServer(ctx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
})
if err != nil {
panic(err)
}
sshSrv.Env = a.envVars
sshSrv.AgentToken = func() string { return *a.sessionToken.Load() }
sshSrv.Manifest = &a.manifest
sshSrv.ServiceBanner = &a.serviceBanner
a.sshServer = sshSrv
a.scriptRunner = agentscripts.New(agentscripts.Options{
LogDir: a.logDir,
Expand Down Expand Up @@ -879,6 +883,83 @@ func (a *agent) run(ctx context.Context) error {
return eg.Wait()
}

// updateCommandEnv updates the provided command environment with the
// following set of environment variables:
// - Predefined workspace environment variables
// - Environment variables currently set (overriding predefined)
// - Environment variables passed via the agent manifest (overriding predefined and current)
// - Agent-level environment variables (overriding all)
func (a *agent) updateCommandEnv(current []string) (updated []string, err error) {
manifest := a.manifest.Load()
if manifest == nil {
return nil, xerrors.Errorf("no manifest")
}

executablePath, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("getting os executable: %w", err)
}
unixExecutablePath := strings.ReplaceAll(executablePath, "\\", "/")

// Define environment variables that should be set for all commands,
// and then merge them with the current environment.
envs := map[string]string{
// Set env vars indicating we're inside a Coder workspace.
"CODER": "true",
"CODER_WORKSPACE_NAME": manifest.WorkspaceName,
"CODER_WORKSPACE_AGENT_NAME": manifest.AgentName,

// Specific Coder subcommands require the agent token exposed!
"CODER_AGENT_TOKEN": *a.sessionToken.Load(),

// Git on Windows resolves with UNIX-style paths.
// If using backslashes, it's unable to find the executable.
"GIT_SSH_COMMAND": fmt.Sprintf("%s gitssh --", unixExecutablePath),
// Hide Coder message on code-server's "Getting Started" page
"CS_DISABLE_GETTING_STARTED_OVERRIDE": "true",
}

// This adds the ports dialog to code-server that enables
// proxying a port dynamically.
// If this is empty string, do not set anything. Code-server auto defaults
// using its basepath to construct a path based port proxy.
if manifest.VSCodePortProxyURI != "" {
envs["VSCODE_PROXY_URI"] = manifest.VSCodePortProxyURI
}

// Allow any of the current env to override what we defined above.
for _, env := range current {
parts := strings.SplitN(env, "=", 2)
if len(parts) != 2 {
continue
}
if _, ok := envs[parts[0]]; !ok {
envs[parts[0]] = parts[1]
}
}

// Load environment variables passed via the agent manifest.
// These override all variables we manually specify.
for k, v := range manifest.EnvironmentVariables {
// Expanding environment variables allows for customization
// of the $PATH, among other variables. Customers can prepend
// or append to the $PATH, so allowing expand is required!
envs[k] = os.ExpandEnv(v)
}

// Agent-level environment variables should take over all. This is
// used for setting agent-specific variables like CODER_AGENT_TOKEN
// and GIT_ASKPASS.
for k, v := range a.environmentVariables {
envs[k] = v
}

for k, v := range envs {
updated = append(updated, fmt.Sprintf("%s=%s", k, v))
}
return updated, nil
}

func (a *agent) wireguardAddresses(agentID uuid.UUID) []netip.Prefix {
if len(a.addresses) == 0 {
return []netip.Prefix{
Expand Down Expand Up @@ -1314,7 +1395,7 @@ func (a *agent) manageProcessPriorityLoop(ctx context.Context) {
}
}()

if val := a.envVars[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
if val := a.environmentVariables[EnvProcPrioMgmt]; val == "" || runtime.GOOS != "linux" {
a.logger.Debug(ctx, "process priority not enabled, agent will not manage process niceness/oom_score_adj ",
slog.F("env_var", EnvProcPrioMgmt),
slog.F("value", val),
Expand Down
85 changes: 83 additions & 2 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -281,6 +282,83 @@ func TestAgent_SessionExec(t *testing.T) {
require.Equal(t, "test", strings.TrimSpace(string(output)))
}

//nolint:tparallel // Sub tests need to run sequentially.
func TestAgent_Session_EnvironmentVariables(t *testing.T) {
t.Parallel()

manifest := agentsdk.Manifest{
EnvironmentVariables: map[string]string{
"MY_MANIFEST": "true",
"MY_OVERRIDE": "false",
"MY_SESSION_MANIFEST": "false",
},
}
banner := codersdk.ServiceBannerConfig{}
session := setupSSHSession(t, manifest, banner, nil, func(_ *agenttest.Client, opts *agent.Options) {
opts.EnvironmentVariables["MY_OVERRIDE"] = "true"
})

err := session.Setenv("MY_SESSION_MANIFEST", "true")
require.NoError(t, err)
err = session.Setenv("MY_SESSION", "true")
require.NoError(t, err)

command := "sh"
echoEnv := func(t *testing.T, w io.Writer, env string) {
if runtime.GOOS == "windows" {
_, err := fmt.Fprintf(w, "echo %%%s%%\r\n", env)
require.NoError(t, err)
} else {
_, err := fmt.Fprintf(w, "echo $%s\n", env)
require.NoError(t, err)
}
}
if runtime.GOOS == "windows" {
command = "cmd.exe"
}
stdin, err := session.StdinPipe()
require.NoError(t, err)
defer stdin.Close()
stdout, err := session.StdoutPipe()
require.NoError(t, err)

err = session.Start(command)
require.NoError(t, err)

// Context is fine here since we're not doing a parallel subtest.
ctx := testutil.Context(t, testutil.WaitLong)
go func() {
<-ctx.Done()
_ = session.Close()
}()

s := bufio.NewScanner(stdout)

//nolint:paralleltest // These tests need to run sequentially.
for k, partialV := range map[string]string{
"CODER": "true", // From the agent.
"MY_MANIFEST": "true", // From the manifest.
"MY_OVERRIDE": "true", // From the agent environment variables option, overrides manifest.
"MY_SESSION_MANIFEST": "false", // From the manifest, overrides session env.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review: This is not necessarily correct behavior (IMO), but I wanted to stay truthful to the existing implementation.

"MY_SESSION": "true", // From the session.
} {
t.Run(k, func(t *testing.T) {
echoEnv(t, stdin, k)
// Windows is unreliable, so keep scanning until we find a match.
for s.Scan() {
got := strings.TrimSpace(s.Text())
t.Logf("%s=%s", k, got)
if strings.Contains(got, partialV) {
break
}
}
if err := s.Err(); !errors.Is(err, io.EOF) {
require.NoError(t, err)
}
})
}
}

func TestAgent_GitSSH(t *testing.T) {
t.Parallel()
session := setupSSHSession(t, agentsdk.Manifest{}, codersdk.ServiceBannerConfig{}, nil)
Expand Down Expand Up @@ -1991,15 +2069,17 @@ func setupSSHSession(
manifest agentsdk.Manifest,
serviceBanner codersdk.ServiceBannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, func(c *agenttest.Client, _ *agent.Options) {
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
})
})
//nolint:dogsled
conn, _, _, fs, _ := setupAgent(t, manifest, 0, opts...)
if prepareFS != nil {
prepareFS(fs)
}
Expand Down Expand Up @@ -2057,6 +2137,7 @@ func setupAgent(t *testing.T, metadata agentsdk.Manifest, ptyTimeout time.Durati
Filesystem: fs,
Logger: logger.Named("agent"),
ReconnectingPTYTimeout: ptyTimeout,
EnvironmentVariables: map[string]string{},
}

for _, opt := range opts {
Expand Down
5 changes: 1 addition & 4 deletions agent/agentscripts/agentscripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"

"cdr.dev/slog/sloggers/slogtest"
Expand Down Expand Up @@ -72,10 +71,8 @@ func setup(t *testing.T, patchLogs func(ctx context.Context, req agentsdk.PatchL
}
fs := afero.NewMemMapFs()
logger := slogtest.Make(t, nil)
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, 0, "")
s, err := agentssh.NewServer(context.Background(), logger, prometheus.NewRegistry(), fs, nil)
require.NoError(t, err)
s.AgentToken = func() string { return "" }
s.Manifest = atomic.NewPointer(&agentsdk.Manifest{})
t.Cleanup(func() {
_ = s.Close()
})
Expand Down
Loading