Skip to content

Commit 866b721

Browse files
committed
Move update predicates to push phase
Instead of the pop phase. This ensures we do not queue up updates that will just end up being discarded once they are popped (which could take some time due to latency to coderd). It also has the side effect of preserving summaries even when the queue gets too big, because now we preserve them as part of pushing, before they might get lost due to getting dropped while we wait on coderd.
1 parent 6d40d40 commit 866b721

File tree

3 files changed

+163
-68
lines changed

3 files changed

+163
-68
lines changed

cli/cliutil/queue.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"sync"
55

66
"golang.org/x/xerrors"
7+
8+
"github.com/coder/coder/v2/codersdk"
79
)
810

911
// Queue is a FIFO queue with a fixed size. If the size is exceeded, the first
@@ -14,6 +16,7 @@ type Queue[T any] struct {
1416
mu sync.Mutex
1517
size int
1618
closed bool
19+
pred func(x T) (T, bool)
1720
}
1821

1922
// NewQueue creates a queue with the given size.
@@ -26,6 +29,13 @@ func NewQueue[T any](size int) *Queue[T] {
2629
return q
2730
}
2831

32+
// WithPredicate adds the given predicate function, which can control what is
33+
// pushed to the queue.
34+
func (q *Queue[T]) WithPredicate(pred func(x T) (T, bool)) *Queue[T] {
35+
q.pred = pred
36+
return q
37+
}
38+
2939
// Close aborts any pending pops and makes future pushes error.
3040
func (q *Queue[T]) Close() {
3141
q.mu.Lock()
@@ -41,6 +51,15 @@ func (q *Queue[T]) Push(x T) error {
4151
if q.closed {
4252
return xerrors.New("queue has been closed")
4353
}
54+
// Potentially mutate or skip the push using the predicate.
55+
if q.pred != nil {
56+
var ok bool
57+
x, ok = q.pred(x)
58+
if !ok {
59+
return nil
60+
}
61+
}
62+
// Remove the first item from the queue if it has gotten too big.
4463
if len(q.items) >= q.size {
4564
q.items = q.items[1:]
4665
}
@@ -70,3 +89,72 @@ func (q *Queue[T]) Len() int {
7089
defer q.mu.Unlock()
7190
return len(q.items)
7291
}
92+
93+
type reportTask struct {
94+
link string
95+
messageID int64
96+
selfReported bool
97+
state codersdk.WorkspaceAppStatusState
98+
summary string
99+
}
100+
101+
// statusQueue is a Queue that:
102+
// 1. Only pushes items that are not duplicates.
103+
// 2. Preserves the existing message and URI when one a message is not provided.
104+
// 3. Ignores "working" updates from the status watcher.
105+
type StatusQueue struct {
106+
Queue[reportTask]
107+
// lastMessageID is the ID of the last *user* message that we saw. A user
108+
// message only happens when interacting via the API (as opposed to
109+
// interacting with the terminal directly).
110+
lastMessageID int64
111+
}
112+
113+
func (q *StatusQueue) Push(report reportTask) error {
114+
q.mu.Lock()
115+
defer q.mu.Unlock()
116+
if q.closed {
117+
return xerrors.New("queue has been closed")
118+
}
119+
var lastReport reportTask
120+
if len(q.items) > 0 {
121+
lastReport = q.items[len(q.items)-1]
122+
}
123+
// Use "working" status if this is a new user message. If this is not a new
124+
// user message, and the status is "working" and not self-reported (meaning it
125+
// came from the screen watcher), then it means one of two things:
126+
// 1. The LLM is still working, in which case our last status will already
127+
// have been "working", so there is nothing to do.
128+
// 2. The user has interacted with the terminal directly. For now, we are
129+
// ignoring these updates. This risks missing cases where the user
130+
// manually submits a new prompt and the LLM becomes active and does not
131+
// update itself, but it avoids spamming useless status updates as the user
132+
// is typing, so the tradeoff is worth it. In the future, if we can
133+
// reliably distinguish between user and LLM activity, we can change this.
134+
if report.messageID > q.lastMessageID {
135+
report.state = codersdk.WorkspaceAppStatusStateWorking
136+
} else if report.state == codersdk.WorkspaceAppStatusStateWorking && !report.selfReported {
137+
q.mu.Unlock()
138+
return nil
139+
}
140+
// Preserve previous message and URI if there was no message.
141+
if report.summary == "" {
142+
report.summary = lastReport.summary
143+
if report.link == "" {
144+
report.link = lastReport.link
145+
}
146+
}
147+
// Avoid queueing duplicate updates.
148+
if report.state == lastReport.state &&
149+
report.link == lastReport.link &&
150+
report.summary == lastReport.summary {
151+
return nil
152+
}
153+
// Drop the first item if the queue has gotten too big.
154+
if len(q.items) >= q.size {
155+
q.items = q.items[1:]
156+
}
157+
q.items = append(q.items, report)
158+
q.cond.Broadcast()
159+
return nil
160+
}

cli/cliutil/queue_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,29 @@ func TestQueue(t *testing.T) {
8282
err := q.Push(10)
8383
require.Error(t, err)
8484
})
85+
86+
t.Run("WithPredicate", func(t *testing.T) {
87+
t.Parallel()
88+
89+
q := cliutil.NewQueue[int](10)
90+
q.WithPredicate(func(n int) (int, bool) {
91+
if n == 2 {
92+
return n, false
93+
}
94+
return n + 1, true
95+
})
96+
97+
for i := 0; i < 5; i++ {
98+
err := q.Push(i)
99+
require.NoError(t, err)
100+
}
101+
102+
got := []int{}
103+
for i := 0; i < 4; i++ {
104+
val, ok := q.Pop()
105+
require.True(t, ok)
106+
got = append(got, val)
107+
}
108+
require.Equal(t, []int{1, 2, 4, 5}, got)
109+
})
85110
}

