Skip to content

Commit 957589e

Browse files
committed
Merge remote-tracking branch 'origin/main' into stevenmasley/abandoned_workspace_test
2 parents 7dd8694 + c916a9e commit 957589e

File tree

12 files changed

+299
-72
lines changed

12 files changed

+299
-72
lines changed

agent/agent.go

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10251025
}()
10261026

10271027
var rpty *reconnectingPTY
1028-
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
1028+
sendConnected := make(chan *reconnectingPTY, 1)
1029+
// On store, reserve this ID to prevent multiple concurrent new connections.
1030+
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10291031
if ok {
1032+
close(sendConnected) // Unused.
10301033
logger.Debug(ctx, "connecting to existing session")
1031-
rpty, ok = rawRPTY.(*reconnectingPTY)
1034+
c, ok := waitReady.(chan *reconnectingPTY)
10321035
if !ok {
1033-
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
1036+
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10341037
}
1038+
rpty, ok = <-c
1039+
if !ok || rpty == nil {
1040+
return xerrors.Errorf("reconnecting pty closed before connection")
1041+
}
1042+
c <- rpty // Put it back for the next reconnect.
10351043
} else {
10361044
logger.Debug(ctx, "creating new session")
10371045

1046+
connected := false
1047+
defer func() {
1048+
if !connected && retErr != nil {
1049+
a.reconnectingPTYs.Delete(msg.ID)
1050+
close(sendConnected)
1051+
}
1052+
}()
1053+
10381054
// Empty command will default to the users shell!
10391055
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
10401056
if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10551071
return xerrors.Errorf("start command: %w", err)
10561072
}
10571073

1058-
ctx, cancelFunc := context.WithCancel(ctx)
1074+
ctx, cancel := context.WithCancel(ctx)
10591075
rpty = &reconnectingPTY{
10601076
activeConns: map[string]net.Conn{
10611077
// We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10641080
},
10651081
ptty: ptty,
10661082
// Timeouts created with an after func can be reset!
1067-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
1083+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
10681084
circularBuffer: circularBuffer,
10691085
}
1070-
a.reconnectingPTYs.Store(msg.ID, rpty)
10711086
// We don't need to separately monitor for the process exiting.
10721087
// When it exits, our ptty.OutputReader() will return EOF after
10731088
// reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11151130
rpty.Close()
11161131
a.reconnectingPTYs.Delete(msg.ID)
11171132
}); err != nil {
1133+
_ = process.Kill()
1134+
_ = ptty.Close()
11181135
return xerrors.Errorf("start routine: %w", err)
11191136
}
1137+
connected = true
1138+
sendConnected <- rpty
11201139
}
11211140
// Resize the PTY to initial height + width.
11221141
err := rpty.ptty.Resize(msg.Height, msg.Width)

coderd/database/dbgen/generator.go renamed to coderd/database/dbgen/dbgen.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,40 @@ func must[V any](v V, err error) V {
525525
}
526526
return v
527527
}
528+
529+
func takeFirstIP(values ...net.IPNet) net.IPNet {
530+
return takeFirstF(values, func(v net.IPNet) bool {
531+
return len(v.IP) != 0 && len(v.Mask) != 0
532+
})
533+
}
534+
535+
// takeFirstSlice implements takeFirst for []any.
536+
// []any is not a comparable type.
537+
func takeFirstSlice[T any](values ...[]T) []T {
538+
return takeFirstF(values, func(v []T) bool {
539+
return len(v) != 0
540+
})
541+
}
542+
543+
// takeFirstF takes the first value that returns true
544+
func takeFirstF[Value any](values []Value, take func(v Value) bool) Value {
545+
for _, v := range values {
546+
if take(v) {
547+
return v
548+
}
549+
}
550+
// If all empty, return the last element
551+
if len(values) > 0 {
552+
return values[len(values)-1]
553+
}
554+
var empty Value
555+
return empty
556+
}
557+
558+
// takeFirst will take the first non-empty value.
559+
func takeFirst[Value comparable](values ...Value) Value {
560+
var empty Value
561+
return takeFirstF(values, func(v Value) bool {
562+
return v != empty
563+
})
564+
}

coderd/database/dbgen/take.go

