Skip to content

Commit 744a00a

Browse files
authored
feat: Add GIT_COMMITTER information to agent env vars (#1171)
This makes setting up git a bit simpler, and users can always override these values! We'll probably add a way to disable our Git integration anyways, so these could be part of that.
1 parent 877854a commit 744a00a

File tree

6 files changed

+110
-70
lines changed

6 files changed

+110
-70
lines changed

agent/agent.go

+24-15
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,14 @@ import (
3333
"golang.org/x/xerrors"
3434
)
3535

36-
type Options struct {
37-
EnvironmentVariables map[string]string
38-
StartupScript string
36+
type Metadata struct {
37+
OwnerEmail string `json:"owner_email"`
38+
OwnerUsername string `json:"owner_username"`
39+
EnvironmentVariables map[string]string `json:"environment_variables"`
40+
StartupScript string `json:"startup_script"`
3941
}
4042

41-
type Dialer func(ctx context.Context, logger slog.Logger) (*Options, *peerbroker.Listener, error)
43+
type Dialer func(ctx context.Context, logger slog.Logger) (Metadata, *peerbroker.Listener, error)
4244

4345
func New(dialer Dialer, logger slog.Logger) io.Closer {
4446
ctx, cancelFunc := context.WithCancel(context.Background())
@@ -62,14 +64,16 @@ type agent struct {
6264
closed chan struct{}
6365

6466
// Environment variables sent by Coder to inject for shell sessions.
65-
// This is atomic because values can change after reconnect.
67+
// These are atomic because values can change after reconnect.
6668
envVars atomic.Value
69+
ownerEmail atomic.String
70+
ownerUsername atomic.String
6771
startupScript atomic.Bool
6872
sshServer *ssh.Server
6973
}
7074

7175
func (a *agent) run(ctx context.Context) {
72-
var options *Options
76+
var options Metadata
7377
var peerListener *peerbroker.Listener
7478
var err error
7579
// An exponential back-off occurs when the connection is failing to dial.
@@ -95,6 +99,8 @@ func (a *agent) run(ctx context.Context) {
9599
default:
96100
}
97101
a.envVars.Store(options.EnvironmentVariables)
102+
a.ownerEmail.Store(options.OwnerEmail)
103+
a.ownerUsername.Store(options.OwnerUsername)
98104

99105
if a.startupScript.CAS(false, true) {
100106
// The startup script has not ran yet!
@@ -303,8 +309,20 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
303309
}
304310
cmd := exec.CommandContext(session.Context(), shell, caller, command)
305311
cmd.Env = append(os.Environ(), session.Environ()...)
312+
executablePath, err := os.Executable()
313+
if err != nil {
314+
return xerrors.Errorf("getting os executable: %w", err)
315+
}
316+
// Git on Windows resolves with UNIX-style paths.
317+
// If using backslashes, it's unable to find the executable.
318+
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
319+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
320+
// These prevent the user from having to specify _anything_ to successfully commit.
321+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_EMAIL=%s`, a.ownerEmail.Load()))
322+
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_COMMITTER_NAME=%s`, a.ownerUsername.Load()))
306323

307324
// Load environment variables passed via the agent.
325+
// These should override all variables we manually specify.
308326
envVars := a.envVars.Load()
309327
if envVars != nil {
310328
envVarMap, ok := envVars.(map[string]string)
@@ -315,15 +333,6 @@ func (a *agent) handleSSHSession(session ssh.Session) error {
315333
}
316334
}
317335

318-
executablePath, err := os.Executable()
319-
if err != nil {
320-
return xerrors.Errorf("getting os executable: %w", err)
321-
}
322-
// Git on Windows resolves with UNIX-style paths.
323-
// If using backslashes, it's unable to find the executable.
324-
executablePath = strings.ReplaceAll(executablePath, "\\", "/")
325-
cmd.Env = append(cmd.Env, fmt.Sprintf(`GIT_SSH_COMMAND=%s gitssh --`, executablePath))
326-
327336
sshPty, windowSize, isPty := session.Pty()
328337
if isPty {
329338
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", sshPty.Term))

agent/agent_test.go

+10-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestAgent(t *testing.T) {
4040
t.Parallel()
4141
t.Run("SessionExec", func(t *testing.T) {
4242
t.Parallel()
43-
session := setupSSHSession(t, nil)
43+
session := setupSSHSession(t, agent.Metadata{})
4444

4545
command := "echo test"
4646
if runtime.GOOS == "windows" {
@@ -53,7 +53,7 @@ func TestAgent(t *testing.T) {
5353

5454
t.Run("GitSSH", func(t *testing.T) {
5555
t.Parallel()
56-
session := setupSSHSession(t, nil)
56+
session := setupSSHSession(t, agent.Metadata{})
5757
command := "sh -c 'echo $GIT_SSH_COMMAND'"
5858
if runtime.GOOS == "windows" {
5959
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
@@ -71,7 +71,7 @@ func TestAgent(t *testing.T) {
7171
// it seems like it could be either.
7272
t.Skip("ConPTY appears to be inconsistent on Windows.")
7373
}
74-
session := setupSSHSession(t, nil)
74+
session := setupSSHSession(t, agent.Metadata{})
7575
command := "bash"
7676
if runtime.GOOS == "windows" {
7777
command = "cmd.exe"
@@ -131,7 +131,7 @@ func TestAgent(t *testing.T) {
131131

132132
t.Run("SFTP", func(t *testing.T) {
133133
t.Parallel()
134-
sshClient, err := setupAgent(t, nil).SSHClient()
134+
sshClient, err := setupAgent(t, agent.Metadata{}).SSHClient()
135135
require.NoError(t, err)
136136
client, err := sftp.NewClient(sshClient)
137137
require.NoError(t, err)
@@ -148,7 +148,7 @@ func TestAgent(t *testing.T) {
148148
t.Parallel()
149149
key := "EXAMPLE"
150150
value := "value"
151-
session := setupSSHSession(t, &agent.Options{
151+
session := setupSSHSession(t, agent.Metadata{
152152
EnvironmentVariables: map[string]string{
153153
key: value,
154154
},
@@ -166,7 +166,7 @@ func TestAgent(t *testing.T) {
166166
t.Parallel()
167167
tempPath := filepath.Join(os.TempDir(), "content.txt")
168168
content := "somethingnice"
169-
setupAgent(t, &agent.Options{
169+
setupAgent(t, agent.Metadata{
170170
StartupScript: "echo " + content + " > " + tempPath,
171171
})
172172
var gotContent string
@@ -191,7 +191,7 @@ func TestAgent(t *testing.T) {
191191
}
192192

193193
func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exec.Cmd {
194-
agentConn := setupAgent(t, nil)
194+
agentConn := setupAgent(t, agent.Metadata{})
195195
listener, err := net.Listen("tcp", "127.0.0.1:0")
196196
require.NoError(t, err)
197197
go func() {
@@ -219,20 +219,17 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
219219
return exec.Command("ssh", args...)
220220
}
221221

222-
func setupSSHSession(t *testing.T, options *agent.Options) *ssh.Session {
222+
func setupSSHSession(t *testing.T, options agent.Metadata) *ssh.Session {
223223
sshClient, err := setupAgent(t, options).SSHClient()
224224
require.NoError(t, err)
225225
session, err := sshClient.NewSession()
226226
require.NoError(t, err)
227227
return session
228228
}
229229

230-
func setupAgent(t *testing.T, options *agent.Options) *agent.Conn {
231-
if options == nil {
232-
options = &agent.Options{}
233-
}
230+
func setupAgent(t *testing.T, options agent.Metadata) *agent.Conn {
234231
client, server := provisionersdk.TransportPipe()
235-
closer := agent.New(func(ctx context.Context, logger slog.Logger) (*agent.Options, *peerbroker.Listener, error) {
232+
closer := agent.New(func(ctx context.Context, logger slog.Logger) (agent.Metadata, *peerbroker.Listener, error) {
236233
listener, err := peerbroker.Listen(server, nil)
237234
return options, listener, err
238235
}, slogtest.Make(t, nil).Leveled(slog.LevelDebug))

coderd/coderd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ func New(options *Options) (http.Handler, func()) {
201201
r.Post("/google-instance-identity", api.postWorkspaceAuthGoogleInstanceIdentity)
202202
r.Route("/me", func(r chi.Router) {
203203
r.Use(httpmw.ExtractWorkspaceAgent(options.Database))
204-
r.Get("/", api.workspaceAgentMe)
204+
r.Get("/metadata", api.workspaceAgentMetadata)
205205
r.Get("/listen", api.workspaceAgentListen)
206206
r.Get("/gitsshkey", api.agentGitSSHKey)
207207
r.Get("/turn", api.workspaceAgentTurn)

coderd/workspaceagents.go

+59-25
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"nhooyr.io/websocket"
1616

1717
"cdr.dev/slog"
18+
"github.com/coder/coder/agent"
1819
"github.com/coder/coder/coderd/database"
1920
"github.com/coder/coder/coderd/httpapi"
2021
"github.com/coder/coder/coderd/httpmw"
@@ -25,8 +26,8 @@ import (
2526
)
2627

2728
func (api *api) workspaceAgent(rw http.ResponseWriter, r *http.Request) {
28-
agent := httpmw.WorkspaceAgentParam(r)
29-
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
29+
workspaceAgent := httpmw.WorkspaceAgentParam(r)
30+
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
3031
if err != nil {
3132
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
3233
Message: fmt.Sprintf("convert workspace agent: %s", err),
@@ -43,8 +44,8 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
4344
api.websocketWaitMutex.Unlock()
4445
defer api.websocketWaitGroup.Done()
4546

46-
agent := httpmw.WorkspaceAgentParam(r)
47-
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
47+
workspaceAgent := httpmw.WorkspaceAgentParam(r)
48+
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
4849
if err != nil {
4950
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
5051
Message: fmt.Sprintf("convert workspace agent: %s", err),
@@ -78,7 +79,7 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
7879
return
7980
}
8081
err = peerbroker.ProxyListen(r.Context(), session, peerbroker.ProxyOptions{
81-
ChannelID: agent.ID.String(),
82+
ChannelID: workspaceAgent.ID.String(),
8283
Logger: api.Logger.Named("peerbroker-proxy-dial"),
8384
Pubsub: api.Pubsub,
8485
})
@@ -88,16 +89,49 @@ func (api *api) workspaceAgentDial(rw http.ResponseWriter, r *http.Request) {
8889
}
8990
}
9091

91-
func (api *api) workspaceAgentMe(rw http.ResponseWriter, r *http.Request) {
92-
agent := httpmw.WorkspaceAgent(r)
93-
apiAgent, err := convertWorkspaceAgent(agent, api.AgentConnectionUpdateFrequency)
92+
func (api *api) workspaceAgentMetadata(rw http.ResponseWriter, r *http.Request) {
93+
workspaceAgent := httpmw.WorkspaceAgent(r)
94+
apiAgent, err := convertWorkspaceAgent(workspaceAgent, api.AgentConnectionUpdateFrequency)
9495
if err != nil {
9596
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
9697
Message: fmt.Sprintf("convert workspace agent: %s", err),
9798
})
9899
return
99100
}
100-
httpapi.Write(rw, http.StatusOK, apiAgent)
101+
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
102+
if err != nil {
103+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
104+
Message: fmt.Sprintf("get workspace resource: %s", err),
105+
})
106+
return
107+
}
108+
build, err := api.Database.GetWorkspaceBuildByJobID(r.Context(), resource.JobID)
109+
if err != nil {
110+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
111+
Message: fmt.Sprintf("get workspace build: %s", err),
112+
})
113+
return
114+
}
115+
workspace, err := api.Database.GetWorkspaceByID(r.Context(), build.WorkspaceID)
116+
if err != nil {
117+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
118+
Message: fmt.Sprintf("get workspace build: %s", err),
119+
})
120+
return
121+
}
122+
owner, err := api.Database.GetUserByID(r.Context(), workspace.OwnerID)
123+
if err != nil {
124+
httpapi.Write(rw, http.StatusInternalServerError, httpapi.Response{
125+
Message: fmt.Sprintf("get workspace build: %s", err),
126+
})
127+
return
128+
}
129+
httpapi.Write(rw, http.StatusOK, agent.Metadata{
130+
OwnerEmail: owner.Email,
131+
OwnerUsername: owner.Username,
132+
EnvironmentVariables: apiAgent.EnvironmentVariables,
133+
StartupScript: apiAgent.StartupScript,
134+
})
101135
}
102136

103137
func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
@@ -106,7 +140,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
106140
api.websocketWaitMutex.Unlock()
107141
defer api.websocketWaitGroup.Done()
108142

109-
agent := httpmw.WorkspaceAgent(r)
143+
workspaceAgent := httpmw.WorkspaceAgent(r)
110144
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
111145
CompressionMode: websocket.CompressionDisabled,
112146
})
@@ -116,7 +150,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
116150
})
117151
return
118152
}
119-
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), agent.ResourceID)
153+
resource, err := api.Database.GetWorkspaceResourceByID(r.Context(), workspaceAgent.ResourceID)
120154
if err != nil {
121155
httpapi.Write(rw, http.StatusBadRequest, httpapi.Response{
122156
Message: fmt.Sprintf("accept websocket: %s", err),
@@ -135,7 +169,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
135169
return
136170
}
137171
closer, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(session)), peerbroker.ProxyOptions{
138-
ChannelID: agent.ID.String(),
172+
ChannelID: workspaceAgent.ID.String(),
139173
Pubsub: api.Pubsub,
140174
Logger: api.Logger.Named("peerbroker-proxy-listen"),
141175
})
@@ -144,7 +178,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
144178
return
145179
}
146180
defer closer.Close()
147-
firstConnectedAt := agent.FirstConnectedAt
181+
firstConnectedAt := workspaceAgent.FirstConnectedAt
148182
if !firstConnectedAt.Valid {
149183
firstConnectedAt = sql.NullTime{
150184
Time: database.Now(),
@@ -155,10 +189,10 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
155189
Time: database.Now(),
156190
Valid: true,
157191
}
158-
disconnectedAt := agent.DisconnectedAt
192+
disconnectedAt := workspaceAgent.DisconnectedAt
159193
updateConnectionTimes := func() error {
160194
err = api.Database.UpdateWorkspaceAgentConnectionByID(r.Context(), database.UpdateWorkspaceAgentConnectionByIDParams{
161-
ID: agent.ID,
195+
ID: workspaceAgent.ID,
162196
FirstConnectedAt: firstConnectedAt,
163197
LastConnectedAt: lastConnectedAt,
164198
DisconnectedAt: disconnectedAt,
@@ -205,7 +239,7 @@ func (api *api) workspaceAgentListen(rw http.ResponseWriter, r *http.Request) {
205239
return
206240
}
207241

208-
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", agent))
242+
api.Logger.Info(r.Context(), "accepting agent", slog.F("resource", resource), slog.F("agent", workspaceAgent))
209243

210244
ticker := time.NewTicker(api.AgentConnectionUpdateFrequency)
211245
defer ticker.Stop()
@@ -294,7 +328,7 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
294328
return codersdk.WorkspaceAgent{}, xerrors.Errorf("unmarshal: %w", err)
295329
}
296330
}
297-
agent := codersdk.WorkspaceAgent{
331+
workspaceAgent := codersdk.WorkspaceAgent{
298332
ID: dbAgent.ID,
299333
CreatedAt: dbAgent.CreatedAt,
300334
UpdatedAt: dbAgent.UpdatedAt,
@@ -307,31 +341,31 @@ func convertWorkspaceAgent(dbAgent database.WorkspaceAgent, agentUpdateFrequency
307341
EnvironmentVariables: envs,
308342
}
309343
if dbAgent.FirstConnectedAt.Valid {
310-
agent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
344+
workspaceAgent.FirstConnectedAt = &dbAgent.FirstConnectedAt.Time
311345
}
312346
if dbAgent.LastConnectedAt.Valid {
313-
agent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
347+
workspaceAgent.LastConnectedAt = &dbAgent.LastConnectedAt.Time
314348
}
315349
if dbAgent.DisconnectedAt.Valid {
316-
agent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
350+
workspaceAgent.DisconnectedAt = &dbAgent.DisconnectedAt.Time
317351
}
318352
switch {
319353
case !dbAgent.FirstConnectedAt.Valid:
320354
// If the agent never connected, it's waiting for the compute
321355
// to start up.
322-
agent.Status = codersdk.WorkspaceAgentConnecting
356+
workspaceAgent.Status = codersdk.WorkspaceAgentConnecting
323357
case dbAgent.DisconnectedAt.Time.After(dbAgent.LastConnectedAt.Time):
324358
// If we've disconnected after our last connection, we know the
325359
// agent is no longer connected.
326-
agent.Status = codersdk.WorkspaceAgentDisconnected
360+
workspaceAgent.Status = codersdk.WorkspaceAgentDisconnected
327361
case agentUpdateFrequency*2 >= database.Now().Sub(dbAgent.LastConnectedAt.Time):
328362
// The connection updated it's timestamp within the update frequency.
329363
// We multiply by two to allow for some lag.
330-
agent.Status = codersdk.WorkspaceAgentConnected
364+
workspaceAgent.Status = codersdk.WorkspaceAgentConnected
331365
case database.Now().Sub(dbAgent.LastConnectedAt.Time) > agentUpdateFrequency*2:
332366
// The connection died without updating the last connected.
333-
agent.Status = codersdk.WorkspaceAgentDisconnected
367+
workspaceAgent.Status = codersdk.WorkspaceAgentDisconnected
334368
}
335369

336-
return agent, nil
370+
return workspaceAgent, nil
337371
}

coderd/workspaceagents_test.go

-2
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ func TestWorkspaceAgentListen(t *testing.T) {
102102
})
103103
_, err = conn.Ping()
104104
require.NoError(t, err)
105-
_, err = agentClient.WorkspaceAgent(context.Background(), codersdk.Me)
106-
require.NoError(t, err)
107105
}
108106

109107
func TestWorkspaceAgentTURN(t *testing.T) {

0 commit comments

Comments
 (0)