Skip to content

Commit a607a1b

Browse files
committed
Merge branch 'main' into updatetf
2 parents 2e983be + b948f2d commit a607a1b

File tree

20 files changed

+146
-117
lines changed

20 files changed

+146
-117
lines changed

agent/agent.go

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040

4141
type Options struct {
4242
ReconnectingPTYTimeout time.Duration
43+
EnvironmentVariables map[string]string
4344
Logger slog.Logger
4445
}
4546

@@ -67,6 +68,7 @@ func New(dialer Dialer, options *Options) io.Closer {
6768
logger: options.Logger,
6869
closeCancel: cancelFunc,
6970
closed: make(chan struct{}),
71+
envVars: options.EnvironmentVariables,
7072
}
7173
server.init(ctx)
7274
return server
@@ -84,24 +86,21 @@ type agent struct {
8486
closeMutex sync.Mutex
8587
closed chan struct{}
8688

87-
// Environment variables sent by Coder to inject for shell sessions.
88-
// These are atomic because values can change after reconnect.
89-
envVars atomic.Value
90-
directory atomic.String
91-
ownerEmail atomic.String
92-
ownerUsername atomic.String
89+
envVars map[string]string
90+
// metadata is atomic because values can change after reconnection.
91+
metadata atomic.Value
9392
startupScript atomic.Bool
9493
sshServer *ssh.Server
9594
}
9695