cli/exp_mcp.go

Lines changed: 50 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ func (*RootCmd) mcpConfigureCursor() *serpent.Command {
361361
return cmd
362362
}
363363

364-
type reportTask struct {
364+
type taskReport struct {
365365
link string
366366
messageID int64
367367
selfReported bool
@@ -374,7 +374,7 @@ type mcpServer struct {
374374
appStatusSlug string
375375
client *codersdk.Client
376376
llmClient *agentapi.Client
377-
queue *cliutil.Queue[reportTask]
377+
queue *cliutil.Queue[taskReport]
378378
}
379379

380380
func (r *RootCmd) mcpServer() *serpent.Command {
@@ -388,9 +388,50 @@ func (r *RootCmd) mcpServer() *serpent.Command {
388388
return &serpent.Command{
389389
Use: "server",
390390
Handler: func(inv *serpent.Invocation) error {
391+
// lastUserMessageID is the ID of the last *user* message that we saw. A
392+
// user message only happens when interacting via the LLM agent API (as
393+
// opposed to interacting with the terminal directly).
394+
var lastUserMessageID int64
395+
var lastReport taskReport
396+
// Create a queue that skips duplicates and preserves summaries.
397+
queue := cliutil.NewQueue[taskReport](512).WithPredicate(func(report taskReport) (taskReport, bool) {
398+
// Use "working" status if this is a new user message. If this is not a
399+
// new user message, and the status is "working" and not self-reported
400+
// (meaning it came from the screen watcher), then it means one of two
401+
// things:
402+
// 1. The LLM is still working, so there is nothing to update.
403+
// 2. The LLM stopped working, then the user has interacted with the
404+
// terminal directly. For now, we are ignoring these updates. This
405+
// risks missing cases where the user manually submits a new prompt
406+
// and the LLM becomes active and does not update itself, but it
407+
// avoids spamming useless status updates as the user is typing, so
408+
// the tradeoff is worth it. In the future, if we can reliably
409+
// distinguish between user and LLM activity, we can change this.
410+
if report.messageID > lastUserMessageID {
411+
report.state = codersdk.WorkspaceAppStatusStateWorking
412+
} else if report.state == codersdk.WorkspaceAppStatusStateWorking && !report.selfReported {
413+
return report, false
414+
}
415+
// Preserve previous message and URI if there was no message.
416+
if report.summary == "" {
417+
report.summary = lastReport.summary
418+
if report.link == "" {
419+
report.link = lastReport.link
420+
}
421+
}
422+
// Avoid queueing duplicate updates.
423+
if report.state == lastReport.state &&
424+
report.link == lastReport.link &&
425+
report.summary == lastReport.summary {
426+
return report, false
427+
}
428+
lastReport = report
429+
return report, true
430+
})
431+
391432
srv := &mcpServer{
392433
appStatusSlug: appStatusSlug,
393-
queue: cliutil.NewQueue[reportTask](100),
434+
queue: queue,
394435
}
395436

396437
// Display client URL separately from authentication status.
@@ -505,35 +546,6 @@ func (r *RootCmd) mcpServer() *serpent.Command {
505546
}
506547

507548
func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) {
508-
// lastMessageID is the ID of the last *user* message that we saw. A user
509-
// message only happens when interacting via the API (as opposed to
510-
// interacting with the terminal directly).
511-
var lastMessageID int64
512-
shouldUpdate := func(item reportTask) codersdk.WorkspaceAppStatusState {
513-
// Always send self-reported updates.
514-
if item.selfReported {
515-
return item.state
516-
}
517-
// Always send completed states.
518-
switch item.state {
519-
case codersdk.WorkspaceAppStatusStateComplete,
520-
codersdk.WorkspaceAppStatusStateFailure:
521-
return item.state
522-
}
523-
// Always send "working" when there is a new user message, since we know the
524-
// LLM will begin work soon if it has not already.
525-
if item.messageID > lastMessageID {
526-
return codersdk.WorkspaceAppStatusStateWorking
527-
}
528-
// Otherwise, if the state is "working" and there have been no new user
529-
// messages, it means either that the LLM is still working or it means the
530-
// user has interacted with the terminal directly. For now, we are ignoring
531-
// these updates. This risks missing cases where the user manually submits
532-
// a new prompt and the LLM becomes active and does not update itself, but
533-
// it avoids spamming useless status updates.
534-
return ""
535-
}
536-
var lastPayload agentsdk.PatchAppStatus
537549
go func() {
538550
for {
539551
// TODO: Even with the queue, there is still the potential that a message
@@ -545,45 +557,15 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
545557
return
546558
}
547559

548-
state := shouldUpdate(item)
549-
if state == "" {
550-
continue
551-
}
552-
553-
if item.messageID != 0 {
554-
lastMessageID = item.messageID
555-
}
556-
557-
payload := agentsdk.PatchAppStatus{
560+
err := s.agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
558561
AppSlug: s.appStatusSlug,
559562
Message: item.summary,
560563
URI: item.link,
561-
State: state,
562-
}
563-
564-
// Preserve previous message and URI if there was no message.
565-
if payload.Message == "" {
566-
payload.Message = lastPayload.Message
567-
if payload.URI == "" {
568-
payload.URI = lastPayload.URI
569-
}
570-
}
571-
572-
// Avoid sending duplicate updates.
573-
if lastPayload.State == payload.State &&
574-
lastPayload.URI == payload.URI &&
575-
lastPayload.Message == payload.Message {
576-
continue
577-
}
578-
579-
err := s.agentClient.PatchAppStatus(ctx, payload)
564+
State: item.state,
565+
})
580566
if err != nil && !errors.Is(err, context.Canceled) {
581567
cliui.Warnf(inv.Stderr, "Failed to report task status: %s", err)
582568
}
583-
584-
if err == nil {
585-
lastPayload = payload
586-
}
587569
}
588570
}()
589571
}
@@ -607,7 +589,7 @@ func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
607589
if ev.Status == agentapi.StatusStable {
608590
state = codersdk.WorkspaceAppStatusStateComplete
609591
}
610-
err := s.queue.Push(reportTask{
592+
err := s.queue.Push(taskReport{
611593
state: state,
612594
})
613595
if err != nil {
@@ -616,7 +598,7 @@ func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
616598
}
617599
case agentapi.EventMessageUpdate:
618600
if ev.Role == agentapi.RoleUser {
619-
err := s.queue.Push(reportTask{
601+
err := s.queue.Push(taskReport{
620602
messageID: ev.Id,
621603
})
622604
if err != nil {
@@ -667,7 +649,7 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
667649
// Add tool dependencies.
668650
toolOpts := []func(*toolsdk.Deps){
669651
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
670-
return s.queue.Push(reportTask{
652+
return s.queue.Push(taskReport{
671653
link: args.Link,
672654
selfReported: true,
673655
state: codersdk.WorkspaceAppStatusState(args.State),

0 commit comments

Comments
 (0)