Lines changed: 0 additions & 40 deletions
This file was deleted.

coderd/database/pubsub.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ func (q *msgQueue) dropped() {
163163
// Pubsub implementation using PostgreSQL.
164164
type pgPubsub struct {
165165
ctx context.Context
166+
cancel context.CancelFunc
167+
listenDone chan struct{}
166168
pgListener *pq.Listener
167169
db *sql.DB
168170
mut sync.Mutex
@@ -228,7 +230,7 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
228230
// This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't
229231
// support the first parameter being a prepared statement.
230232
//nolint:gosec
231-
_, err := p.db.ExecContext(context.Background(), `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
233+
_, err := p.db.ExecContext(p.ctx, `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message)
232234
if err != nil {
233235
return xerrors.Errorf("exec pg_notify: %w", err)
234236
}
@@ -237,19 +239,24 @@ func (p *pgPubsub) Publish(event string, message []byte) error {
237239

238240
// Close closes the pubsub instance.
239241
func (p *pgPubsub) Close() error {
240-
return p.pgListener.Close()
242+
p.cancel()
243+
err := p.pgListener.Close()
244+
<-p.listenDone
245+
return err
241246
}
242247

243248
// listen begins receiving messages on the pq listener.
244-
func (p *pgPubsub) listen(ctx context.Context) {
249+
func (p *pgPubsub) listen() {
250+
defer close(p.listenDone)
251+
defer p.pgListener.Close()
252+
245253
var (
246254
notif *pq.Notification
247255
ok bool
248256
)
249-
defer p.pgListener.Close()
250257
for {
251258
select {
252-
case <-ctx.Done():
259+
case <-p.ctx.Done():
253260
return
254261
case notif, ok = <-p.pgListener.Notify:
255262
if !ok {
@@ -292,7 +299,7 @@ func (p *pgPubsub) recordReconnect() {
292299
func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub, error) {
293300
// Creates a new listener using pq.
294301
errCh := make(chan error)
295-
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(event pq.ListenerEventType, err error) {
302+
listener := pq.NewListener(connectURL, time.Second, time.Minute, func(_ pq.ListenerEventType, err error) {
296303
// This callback gets events whenever the connection state changes.
297304
// Don't send if the errChannel has already been closed.
298305
select {
@@ -306,18 +313,25 @@ func NewPubsub(ctx context.Context, database *sql.DB, connectURL string) (Pubsub
306313
select {
307314
case err := <-errCh:
308315
if err != nil {
316+
_ = listener.Close()
309317
return nil, xerrors.Errorf("create pq listener: %w", err)
310318
}
311319
case <-ctx.Done():
320+
_ = listener.Close()
312321
return nil, ctx.Err()
313322
}
323+
324+
// Start a new context that will be canceled when the pubsub is closed.
325+
ctx, cancel := context.WithCancel(context.Background())
314326
pgPubsub := &pgPubsub{
315327
ctx: ctx,
328+
cancel: cancel,
329+
listenDone: make(chan struct{}),
316330
db: database,
317331
pgListener: listener,
318332
queues: make(map[string]map[uuid.UUID]*msgQueue),
319333
}
320-
go pgPubsub.listen(ctx)
334+
go pgPubsub.listen()
321335

322336
return pgPubsub, nil
323337
}

coderd/database/pubsub_test.go

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ func TestPubsub(t *testing.T) {
4545
event := "test"
4646
data := "testing"
4747
messageChannel := make(chan []byte)
48-
cancelFunc, err = pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
48+
unsub, err := pubsub.Subscribe(event, func(ctx context.Context, message []byte) {
4949
messageChannel <- message
5050
})
5151
require.NoError(t, err)
52-
defer cancelFunc()
52+
defer unsub()
5353
go func() {
5454
err = pubsub.Publish(event, []byte(data))
5555
assert.NoError(t, err)
@@ -72,6 +72,91 @@ func TestPubsub(t *testing.T) {
7272
defer pubsub.Close()
7373
cancelFunc()
7474
})
75+
76+
t.Run("NotClosedOnCancelContext", func(t *testing.T) {
77+
ctx, cancel := context.WithCancel(context.Background())
78+
defer cancel()
79+
connectionURL, closePg, err := postgres.Open()
80+
require.NoError(t, err)
81+
defer closePg()
82+
db, err := sql.Open("postgres", connectionURL)
83+
require.NoError(t, err)
84+
defer db.Close()
85+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
86+
require.NoError(t, err)
87+
defer pubsub.Close()
88+
89+
// Provided context must only be active during NewPubsub, not after.
90+
cancel()
91+
92+
event := "test"
93+
data := "testing"
94+
messageChannel := make(chan []byte)
95+
unsub, err := pubsub.Subscribe(event, func(_ context.Context, message []byte) {
96+
messageChannel <- message
97+
})
98+
require.NoError(t, err)
99+
defer unsub()
100+
go func() {
101+
err = pubsub.Publish(event, []byte(data))
102+
assert.NoError(t, err)
103+
}()
104+
message := <-messageChannel
105+
assert.Equal(t, string(message), data)
106+
})
107+
108+
t.Run("ClosePropagatesContextCancellationToSubscription", func(t *testing.T) {
109+
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
110+
defer cancel()
111+
connectionURL, closePg, err := postgres.Open()
112+
require.NoError(t, err)
113+
defer closePg()
114+
db, err := sql.Open("postgres", connectionURL)
115+
require.NoError(t, err)
116+
defer db.Close()
117+
pubsub, err := database.NewPubsub(ctx, db, connectionURL)
118+
require.NoError(t, err)
119+
defer pubsub.Close()
120+
121+
event := "test"
122+
done := make(chan struct{})
123+
called := make(chan struct{})
124+
unsub, err := pubsub.Subscribe(event, func(subCtx context.Context, _ []byte) {
125+
defer close(done)
126+
select {
127+
case <-subCtx.Done():
128+
assert.Fail(t, "context should not be canceled")
129+
default:
130+
}
131+
close(called)
132+
select {
133+
case <-subCtx.Done():
134+
case <-ctx.Done():
135+
assert.Fail(t, "timeout waiting for sub context to be canceled")
136+
}
137+
})
138+
require.NoError(t, err)
139+
defer unsub()
140+
141+
go func() {
142+
err := pubsub.Publish(event, nil)
143+
assert.NoError(t, err)
144+
}()
145+
146+
select {
147+
case <-called:
148+
case <-ctx.Done():
149+
require.Fail(t, "timeout waiting for handler to be called")
150+
}
151+
err = pubsub.Close()
152+
require.NoError(t, err)
153+
154+
select {
155+
case <-done:
156+
case <-ctx.Done():
157+
require.Fail(t, "timeout waiting for handler to finish")
158+
}
159+
})
75160
}
76161

77162
func TestPubsub_ordering(t *testing.T) {

coderd/workspaceagents.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1434,17 +1434,15 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
14341434
})
14351435
}
14361436

1437-
// Send initial metadata.
1438-
sendMetadata(true)
1439-
14401437
// We debounce metadata updates to avoid overloading the frontend when
14411438
// an agent is sending a lot of updates.
14421439
pubsubDebounce := debounce.New(time.Second)
14431440
if flag.Lookup("test.v") != nil {
14441441
pubsubDebounce = debounce.New(time.Millisecond * 100)
14451442
}
14461443

1447-
// Send metadata on updates.
1444+
// Send metadata on updates, we must ensure subscription before sending
1445+
// initial metadata to guarantee that events in-between are not missed.
14481446
cancelSub, err := api.Pubsub.Subscribe(watchWorkspaceAgentMetadataChannel(workspaceAgent.ID), func(_ context.Context, _ []byte) {
14491447
pubsubDebounce(func() {
14501448
sendMetadata(true)
@@ -1456,12 +1454,14 @@ func (api *API) watchWorkspaceAgentMetadata(rw http.ResponseWriter, r *http.Requ
14561454
}
14571455
defer cancelSub()
14581456

1457+
// Send initial metadata.
1458+
sendMetadata(true)
1459+
14591460
for {
14601461
select {
14611462
case <-senderClosed:
14621463
return
14631464
case <-refreshTicker.C:
1464-
break
14651465
}
14661466

14671467
// Avoid spamming the DB with reads we know there are no updates. We want

0 commit comments

Comments
 (0)