9796
func (a *agent) run(ctx context.Context) {
98-
var options Metadata
97+
var metadata Metadata
9998
var peerListener *peerbroker.Listener
10099
var err error
101100
// An exponential back-off occurs when the connection is failing to dial.
102101
// This is to prevent server spam in case of a coderd outage.
103102
for retrier := retry.New(50*time.Millisecond, 10*time.Second); retrier.Wait(ctx); {
104-
options, peerListener, err = a.dialer(ctx, a.logger)
103+
metadata, peerListener, err = a.dialer(ctx, a.logger)
105104
if err != nil {
106105
if errors.Is(err, context.Canceled) {
107106
return
@@ -120,15 +119,12 @@ func (a *agent) run(ctx context.Context) {
120119
return
121120
default:
122121
}
123-
a.directory.Store(options.Directory)
124-
a.envVars.Store(options.EnvironmentVariables)
125-
a.ownerEmail.Store(options.OwnerEmail)
126-
a.ownerUsername.Store(options.OwnerUsername)
122+
a.metadata.Store(metadata)
127123

128124
if a.startupScript.CAS(false, true) {
129125
// The startup script has not ran yet!
130126
go func() {
131-
err := a.runStartupScript(ctx, options.StartupScript)
127+
err := a.runStartupScript(ctx, metadata.StartupScript)
132128
if errors.Is(err, context.Canceled) {
133129
return
134130
}
@@ -175,7 +171,7 @@ func (*agent) runStartupScript(ctx context.Context, script string) error {
175171
writer, err = gsyslog.NewLogger(gsyslog.LOG_INFO, "USER", "coder-startup-script")
176172
if err != nil {
177173
// If the syslog isn't supported or cannot be created, use a text file in temp.
178-
writer, err = os.CreateTemp("", "coder-startup-script.txt")
174+
writer, err = os.CreateTemp("", "coder-startup-script-*.txt")
179175
if err != nil {
180176
return xerrors.Errorf("open startup script log file: %w", err)
181177
}
@@ -322,6 +318,15 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
322318
return nil, xerrors.Errorf("get user shell: %w", err)
323319
}
324320

321+
rawMetadata := a.metadata.Load()
322+
if rawMetadata == nil {
323+
return nil, xerrors.Errorf("no metadata was provided: %w", err)
324+
}
325+
metadata, valid := rawMetadata.(Metadata)
326+
if !valid {
327+
return nil, xerrors.Errorf("metadata is the wrong type: %T", metadata)
328+
}
329+
325330
// gliderlabs/ssh returns a command slice of zero
326331
// when a shell is requested.
327332
command := rawCommand
@@ -336,7 +341,7 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
336341
caller = "/c"
337342
}
338343
cmd := exec.CommandContext(ctx, shell, caller, command)
339-
cmd.Dir = a.directory.Load()
344+
cmd.Dir = metadata.Directory
340345
if cmd.Dir == "" {
341346
// Default to $HOME if a directory is not set!
342347
cmd.Dir = os.Getenv("HOME")
@@ -351,20 +356,24 @@ func (a *agent) createCommand(ctx context.Context, rawCommand string, env []stri
351356
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
352357
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
353358
// These prevent the user from having to specify _anything_ to successfully commit.
354-
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, a.ownerEmail.Load()))
355-
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_NAME=%s`, a.ownerUsername.Load()))
359+
// Both author and committer must be set!
360+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_EMAIL=%s`, metadata.OwnerEmail))
361+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_EMAIL=%s`, metadata.OwnerEmail))
362+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_AUTHOR_NAME=%s`, metadata.OwnerUsername))
363+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_NAME=%s`, metadata.OwnerUsername))
356364

357365
// Load environment variables passed via the agent.
358366
// These should override all variables we manually specify.
359-
envVars := a.envVars.Load()
360-
if envVars != nil {
361-
envVarMap, ok := envVars.(map[string]string)
362-
if ok {
363-
for key, value := range envVarMap {
364-
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
365-
}
366-
}
367+
for key, value := range metadata.EnvironmentVariables {
368+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
369+
}
370+
371+
// Agent-level environment variables should take over all!
372+
// This is used for setting agent-specific variables like "CODER_AGENT_TOKEN".
373+
for key, value := range a.envVars {
374+
cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", key, value))
367375
}
376+
368377
return cmd, nil
369378
}
370379

cli/agent.go

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,16 @@ import (
2121

2222
func workspaceAgent() *cobra.Command {
2323
var (
24-
rawURL string
25-
auth string
26-
token string
24+
auth string
2725
)
2826
cmd := &cobra.Command{
2927
Use: "agent",
3028
// This command isn't useful to manually execute.
3129
Hidden: true,
3230
RunE: func(cmd *cobra.Command, args []string) error {
33-
if rawURL == "" {
34-
return xerrors.New("CODER_URL must be set")
31+
rawURL, err := cmd.Flags().GetString(varAgentURL)
32+
if err != nil {
33+
return xerrors.Errorf("CODER_AGENT_URL must be set: %w", err)
3534
}
3635
coderURL, err := url.Parse(rawURL)
3736
if err != nil {
@@ -46,8 +45,9 @@ func workspaceAgent() *cobra.Command {
4645
var exchangeToken func(context.Context) (codersdk.WorkspaceAgentAuthenticateResponse, error)
4746
switch auth {
4847
case "token":
49-
if token == "" {
50-
return xerrors.Errorf("CODER_TOKEN must be set for token auth")
48+
token, err := cmd.Flags().GetString(varAgentToken)
49+
if err != nil {
50+
return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err)
5151
}
5252
client.SessionToken = token
5353
case "google-instance-identity":
@@ -115,27 +115,19 @@ func workspaceAgent() *cobra.Command {
115115
}
116116
}
117117

118-
cfg := createConfig(cmd)
119-
err = cfg.AgentSession().Write(client.SessionToken)
120-
if err != nil {
121-
return xerrors.Errorf("writing agent session token to config: %w", err)
122-
}
123-
err = cfg.URL().Write(client.URL.String())
124-
if err != nil {
125-
return xerrors.Errorf("writing agent url to config: %w", err)
126-
}
127-
128118
closer := agent.New(client.ListenWorkspaceAgent, &agent.Options{
129119
Logger: logger,
120+
EnvironmentVariables: map[string]string{
121+
// Override the "CODER_AGENT_TOKEN" variable in all
122+
// shells so "gitssh" works!
123+
"CODER_AGENT_TOKEN": client.SessionToken,
124+
},
130125
})
131126
<-cmd.Context().Done()
132127
return closer.Close()
133128
},
134129
}
135130

136-
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AUTH", "token", "Specify the authentication type to use for the agent")
137-
cliflag.StringVarP(cmd.Flags(), &rawURL, "url", "", "CODER_URL", "", "Specify the URL to access Coder")
138-
cliflag.StringVarP(cmd.Flags(), &token, "token", "", "CODER_TOKEN", "", "Specifies the authentication token to access Coder")
139-
131+
cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AGENT_AUTH", "token", "Specify the authentication type to use for the agent")
140132
return cmd
141133
}

cli/agent_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func TestWorkspaceAgent(t *testing.T) {
4646
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
4747
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
4848

49-
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--url", client.URL.String())
49+
cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String())
5050
ctx, cancelFunc := context.WithCancel(context.Background())
5151
defer cancelFunc()
5252
go func() {
@@ -100,7 +100,7 @@ func TestWorkspaceAgent(t *testing.T) {
100100
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
101101
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
102102

103-
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--url", client.URL.String())
103+
cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String())
104104
ctx, cancelFunc := context.WithCancel(context.Background())
105105
defer cancelFunc()
106106
go func() {
@@ -154,7 +154,7 @@ func TestWorkspaceAgent(t *testing.T) {
154154
workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID)
155155
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
156156

157-
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--url", client.URL.String())
157+
cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String())
158158
ctx, cancelFunc := context.WithCancel(context.Background())
159159
defer cancelFunc()
160160
go func() {

cli/cliflag/cliflag.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ import (
1919
"github.com/spf13/pflag"
2020
)
2121

22+
// String sets a string flag on the given flag set.
23+
func String(flagset *pflag.FlagSet, name, shorthand, env, def, usage string) {
24+
v, ok := os.LookupEnv(env)
25+
if !ok || v == "" {
26+
v = def
27+
}
28+
flagset.StringP(name, shorthand, v, fmtUsage(usage, env))
29+
}
30+
2231
// StringVarP sets a string flag on the given flag set.
2332
func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) {
2433
v, ok := os.LookupEnv(env)

cli/cliflag/cliflag_test.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,28 @@ import (
1616
//nolint:paralleltest
1717
func TestCliflag(t *testing.T) {
1818
t.Run("StringDefault", func(t *testing.T) {
19+
flagset, name, shorthand, env, usage := randomFlag()
20+
def, _ := cryptorand.String(10)
21+
cliflag.String(flagset, name, shorthand, env, def, usage)
22+
got, err := flagset.GetString(name)
23+
require.NoError(t, err)
24+
require.Equal(t, def, got)
25+
require.Contains(t, flagset.FlagUsages(), usage)
26+
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
27+
})
28+
29+
t.Run("StringEnvVar", func(t *testing.T) {
30+
flagset, name, shorthand, env, usage := randomFlag()
31+
envValue, _ := cryptorand.String(10)
32+
t.Setenv(env, envValue)
33+
def, _ := cryptorand.String(10)
34+
cliflag.String(flagset, name, shorthand, env, def, usage)
35+
got, err := flagset.GetString(name)
36+
require.NoError(t, err)
37+
require.Equal(t, envValue, got)
38+
})
39+
40+
t.Run("StringVarPDefault", func(t *testing.T) {
1941
var ptr string
2042
flagset, name, shorthand, env, usage := randomFlag()
2143
def, _ := cryptorand.String(10)
@@ -28,7 +50,7 @@ func TestCliflag(t *testing.T) {
2850
require.Contains(t, flagset.FlagUsages(), fmt.Sprintf(" - consumes $%s", env))
2951
})
3052

31-
t.Run("StringEnvVar", func(t *testing.T) {
53+
t.Run("StringVarPEnvVar", func(t *testing.T) {
3254
var ptr string
3355
flagset, name, shorthand, env, usage := randomFlag()
3456
envValue, _ := cryptorand.String(10)

cli/config/file.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ func (r Root) Organization() File {
2121
return File(filepath.Join(string(r), "organization"))
2222
}
2323

24-
func (r Root) AgentSession() File {
25-
return File(filepath.Join(string(r), "agentsession"))
26-
}
27-
2824
// File provides convenience methods for interacting with *os.File.
2925
type File string
3026

cli/gitssh.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cli
22

33
import (
44
"fmt"
5-
"net/url"
65
"os"
76
"os/exec"
87
"strings"
@@ -11,7 +10,6 @@ import (
1110
"golang.org/x/xerrors"
1211

1312
"github.com/coder/coder/cli/cliui"
14-
"github.com/coder/coder/codersdk"
1513
)
1614

1715
func gitssh() *cobra.Command {
@@ -20,22 +18,10 @@ func gitssh() *cobra.Command {
2018
Hidden: true,
2119
Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`,
2220
RunE: func(cmd *cobra.Command, args []string) error {
23-
cfg := createConfig(cmd)
24-
rawURL, err := cfg.URL().Read()
21+
client, err := createAgentClient(cmd)
2522
if err != nil {
26-
return xerrors.Errorf("read agent url from config: %w", err)
23+
return xerrors.Errorf("create agent client: %w", err)
2724
}
28-
parsedURL, err := url.Parse(rawURL)
29-
if err != nil {
30-
return xerrors.Errorf("parse agent url from config: %w", err)
31-
}
32-
session, err := cfg.AgentSession().Read()
33-
if err != nil {
34-
return xerrors.Errorf("read agent session from config: %w", err)
35-
}
36-
client := codersdk.New(parsedURL)
37-
client.SessionToken = session
38-
3925
key, err := client.AgentGitSSHKey(cmd.Context())
4026
if err != nil {
4127
return xerrors.Errorf("get agent git ssh token: %w", err)

cli/gitssh_test.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@ import (
99

1010
"github.com/gliderlabs/ssh"
1111
"github.com/google/uuid"
12-
"github.com/spf13/cobra"
1312
"github.com/stretchr/testify/require"
1413
gossh "golang.org/x/crypto/ssh"
1514

1615
"github.com/coder/coder/cli/clitest"
17-
"github.com/coder/coder/cli/config"
1816
"github.com/coder/coder/coderd/coderdtest"
1917
"github.com/coder/coder/codersdk"
2018
"github.com/coder/coder/provisioner/echo"
@@ -61,7 +59,7 @@ func TestGitSSH(t *testing.T) {
6159
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
6260

6361
// start workspace agent
64-
cmd, root := clitest.New(t, "agent", "--token", agentToken, "--url", client.URL.String())
62+
cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String())
6563
agentClient := &*client
6664
clitest.SetupConfig(t, agentClient, root)
6765
ctx, cancelFunc := context.WithCancel(context.Background())
@@ -92,7 +90,7 @@ func TestGitSSH(t *testing.T) {
9290
// as long as we get a successful session we don't care if the server errors
9391
_ = ssh.Serve(l, func(s ssh.Session) {
9492
atomic.AddInt64(&inc, 1)
95-
t.Log("got authenticated sesion")
93+
t.Log("got authenticated session")
9694
err := s.Exit(0)
9795
require.NoError(t, err)
9896
}, publicKeyOption)
@@ -101,22 +99,10 @@ func TestGitSSH(t *testing.T) {
10199
// start ssh session
102100
addr, ok := l.Addr().(*net.TCPAddr)
103101
require.True(t, ok)
104-
cfgDir := createConfig(cmd)
105102
// set to agent config dir
106-
cmd, root = clitest.New(t, "gitssh", "--global-config="+string(cfgDir), "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "127.0.0.1")
107-
clitest.SetupConfig(t, agentClient, root)
108-
103+
cmd, _ = clitest.New(t, "gitssh", "--agent-url", agentClient.URL.String(), "--agent-token", agentToken, "--", fmt.Sprintf("-p%d", addr.Port), "-o", "StrictHostKeyChecking=no", "127.0.0.1")
109104
err = cmd.ExecuteContext(context.Background())
110105
require.NoError(t, err)
111106
require.EqualValues(t, 1, inc)
112107
})
113108
}
114-
115-
// createConfig consumes the global configuration flag to produce a config root.
116-
func createConfig(cmd *cobra.Command) config.Root {
117-
globalRoot, err := cmd.Flags().GetString("global-config")
118-
if err != nil {
119-
panic(err)
120-
}
121-
return config.Root(globalRoot)
122-
}

0 commit comments

Comments
 (0)