Skip to content

Commit cc715cd

Browse files
committed
Fix startup and shutdown reporting error
1 parent 8d75963 commit cc715cd

File tree

1 file changed

+58
-30
lines changed

1 file changed

+58
-30
lines changed

agent/agent.go

+58-30
Original file line numberDiff line numberDiff line change
@@ -652,22 +652,48 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
652652
_ = fileWriter.Close()
653653
}()
654654

655-
// Create pipes for startup logs reader and writer
656-
startupLogsReader, startupLogsWriter := io.Pipe()
655+
var writer io.Writer = fileWriter
656+
if lifecycle == "startup" {
657+
// Create pipes for startup logs reader and writer
658+
logsReader, logsWriter := io.Pipe()
659+
defer func() {
660+
_ = logsReader.Close()
661+
}()
662+
writer = io.MultiWriter(fileWriter, logsWriter)
663+
flushedLogs, err := a.trackScriptLogs(ctx, logsReader)
664+
if err != nil {
665+
return xerrors.Errorf("track script logs: %w", err)
666+
}
667+
defer func() {
668+
_ = logsWriter.Close()
669+
<-flushedLogs
670+
}()
671+
}
657672

658-
// Close the pipes when the function returns
659-
defer func() {
660-
_ = startupLogsReader.Close()
661-
_ = startupLogsWriter.Close()
662-
}()
673+
cmd, err := a.createCommand(ctx, script, nil)
674+
if err != nil {
675+
return xerrors.Errorf("create command: %w", err)
676+
}
677+
cmd.Stdout = writer
678+
cmd.Stderr = writer
679+
err = cmd.Run()
680+
if err != nil {
681+
// cmd.Run does not return a context canceled error, it returns "signal: killed".
682+
if ctx.Err() != nil {
683+
return ctx.Err()
684+
}
663685

664-
// Create a multi-writer for startup logs and file writer
665-
writer := io.MultiWriter(startupLogsWriter, fileWriter)
686+
return xerrors.Errorf("run: %w", err)
687+
}
688+
return nil
689+
}
666690

691+
func (a *agent) trackScriptLogs(ctx context.Context, reader io.Reader) (chan struct{}, error) {
667692
// Initialize variables for log management
668693
queuedLogs := make([]agentsdk.StartupLog, 0)
669694
var flushLogsTimer *time.Timer
670695
var logMutex sync.Mutex
696+
logsFlushed := sync.NewCond(&sync.Mutex{})
671697
var logsSending bool
672698
defer func() {
673699
logMutex.Lock()
@@ -720,6 +746,7 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
720746
logsSending = false
721747
flushLogsTimer.Reset(100 * time.Millisecond)
722748
logMutex.Unlock()
749+
logsFlushed.Broadcast()
723750
}
724751
// queueLog function appends a log to the queue and triggers sendLogs if necessary
725752
queueLog := func(log agentsdk.StartupLog) {
@@ -743,36 +770,37 @@ func (a *agent) runScript(ctx context.Context, lifecycle, script string) error {
743770
}
744771
flushLogsTimer = time.AfterFunc(100*time.Millisecond, sendLogs)
745772
}
746-
err = a.trackConnGoroutine(func() {
747-
scanner := bufio.NewScanner(startupLogsReader)
773+
774+
// It's important that we either flush or drop all logs before returning
775+
// because the startup state is reported after flush.
776+
//
777+
// It'd be weird for the startup state to be ready, but logs are still
778+
// coming in.
779+
logsFinished := make(chan struct{})
780+
err := a.trackConnGoroutine(func() {
781+
scanner := bufio.NewScanner(reader)
748782
for scanner.Scan() {
749783
queueLog(agentsdk.StartupLog{
750784
CreatedAt: database.Now(),
751785
Output: scanner.Text(),
752786
})
753787
}
788+
defer close(logsFinished)
789+
logsFlushed.L.Lock()
790+
for {
791+
logMutex.Lock()
792+
if len(queuedLogs) == 0 {
793+
logMutex.Unlock()
794+
break
795+
}
796+
logMutex.Unlock()
797+
logsFlushed.Wait()
798+
}
754799
})
755800
if err != nil {
756-
return xerrors.Errorf("track conn goroutine: %w", err)
801+
return nil, xerrors.Errorf("track conn goroutine: %w", err)
757802
}
758-
759-
cmd, err := a.createCommand(ctx, script, nil)
760-
if err != nil {
761-
return xerrors.Errorf("create command: %w", err)
762-
}
763-
cmd.Stdout = writer
764-
cmd.Stderr = writer
765-
err = cmd.Run()
766-
if err != nil {
767-
// cmd.Run does not return a context canceled error, it returns "signal: killed".
768-
if ctx.Err() != nil {
769-
return ctx.Err()
770-
}
771-
772-
return xerrors.Errorf("run: %w", err)
773-
}
774-
775-
return nil
803+
return logsFinished, nil
776804
}
777805

778806
func (a *agent) init(ctx context.Context) {

0 commit comments

Comments
 (0)