@@ -12,12 +12,15 @@ import (
12
12
"strconv"
13
13
"strings"
14
14
"testing"
15
+ "time"
15
16
16
17
"github.com/pion/webrtc/v3"
17
18
"github.com/pkg/sftp"
18
19
"github.com/stretchr/testify/require"
19
20
"go.uber.org/goleak"
20
21
"golang.org/x/crypto/ssh"
22
+ "golang.org/x/text/encoding/unicode"
23
+ "golang.org/x/text/transform"
21
24
22
25
"cdr.dev/slog"
23
26
"cdr.dev/slog/sloggers/slogtest"
@@ -37,7 +40,7 @@ func TestAgent(t *testing.T) {
37
40
t .Parallel ()
38
41
t .Run ("SessionExec" , func (t * testing.T ) {
39
42
t .Parallel ()
40
- session := setupSSHSession (t )
43
+ session := setupSSHSession (t , nil )
41
44
42
45
command := "echo test"
43
46
if runtime .GOOS == "windows" {
@@ -50,7 +53,7 @@ func TestAgent(t *testing.T) {
50
53
51
54
t .Run ("GitSSH" , func (t * testing.T ) {
52
55
t .Parallel ()
53
- session := setupSSHSession (t )
56
+ session := setupSSHSession (t , nil )
54
57
command := "sh -c 'echo $GIT_SSH_COMMAND'"
55
58
if runtime .GOOS == "windows" {
56
59
command = "cmd.exe /c echo %GIT_SSH_COMMAND%"
@@ -62,7 +65,7 @@ func TestAgent(t *testing.T) {
62
65
63
66
t .Run ("SessionTTY" , func (t * testing.T ) {
64
67
t .Parallel ()
65
- session := setupSSHSession (t )
68
+ session := setupSSHSession (t , nil )
66
69
command := "bash"
67
70
if runtime .GOOS == "windows" {
68
71
command = "cmd.exe"
@@ -117,7 +120,7 @@ func TestAgent(t *testing.T) {
117
120
118
121
t .Run ("SFTP" , func (t * testing.T ) {
119
122
t .Parallel ()
120
- sshClient , err := setupAgent (t ).SSHClient ()
123
+ sshClient , err := setupAgent (t , nil ).SSHClient ()
121
124
require .NoError (t , err )
122
125
client , err := sftp .NewClient (sshClient )
123
126
require .NoError (t , err )
@@ -129,10 +132,52 @@ func TestAgent(t *testing.T) {
129
132
_ , err = os .Stat (tempFile )
130
133
require .NoError (t , err )
131
134
})
135
+
136
+ t .Run ("EnvironmentVariables" , func (t * testing.T ) {
137
+ t .Parallel ()
138
+ key := "EXAMPLE"
139
+ value := "value"
140
+ session := setupSSHSession (t , & agent.Options {
141
+ EnvironmentVariables : map [string ]string {
142
+ key : value ,
143
+ },
144
+ })
145
+ command := "sh -c 'echo $" + key + "'"
146
+ if runtime .GOOS == "windows" {
147
+ command = "cmd.exe /c echo %" + key + "%"
148
+ }
149
+ output , err := session .Output (command )
150
+ require .NoError (t , err )
151
+ require .Equal (t , value , strings .TrimSpace (string (output )))
152
+ })
153
+
154
+ t .Run ("StartupScript" , func (t * testing.T ) {
155
+ t .Parallel ()
156
+ tempPath := filepath .Join (os .TempDir (), "content.txt" )
157
+ content := "somethingnice"
158
+ setupAgent (t , & agent.Options {
159
+ StartupScript : "echo " + content + " > " + tempPath ,
160
+ })
161
+ var gotContent string
162
+ require .Eventually (t , func () bool {
163
+ content , err := os .ReadFile (tempPath )
164
+ if err != nil {
165
+ return false
166
+ }
167
+ if runtime .GOOS == "windows" {
168
+ // Windows uses UTF16! 🪟🪟🪟
169
+ content , _ , err = transform .Bytes (unicode .UTF16 (unicode .LittleEndian , unicode .UseBOM ).NewDecoder (), content )
170
+ require .NoError (t , err )
171
+ }
172
+ gotContent = string (content )
173
+ return true
174
+ }, 15 * time .Second , 100 * time .Millisecond )
175
+ require .Equal (t , content , strings .TrimSpace (gotContent ))
176
+ })
132
177
}
133
178
134
179
func setupSSHCommand (t * testing.T , beforeArgs []string , afterArgs []string ) * exec.Cmd {
135
- agentConn := setupAgent (t )
180
+ agentConn := setupAgent (t , nil )
136
181
listener , err := net .Listen ("tcp" , "127.0.0.1:0" )
137
182
require .NoError (t , err )
138
183
go func () {
@@ -160,18 +205,22 @@ func setupSSHCommand(t *testing.T, beforeArgs []string, afterArgs []string) *exe
160
205
return exec .Command ("ssh" , args ... )
161
206
}
162
207
163
- func setupSSHSession (t * testing.T ) * ssh.Session {
164
- sshClient , err := setupAgent (t ).SSHClient ()
208
+ func setupSSHSession (t * testing.T , options * agent. Options ) * ssh.Session {
209
+ sshClient , err := setupAgent (t , options ).SSHClient ()
165
210
require .NoError (t , err )
166
211
session , err := sshClient .NewSession ()
167
212
require .NoError (t , err )
168
213
return session
169
214
}
170
215
171
- func setupAgent (t * testing.T ) * agent.Conn {
216
+ func setupAgent (t * testing.T , options * agent.Options ) * agent.Conn {
217
+ if options == nil {
218
+ options = & agent.Options {}
219
+ }
172
220
client , server := provisionersdk .TransportPipe ()
173
- closer := agent .New (func (ctx context.Context , logger slog.Logger ) (* peerbroker.Listener , error ) {
174
- return peerbroker .Listen (server , nil )
221
+ closer := agent .New (func (ctx context.Context , logger slog.Logger ) (* agent.Options , * peerbroker.Listener , error ) {
222
+ listener , err := peerbroker .Listen (server , nil )
223
+ return options , listener , err
175
224
}, slogtest .Make (t , nil ).Leveled (slog .LevelDebug ))
176
225
t .Cleanup (func () {
177
226
_ = client .Close ()
0 commit comments