Skip to content

feat: Use environment variables and startup script in agent #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"gographviz",
"goleak",
"gossh",
"gsyslog",
"hashicorp",
"hclsyntax",
"httpmw",
Expand Down
92 changes: 87 additions & 5 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,13 @@ import (
"os"
"os/exec"
"os/user"
"runtime"
"sync"
"time"

gsyslog "github.com/hashicorp/go-syslog"
"go.uber.org/atomic"

"cdr.dev/slog"
"github.com/coder/coder/agent/usershell"
"github.com/coder/coder/peer"
Expand All @@ -29,10 +33,11 @@ import (
)

type Options struct {
Logger slog.Logger
EnvironmentVariables map[string]string
StartupScript string
}

type Dialer func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error)
type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)

func New(dialer Dialer, logger slog.Logger) io.Closer {
ctx, cancelFunc := context.WithCancel(context.Background())
Expand All @@ -55,16 +60,21 @@ type agent struct {
closeMutex sync.Mutex
closed chan struct{}

sshServer *ssh.Server
// Environment variables sent by Coder to inject for shell sessions.
// This is atomic because values can change after reconnect.
envVars atomic.Value
startupScript atomic.Bool
sshServer *ssh.Server
}

func (a *agent) run(ctx context.Context) {
var options *Options
var peerListener *peerbroker.Listener
var err error
// An exponential back-off occurs when the connection is failing to dial.
// This is to prevent server spam in case of a coderd outage.
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
peerListener, err = a.dialer(ctx, a.logger)
options, peerListener, err = a.dialer(ctx, a.logger)
if err != nil {
if errors.Is(err, context.Canceled) {
return
Expand All @@ -83,6 +93,20 @@ func (a *agent) run(ctx context.Context) {
return
default:
}
a.envVars.Store(options.EnvironmentVariables)

if a.startupScript.CAS(false, true) {
// The startup script has not ran yet!
go func() {
err := a.runStartupScript(ctx, options.StartupScript)
if errors.Is(err, context.Canceled) {
return
}
if err != nil {
a.logger.Warn(ctx, "agent script failed", slog.Error(err))
}
}()
}

for {
conn, err := peerListener.Accept()
Expand All @@ -101,6 +125,48 @@ func (a *agent) run(ctx context.Context) {
}
}

func (*agent) runStartupScript(ctx context.Context, script string) error {
if script == "" {
return nil
}
currentUser, err := user.Current()
if err != nil {
return xerrors.Errorf("get current user: %w", err)
}
username := currentUser.Username

shell, err := usershell.Get(username)
if err != nil {
return xerrors.Errorf("get user shell: %w", err)
}

var writer io.WriteCloser
// Attempt to use the syslog to write startup information.
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
if err != nil {
// If the syslog isn't supported or cannot be created, use a text file in temp.
writer, err = os.CreateTemp("", "coder-startup-script.txt")
if err != nil {
return xerrors.Errorf("open startup script log file: %w", err)
}
}
defer func() {
_ = writer.Close()
}()
caller := "-c"
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(ctx, shell, caller, script)
cmd.Stdout = writer
cmd.Stderr = writer
err = cmd.Run()
if err != nil {
return xerrors.Errorf("run: %w", err)
}
return nil
}

func (a *agent) handlePeerConn(ctx context.Context, conn *peer.Conn) {
go func() {
select {
Expand Down Expand Up @@ -230,8 +296,24 @@ func (a *agent) handleSSHSession(session ssh.Session) error {

// OpenSSH executes all commands with the users current shell.
// We replicate that behavior for IDE support.
cmd := exec.CommandContext(session.Context(), shell, "-c", command)
caller := "-c"
if runtime.GOOS == "windows" {
caller = "/c"
}
cmd := exec.CommandContext(session.Context(), shell, caller, command)
cmd.Env = append(os.Environ(), session.Environ()...)

// Load environment variables passed via the agent.
envVars := a.envVars.Load()
if envVars != nil {
envVarMap, ok := envVars.(map[string]string)
if ok {
for key, value := range envVarMap {
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
}
}
}

executablePath, err := os.Executable()
if err != nil {
return xerrors.Errorf("getting os executable: %w", err)
Expand Down
69 changes: 59 additions & 10 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/pion/webrtc/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"golang.org/x/crypto/ssh"
"golang.org/x/text/encoding/unicode"
"golang.org/x/text/transform"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
Expand All @@ -37,7 +40,7 @@ func TestAgent(t *testing.T) {
t.Parallel()
t.Run("SessionExec", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t)
session := setupSSHSession(t, nil)

command := "echo test"
if runtime.GOOS == "windows" {
Expand All @@ -50,7 +53,7 @@ func TestAgent(t *testing.T) {

t.Run("GitSSH", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t)
session := setupSSHSession(t, nil)
command := "sh -c 'echo $GIT_SSH_COMMAND'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
Expand All @@ -62,7 +65,7 @@ func TestAgent(t *testing.T) {

t.Run("SessionTTY", func(t *testing.T) {
t.Parallel()
session := setupSSHSession(t)
session := setupSSHSession(t, nil)
command := "bash"
if runtime.GOOS == "windows" {
command = "cmd.exe"
Expand Down Expand Up @@ -117,7 +120,7 @@ func TestAgent(t *testing.T) {

t.Run("SFTP", func(t *testing.T) {
t.Parallel()
sshClient, err := setupAgent(t).SSHClient()
sshClient, err := setupAgent(t, nil).SSHClient()
require.NoError(t, err)
client, err := sftp.NewClient(sshClient)
require.NoError(t, err)
Expand All @@ -129,10 +132,52 @@ func TestAgent(t *testing.T) {
_, err = os.Stat(tempFile)
require.NoError(t, err)
})

t.Run("EnvironmentVariables", func(t *testing.T) {
t.Parallel()
key := "EXAMPLE"
value := "value"
session := setupSSHSession(t, &agent.Options{
EnvironmentVariables: map[string]string{
key: value,
},
})
command := "sh -c 'echo $" + key + "'"
if runtime.GOOS == "windows" {
command = "cmd.exe /c echo %" + key + "%"
}
output, err := session.Output(command)
require.NoError(t, err)
require.Equal(t, value, strings.TrimSpace(string(output)))
})

t.Run("StartupScript", func(t *testing.T) {
t.Parallel()
tempPath := filepath.Join(os.TempDir(), "content.txt")
content := "somethingnice"
setupAgent(t, &agent.Options{
StartupScript: "echo " + content + " > " + tempPath,
})
var gotContent string
require.Eventually(t, func() bool {
content, err := os.ReadFile(tempPath)
if err != nil {
return false
}
if runtime.GOOS == "windows" {
// Windows uses UTF16! 🪟🪟🪟
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👻

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spooky is right

content, _, err = transform.Bytes(unicode.UTF16(unicode.LittleEndian, unicode.UseBOM).NewDecoder(), content)
require.NoError(t, err)
}
gotContent = string(content)
return true
}, 15*time.Second, 100*time.Millisecond)
require.Equal(t, content, strings.TrimSpace(gotContent))
})
}

func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
agentConn := setupAgent(t)
agentConn := setupAgent(t, nil)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
go func() {
Expand Down Expand Up @@ -160,18 +205,22 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
return exec.Command("ssh", args...)
}

func setupSSHSession(t *testing.T) *ssh.Session {
sshClient, err := setupAgent(t).SSHClient()
func setupSSHSession(t *testing.T, options *agent.Options) *ssh.Session {
sshClient, err := setupAgent(t, options).SSHClient()
require.NoError(t, err)
session, err := sshClient.NewSession()
require.NoError(t, err)
return session
}

func setupAgent(t *testing.T) *agent.Conn {
func setupAgent(t *testing.T, options *agent.Options) *agent.Conn {
if options == nil {
options = &agent.Options{}
}
client, server := provisionersdk.TransportPipe()
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*peerbroker.Listener, error) {
return peerbroker.Listen(server, nil)
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*agent.Options, *peerbroker.Listener, error) {
listener, err := peerbroker.Listen(server, nil)
return options, listener, err
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))
t.Cleanup(func() {
_ = client.Close()
Expand Down
5 changes: 3 additions & 2 deletions cli/gitssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ import (
"os/exec"
"strings"

"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
"github.com/spf13/cobra"
"golang.org/x/xerrors"

"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/codersdk"
)

func gitssh() *cobra.Command {
Expand Down
3 changes: 2 additions & 1 deletion coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ func New(options *Options) (http.Handler, func()) {
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
r.Route("/me", func(r chi.Router) {
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
r.Get("/", api.workspaceAgentListen)
r.Get("/", api.workspaceAgentMe)
r.Get("/listen", api.workspaceAgentListen)
r.Get("/gitsshkey", api.agentGitSSHKey)
r.Get("/turn", api.workspaceAgentTurn)
r.Get("/iceservers", api.workspaceAgentICEServers)
Expand Down
12 changes: 12 additions & 0 deletions coderd/workspaceagents.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
}
}

func (api *api) workspaceAgentMe(rw http.ResponseWriter, r *http.Request) {
agent := httpmw.WorkspaceAgent(r)
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
if err != nil {
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
Message: fmt.Sprintf("convert workspace agent: %s", err),
})
return
}
httpapi.Write(rw, http.StatusOK, apiAgent)
}

func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
Expand Down
2 changes: 2 additions & 0 deletions coderd/workspaceagents_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func TestWorkspaceAgentListen(t *testing.T) {
})
_, err = conn.Ping()
require.NoError(t, err)
_, err = agentClient.WorkspaceAgent(context.Background(), codersdk.Me)
require.NoError(t, err)
}

func TestWorkspaceAgentTURN(t *testing.T) {
Expand Down
Loading