Skip to content

Commit ffc599d

Browse files
committed
Move update predicates to push phase
Instead of the pop phase. This ensures we do not queue up updates that will just end up being discarded once they are popped (which could take some time due to latency to coderd). It also has the side effect of preserving summaries even when the queue gets too big, because now we preserve them as part of pushing, before they might get lost due to getting dropped while we wait on coderd.
1 parent 6d40d40 commit ffc599d

File tree

3 files changed

+168
-68
lines changed

3 files changed

+168
-68
lines changed

cli/cliutil/queue.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"sync"
55

66
"golang.org/x/xerrors"
7+
8+
"github.com/coder/coder/v2/codersdk"
79
)
810

911
// Queue is a FIFO queue with a fixed size. If the size is exceeded, the first
@@ -14,6 +16,7 @@ type Queue[T any] struct {
1416
mu sync.Mutex
1517
size int
1618
closed bool
19+
pred func(x T) (T, bool)
1720
}
1821

1922
// NewQueue creates a queue with the given size.
@@ -26,6 +29,13 @@ func NewQueue[T any](size int) *Queue[T] {
2629
return q
2730
}
2831

32+
// WithPredicate adds the given predicate function, which can control what is
33+
// pushed to the queue.
34+
func (q *Queue[T]) WithPredicate(pred func(x T) (T, bool)) *Queue[T] {
35+
q.pred = pred
36+
return q
37+
}
38+
2939
// Close aborts any pending pops and makes future pushes error.
3040
func (q *Queue[T]) Close() {
3141
q.mu.Lock()
@@ -41,6 +51,15 @@ func (q *Queue[T]) Push(x T) error {
4151
if q.closed {
4252
return xerrors.New("queue has been closed")
4353
}
54+
// Potentially mutate or skip the push using the predicate.
55+
if q.pred != nil {
56+
var ok bool
57+
x, ok = q.pred(x)
58+
if !ok {
59+
return nil
60+
}
61+
}
62+
// Remove the first item from the queue if it has gotten too big.
4463
if len(q.items) >= q.size {
4564
q.items = q.items[1:]
4665
}
@@ -70,3 +89,72 @@ func (q *Queue[T]) Len() int {
7089
defer q.mu.Unlock()
7190
return len(q.items)
7291
}
92+
93+
type reportTask struct {
94+
link string
95+
messageID int64
96+
selfReported bool
97+
state codersdk.WorkspaceAppStatusState
98+
summary string
99+
}
100+
101+
// statusQueue is a Queue that:
102+
// 1. Only pushes items that are not duplicates.
103+
// 2. Preserves the existing message and URI when one a message is not provided.
104+
// 3. Ignores "working" updates from the status watcher.
105+
type StatusQueue struct {
106+
Queue[reportTask]
107+
// lastMessageID is the ID of the last *user* message that we saw. A user
108+
// message only happens when interacting via the API (as opposed to
109+
// interacting with the terminal directly).
110+
lastMessageID int64
111+
}
112+
113+
func (q *StatusQueue) Push(report reportTask) error {
114+
q.mu.Lock()
115+
defer q.mu.Unlock()
116+
if q.closed {
117+
return xerrors.New("queue has been closed")
118+
}
119+
var lastReport reportTask
120+
if len(q.items) > 0 {
121+
lastReport = q.items[len(q.items)-1]
122+
}
123+
// Use "working" status if this is a new user message. If this is not a new
124+
// user message, and the status is "working" and not self-reported (meaning it
125+
// came from the screen watcher), then it means one of two things:
126+
// 1. The LLM is still working, in which case our last status will already
127+
// have been "working", so there is nothing to do.
128+
// 2. The user has interacted with the terminal directly. For now, we are
129+
// ignoring these updates. This risks missing cases where the user
130+
// manually submits a new prompt and the LLM becomes active and does not
131+
// update itself, but it avoids spamming useless status updates as the user
132+
// is typing, so the tradeoff is worth it. In the future, if we can
133+
// reliably distinguish between user and LLM activity, we can change this.
134+
if report.messageID > q.lastMessageID {
135+
report.state = codersdk.WorkspaceAppStatusStateWorking
136+
} else if report.state == codersdk.WorkspaceAppStatusStateWorking && !report.selfReported {
137+
q.mu.Unlock()
138+
return nil
139+
}
140+
// Preserve previous message and URI if there was no message.
141+
if report.summary == "" {
142+
report.summary = lastReport.summary
143+
if report.link == "" {
144+
report.link = lastReport.link
145+
}
146+
}
147+
// Avoid queueing duplicate updates.
148+
if report.state == lastReport.state &&
149+
report.link == lastReport.link &&
150+
report.summary == lastReport.summary {
151+
return nil
152+
}
153+
// Drop the first item if the queue has gotten too big.
154+
if len(q.items) >= q.size {
155+
q.items = q.items[1:]
156+
}
157+
q.items = append(q.items, report)
158+
q.cond.Broadcast()
159+
return nil
160+
}

cli/cliutil/queue_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,35 @@ func TestQueue(t *testing.T) {
8282
err := q.Push(10)
8383
require.Error(t, err)
8484
})
85+
86+
t.Run("WithPredicate", func(t *testing.T) {
87+
t.Parallel()
88+
89+
q := cliutil.NewQueue[int](10)
90+
q.WithPredicate(func(n int) (int, bool) {
91+
if n == 2 {
92+
return n, false
93+
}
94+
return n + 1, true
95+
})
96+
97+
done := make(chan bool)
98+
go func() {
99+
_, ok := q.Pop()
100+
done <- ok
101+
}()
102+
103+
for i := 0; i < 5; i++ {
104+
err := q.Push(i)
105+
require.NoError(t, err)
106+
}
107+
108+
got := []int{}
109+
for i := 0; i < 4; i++ {
110+
val, ok := q.Pop()
111+
require.True(t, ok)
112+
got = append(got, val)
113+
}
114+
require.Equal(t, []int{1, 2, 4, 5}, got)
115+
})
85116
}

