Skip to content

Commit ffc2981

Browse files
committed
add agentscripts test for execute option
1 parent 9aca1c8 commit ffc2981

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

agent/agentscripts/agentscripts_test.go

+151
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"path/filepath"
66
"runtime"
7+
"slices"
8+
"sync"
79
"testing"
810
"time"
911

@@ -151,6 +153,155 @@ func TestCronClose(t *testing.T) {
151153
require.NoError(t, runner.Close(), "close runner")
152154
}
153155

156+
func TestExecuteOptions(t *testing.T) {
157+
t.Parallel()
158+
159+
startScript := codersdk.WorkspaceAgentScript{
160+
ID: uuid.New(),
161+
LogSourceID: uuid.New(),
162+
Script: "echo start",
163+
RunOnStart: true,
164+
}
165+
stopScript := codersdk.WorkspaceAgentScript{
166+
ID: uuid.New(),
167+
LogSourceID: uuid.New(),
168+
Script: "echo stop",
169+
RunOnStop: true,
170+
}
171+
postStartScript := codersdk.WorkspaceAgentScript{
172+
ID: uuid.New(),
173+
LogSourceID: uuid.New(),
174+
Script: "echo poststart",
175+
}
176+
regularScript := codersdk.WorkspaceAgentScript{
177+
ID: uuid.New(),
178+
LogSourceID: uuid.New(),
179+
Script: "echo regular",
180+
}
181+
182+
scripts := []codersdk.WorkspaceAgentScript{
183+
startScript,
184+
stopScript,
185+
regularScript,
186+
}
187+
allScripts := append(slices.Clone(scripts), postStartScript)
188+
189+
scriptByID := func(t *testing.T, id uuid.UUID) codersdk.WorkspaceAgentScript {
190+
for _, script := range allScripts {
191+
if script.ID == id {
192+
return script
193+
}
194+
}
195+
t.Fatal("script not found")
196+
return codersdk.WorkspaceAgentScript{}
197+
}
198+
199+
wantOutput := map[uuid.UUID]string{
200+
startScript.ID: "start",
201+
stopScript.ID: "stop",
202+
postStartScript.ID: "poststart",
203+
regularScript.ID: "regular",
204+
}
205+
206+
testCases := []struct {
207+
name string
208+
option agentscripts.ExecuteOption
209+
wantRun []uuid.UUID
210+
}{
211+
{
212+
name: "ExecuteAllScripts",
213+
option: agentscripts.ExecuteAllScripts,
214+
wantRun: []uuid.UUID{startScript.ID, stopScript.ID, regularScript.ID, postStartScript.ID},
215+
},
216+
{
217+
name: "ExecuteStartScripts",
218+
option: agentscripts.ExecuteStartScripts,
219+
wantRun: []uuid.UUID{startScript.ID},
220+
},
221+
{
222+
name: "ExecutePostStartScripts",
223+
option: agentscripts.ExecutePostStartScripts,
224+
wantRun: []uuid.UUID{postStartScript.ID},
225+
},
226+
{
227+
name: "ExecuteStopScripts",
228+
option: agentscripts.ExecuteStopScripts,
229+
wantRun: []uuid.UUID{stopScript.ID},
230+
},
231+
}
232+
233+
for _, tc := range testCases {
234+
t.Run(tc.name, func(t *testing.T) {
235+
t.Parallel()
236+
237+
ctx := testutil.Context(t, testutil.WaitMedium)
238+
executedScripts := make(map[uuid.UUID]bool)
239+
fLogger := &filterTestLogger{
240+
tb: t,
241+
executedScripts: executedScripts,
242+
wantOutput: wantOutput,
243+
}
244+
245+
runner := setup(t, func(_ uuid.UUID) agentscripts.ScriptLogger {
246+
return fLogger
247+
})
248+
defer runner.Close()
249+
250+
aAPI := agenttest.NewFakeAgentAPI(t, testutil.Logger(t), nil, nil)
251+
err := runner.Init(
252+
scripts,
253+
aAPI.ScriptCompleted,
254+
agentscripts.WithPostStartScripts(postStartScript),
255+
)
256+
require.NoError(t, err)
257+
258+
err = runner.Execute(ctx, tc.option)
259+
require.NoError(t, err)
260+
261+
gotRun := map[uuid.UUID]bool{}
262+
for _, id := range tc.wantRun {
263+
gotRun[id] = true
264+
require.True(t, executedScripts[id],
265+
"script %s should have run when using filter %s", scriptByID(t, id).Script, tc.name)
266+
}
267+
268+
for _, script := range allScripts {
269+
if _, ok := gotRun[script.ID]; ok {
270+
continue
271+
}
272+
require.False(t, executedScripts[script.ID],
273+
"script %s should not have run when using filter %s", script.Script, tc.name)
274+
}
275+
})
276+
}
277+
}
278+
279+
type filterTestLogger struct {
280+
tb testing.TB
281+
executedScripts map[uuid.UUID]bool
282+
wantOutput map[uuid.UUID]string
283+
mu sync.Mutex
284+
}
285+
286+
func (l *filterTestLogger) Send(ctx context.Context, logs ...agentsdk.Log) error {
287+
l.mu.Lock()
288+
defer l.mu.Unlock()
289+
for _, log := range logs {
290+
l.tb.Log(log.Output)
291+
for id, output := range l.wantOutput {
292+
if log.Output == output {
293+
l.executedScripts[id] = true
294+
break
295+
}
296+
}
297+
}
298+
return nil
299+
}
300+
301+
func (l *filterTestLogger) Flush(context.Context) error {
302+
return nil
303+
}
304+
154305
func setup(t *testing.T, getScriptLogger func(logSourceID uuid.UUID) agentscripts.ScriptLogger) *agentscripts.Runner {
155306
t.Helper()
156307
if getScriptLogger == nil {

0 commit comments

Comments
 (0)