Skip to content

feat: add agent exec pkg #15577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Nov 25, 2024
Next Next commit
feat: add agentexec pkg
  • Loading branch information
sreya committed Nov 18, 2024
commit 90281bd9d00f812606d0ee2279a0bb113dbcc1e0
75 changes: 75 additions & 0 deletions agent/agentexec/cli.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package agentexec

import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"slices"
"strconv"
"strings"
"syscall"

"golang.org/x/xerrors"
)

const (
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
EnvProcNiceScore = "CODER_PROC_NICE_SCORE"
)

// CLI runs the agent-exec command. It should only be called by the cli package.
func CLI(ctx context.Context, args []string, environ []string) error {
if runtime.GOOS != "linux" {
return xerrors.Errorf("agent-exec is only supported on Linux")
}

pid := os.Getpid()

oomScore, ok := envVal(environ, EnvProcOOMScore)
if !ok {
return xerrors.Errorf("missing %q", EnvProcOOMScore)
}

niceScore, ok := envVal(environ, EnvProcNiceScore)
if !ok {
return xerrors.Errorf("missing %q", EnvProcNiceScore)
}

score, err := strconv.Atoi(niceScore)
if err != nil {
return xerrors.Errorf("invalid nice score: %w", err)
}

err = syscall.Setpriority(syscall.PRIO_PROCESS, pid, score)
if err != nil {
return xerrors.Errorf("set nice score: %w", err)
}

oomPath := fmt.Sprintf("/proc/%d/oom_score_adj", pid)
err = os.WriteFile(oomPath, []byte(oomScore), 0o600)
if err != nil {
return xerrors.Errorf("set oom score: %w", err)
}

path, err := exec.LookPath(args[0])
if err != nil {
return xerrors.Errorf("look path: %w", err)
}

env := slices.DeleteFunc(environ, func(env string) bool {
return strings.HasPrefix(env, EnvProcOOMScore) || strings.HasPrefix(env, EnvProcNiceScore)
})

return syscall.Exec(path, args, env)
}

func envVal(environ []string, key string) (string, bool) {
for _, env := range environ {
if strings.HasPrefix(env, key+"=") {
return strings.TrimPrefix(env, key+"="), true
}
}
return "", false
}
17 changes: 17 additions & 0 deletions agent/agentexec/cmdtest/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package main

import (
"context"
"fmt"
"os"

"github.com/coder/coder/v2/agent/agentexec"
)