cli/exp_mcp.go

Lines changed: 49 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ func (*RootCmd) mcpConfigureCursor() *serpent.Command {
361361
return cmd
362362
}
363363

364-
type reportTask struct {
364+
type taskReport struct {
365365
link string
366366
messageID int64
367367
selfReported bool
@@ -374,7 +374,7 @@ type mcpServer struct {
374374
appStatusSlug string
375375
client *codersdk.Client
376376
llmClient *agentapi.Client
377-
queue *cliutil.Queue[reportTask]
377+
queue *cliutil.Queue[taskReport]
378378
}
379379

380380
func (r *RootCmd) mcpServer() *serpent.Command {
@@ -388,9 +388,49 @@ func (r *RootCmd) mcpServer() *serpent.Command {
388388
return &serpent.Command{
389389
Use: "server",
390390
Handler: func(inv *serpent.Invocation) error {
391+
// lastMessageID is the ID of the last *user* message that we saw. A user
392+
// message only happens when interacting via the API (as opposed to
393+
// interacting with the terminal directly).
394+
var lastMessageID int64
395+
var lastReport taskReport
396+
// Create a queue that skips duplicates and preserves summaries.
397+
queue := cliutil.NewQueue[taskReport](512).WithPredicate(func(report taskReport) (taskReport, bool) {
398+
// Use "working" status if this is a new user message. If this is not a new
399+
// user message, and the status is "working" and not self-reported (meaning it
400+
// came from the screen watcher), then it means one of two things:
401+
// 1. The LLM is still working, in which case our last status will already
402+
// have been "working", so there is nothing to do.
403+
// 2. The user has interacted with the terminal directly. For now, we are
404+
// ignoring these updates. This risks missing cases where the user
405+
// manually submits a new prompt and the LLM becomes active and does not
406+
// update itself, but it avoids spamming useless status updates as the user
407+
// is typing, so the tradeoff is worth it. In the future, if we can
408+
// reliably distinguish between user and LLM activity, we can change this.
409+
if report.messageID > lastMessageID {
410+
report.state = codersdk.WorkspaceAppStatusStateWorking
411+
} else if report.state == codersdk.WorkspaceAppStatusStateWorking && !report.selfReported {
412+
return report, false
413+
}
414+
// Preserve previous message and URI if there was no message.
415+
if report.summary == "" {
416+
report.summary = lastReport.summary
417+
if report.link == "" {
418+
report.link = lastReport.link
419+
}
420+
}
421+
// Avoid queueing duplicate updates.
422+
if report.state == lastReport.state &&
423+
report.link == lastReport.link &&
424+
report.summary == lastReport.summary {
425+
return report, false
426+
}
427+
lastReport = report
428+
return report, true
429+
})
430+
391431
srv := &mcpServer{
392432
appStatusSlug: appStatusSlug,
393-
queue: cliutil.NewQueue[reportTask](100),
433+
queue: queue,
394434
}
395435

396436
// Display client URL separately from authentication status.
@@ -505,35 +545,6 @@ func (r *RootCmd) mcpServer() *serpent.Command {
505545
}
506546

507547
func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation) {
508-
// lastMessageID is the ID of the last *user* message that we saw. A user
509-
// message only happens when interacting via the API (as opposed to
510-
// interacting with the terminal directly).
511-
var lastMessageID int64
512-
shouldUpdate := func(item reportTask) codersdk.WorkspaceAppStatusState {
513-
// Always send self-reported updates.
514-
if item.selfReported {
515-
return item.state
516-
}
517-
// Always send completed states.
518-
switch item.state {
519-
case codersdk.WorkspaceAppStatusStateComplete,
520-
codersdk.WorkspaceAppStatusStateFailure:
521-
return item.state
522-
}
523-
// Always send "working" when there is a new user message, since we know the
524-
// LLM will begin work soon if it has not already.
525-
if item.messageID > lastMessageID {
526-
return codersdk.WorkspaceAppStatusStateWorking
527-
}
528-
// Otherwise, if the state is "working" and there have been no new user
529-
// messages, it means either that the LLM is still working or it means the
530-
// user has interacted with the terminal directly. For now, we are ignoring
531-
// these updates. This risks missing cases where the user manually submits
532-
// a new prompt and the LLM becomes active and does not update itself, but
533-
// it avoids spamming useless status updates.
534-
return ""
535-
}
536-
var lastPayload agentsdk.PatchAppStatus
537548
go func() {
538549
for {
539550
// TODO: Even with the queue, there is still the potential that a message
@@ -545,45 +556,15 @@ func (s *mcpServer) startReporter(ctx context.Context, inv *serpent.Invocation)
545556
return
546557
}
547558

548-
state := shouldUpdate(item)
549-
if state == "" {
550-
continue
551-
}
552-
553-
if item.messageID != 0 {
554-
lastMessageID = item.messageID
555-
}
556-
557-
payload := agentsdk.PatchAppStatus{
559+
err := s.agentClient.PatchAppStatus(ctx, agentsdk.PatchAppStatus{
558560
AppSlug: s.appStatusSlug,
559561
Message: item.summary,
560562
URI: item.link,
561-
State: state,
562-
}
563-
564-
// Preserve previous message and URI if there was no message.
565-
if payload.Message == "" {
566-
payload.Message = lastPayload.Message
567-
if payload.URI == "" {
568-
payload.URI = lastPayload.URI
569-
}
570-
}
571-
572-
// Avoid sending duplicate updates.
573-
if lastPayload.State == payload.State &&
574-
lastPayload.URI == payload.URI &&
575-
lastPayload.Message == payload.Message {
576-
continue
577-
}
578-
579-
err := s.agentClient.PatchAppStatus(ctx, payload)
563+
State: item.state,
564+
})
580565
if err != nil && !errors.Is(err, context.Canceled) {
581566
cliui.Warnf(inv.Stderr, "Failed to report task status: %s", err)
582567
}
583-
584-
if err == nil {
585-
lastPayload = payload
586-
}
587568
}
588569
}()
589570
}
@@ -607,7 +588,7 @@ func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
607588
if ev.Status == agentapi.StatusStable {
608589
state = codersdk.WorkspaceAppStatusStateComplete
609590
}
610-
err := s.queue.Push(reportTask{
591+
err := s.queue.Push(taskReport{
611592
state: state,
612593
})
613594
if err != nil {
@@ -616,7 +597,7 @@ func (s *mcpServer) startWatcher(ctx context.Context, inv *serpent.Invocation) {
616597
}
617598
case agentapi.EventMessageUpdate:
618599
if ev.Role == agentapi.RoleUser {
619-
err := s.queue.Push(reportTask{
600+
err := s.queue.Push(taskReport{
620601
messageID: ev.Id,
621602
})
622603
if err != nil {
@@ -667,7 +648,7 @@ func (s *mcpServer) startServer(ctx context.Context, inv *serpent.Invocation, in
667648
// Add tool dependencies.
668649
toolOpts := []func(*toolsdk.Deps){
669650
toolsdk.WithTaskReporter(func(args toolsdk.ReportTaskArgs) error {
670-
return s.queue.Push(reportTask{
651+
return s.queue.Push(taskReport{
671652
link: args.Link,
672653
selfReported: true,
673654
state: codersdk.WorkspaceAppStatusState(args.State),

0 commit comments

Comments
 (0)