func main() {
err := agentexec.CLI(context.Background(), os.Args, os.Environ())
if err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
41 changes: 41 additions & 0 deletions agent/agentexec/exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package agentexec

import (
"context"
"os"
"os/exec"
"path/filepath"
"runtime"

"golang.org/x/xerrors"
)

const (
// EnvProcPrioMgmt is the environment variable that determines whether
// we attempt to manage process CPU and OOM Killer priority.
EnvProcPrioMgmt = "CODER_PROC_PRIO_MGMT"
)

// CommandContext returns an exec.Cmd that calls "coder agent-exec" prior to exec'ing
// the provided command if CODER_PROC_PRIO_MGMT is set, otherwise a normal exec.Cmd
// is returned. All instances of exec.Cmd should flow through this function to ensure
// proper resource constraints are applied to the child process.
func CommandContext(ctx context.Context, cmd string, args []string) (*exec.Cmd, error) {
_, enabled := envVal(os.Environ(), EnvProcPrioMgmt)
if runtime.GOOS != "linux" || !enabled {
return exec.CommandContext(ctx, cmd, args...), nil
}

executable, err := os.Executable()
if err != nil {
return nil, xerrors.Errorf("get executable: %w", err)
}

bin, err := filepath.EvalSymlinks(executable)
if err != nil {
return nil, xerrors.Errorf("eval symlinks: %w", err)
}

args = append([]string{"agent-exec", cmd}, args...)
return exec.CommandContext(ctx, bin, args...), nil
}
122 changes: 122 additions & 0 deletions cli/agentexec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package cli

import (
"context"
"fmt"
"os"
"os/exec"
"runtime"
"slices"
"strconv"
"strings"
"syscall"

"golang.org/x/xerrors"

"github.com/spf13/afero"

"github.com/coder/serpent"
)

const EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
const EnvProcNiceScore = "CODER_PROC_NICE_SCORE"

func (*RootCmd) agentExec() *serpent.Command {
return &serpent.Command{
Use: "agent-exec",
Hidden: true,
RawArgs: true,
Handler: func(inv *serpent.Invocation) error {
if runtime.GOOS != "linux" {
return xerrors.Errorf("agent-exec is only supported on Linux")
}

var (
pid = os.Getpid()
args = inv.Args
oomScore = inv.Environ.Get(EnvProcOOMScore)
niceScore = inv.Environ.Get(EnvProcNiceScore)

fs = fsFromContext(inv.Context())
syscaller = syscallerFromContext(inv.Context())
)

score, err := strconv.Atoi(niceScore)
if err != nil {
return xerrors.Errorf("invalid nice score: %w", err)
}

err = syscaller.Setpriority(syscall.PRIO_PROCESS, pid, score)
if err != nil {
return xerrors.Errorf("set nice score: %w", err)
}

oomPath := fmt.Sprintf("/proc/%d/oom_score_adj", pid)
err = afero.WriteFile(fs, oomPath, []byte(oomScore), 0o600)
if err != nil {
return xerrors.Errorf("set oom score: %w", err)
}

path, err := exec.LookPath(args[0])
if err != nil {
return xerrors.Errorf("look path: %w", err)
}

env := slices.DeleteFunc(inv.Environ.ToOS(), excludeKeys(EnvProcOOMScore, EnvProcNiceScore))

return syscall.Exec(path, args, env)
},
}
}

func excludeKeys(keys ...string) func(env string) bool {
return func(env string) bool {
for _, key := range keys {
if strings.HasPrefix(env, key+"=") {
return true
}
}
return false
}
}

type Syscaller interface {
Setpriority(int, int, int) error
Exec(string, []string, []string) error
}

type linuxSyscaller struct{}

func (linuxSyscaller) Setpriority(which, pid, nice int) error {
return syscall.Setpriority(which, pid, nice)
}

func (linuxSyscaller) Exec(path string, args, env []string) error {
return syscall.Exec(path, args, env)
}

type syscallerKey struct{}

func WithSyscaller(ctx context.Context, syscaller Syscaller) context.Context {
return context.WithValue(ctx, syscallerKey{}, syscaller)
}

func syscallerFromContext(ctx context.Context) Syscaller {
if syscaller, ok := ctx.Value(syscallerKey{}).(Syscaller); ok {
return syscaller
}
return linuxSyscaller{}
}

type fsKey struct{}

func WithFS(ctx context.Context, fs afero.Fs) context.Context {
return context.WithValue(ctx, fsKey{}, fs)
}

func fsFromContext(ctx context.Context) afero.Fs {
if fs, ok := ctx.Value(fsKey{}).(afero.Fs); ok {
return fs
}
return afero.NewOsFs()
}
49 changes: 49 additions & 0 deletions cli/agentexec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package cli_test

import (
"bytes"
"io"
"runtime"
"sync"
"testing"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/cli/clitest"
)

func TestAgentExec(t *testing.T) {
t.Parallel()

if runtime.GOOS != "linux" {
t.Skip("agent-exec is only supported on Linux")
}

t.Run("OK", func(t *testing.T) {
t.Parallel()

inv, _ := clitest.New(t, "agent-exec", "echo", "hello")
inv.Environ.Set(cli.EnvProcOOMScore, "1000")
inv.Environ.Set(cli.EnvProcNiceScore, "10")
var buf bytes.Buffer
wr := &syncWriter{W: &buf}
inv.Stdout = wr
inv.Stderr = wr
clitest.Start(t, inv)

require.Equal(t, "hello\n", buf.String())
})

}

type syncWriter struct {
W io.Writer
mu sync.Mutex
}

func (w *syncWriter) Write(p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
return w.W.Write(p)
}
1 change: 1 addition & 0 deletions cli/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ func (r *RootCmd) CoreSubcommands() []*serpent.Command {
r.whoami(),

// Hidden
r.agentExec(),
r.expCmd(),
r.gitssh(),
r.support(),
Expand Down