Skip to content

Commit 0862e1d

Browse files
authored
Merge branch 'main' into mafredri/test-minor-fixes
2 parents fc150a8 + c916a9e commit 0862e1d

File tree

4 files changed

+62
-46
lines changed

4 files changed

+62
-46
lines changed

agent/agent.go

+25-6
Original file line numberDiff line numberDiff line change
@@ -1025,16 +1025,32 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10251025
}()
10261026

10271027
var rpty *reconnectingPTY
1028-
rawRPTY, ok := a.reconnectingPTYs.Load(msg.ID)
1028+
sendConnected := make(chan *reconnectingPTY, 1)
1029+
// On store, reserve this ID to prevent multiple concurrent new connections.
1030+
waitReady, ok := a.reconnectingPTYs.LoadOrStore(msg.ID, sendConnected)
10291031
if ok {
1032+
close(sendConnected) // Unused.
10301033
logger.Debug(ctx, "connecting to existing session")
1031-
rpty, ok = rawRPTY.(*reconnectingPTY)
1034+
c, ok := waitReady.(chan *reconnectingPTY)
10321035
if !ok {
1033-
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", rawRPTY)
1036+
return xerrors.Errorf("found invalid type in reconnecting pty map: %T", waitReady)
10341037
}
1038+
rpty, ok = <-c
1039+
if !ok || rpty == nil {
1040+
return xerrors.Errorf("reconnecting pty closed before connection")
1041+
}
1042+
c <- rpty // Put it back for the next reconnect.
10351043
} else {
10361044
logger.Debug(ctx, "creating new session")
10371045

1046+
connected := false
1047+
defer func() {
1048+
if !connected && retErr != nil {
1049+
a.reconnectingPTYs.Delete(msg.ID)
1050+
close(sendConnected)
1051+
}
1052+
}()
1053+
10381054
// Empty command will default to the users shell!
10391055
cmd, err := a.sshServer.CreateCommand(ctx, msg.Command, nil)
10401056
if err != nil {
@@ -1055,7 +1071,7 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10551071
return xerrors.Errorf("start command: %w", err)
10561072
}
10571073

1058-
ctx, cancelFunc := context.WithCancel(ctx)
1074+
ctx, cancel := context.WithCancel(ctx)
10591075
rpty = &reconnectingPTY{
10601076
activeConns: map[string]net.Conn{
10611077
// We have to put the connection in the map instantly otherwise
@@ -1064,10 +1080,9 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
10641080
},
10651081
ptty: ptty,
10661082
// Timeouts created with an after func can be reset!
1067-
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancelFunc),
1083+
timeout: time.AfterFunc(a.reconnectingPTYTimeout, cancel),
10681084
circularBuffer: circularBuffer,
10691085
}
1070-
a.reconnectingPTYs.Store(msg.ID, rpty)
10711086
// We don't need to separately monitor for the process exiting.
10721087
// When it exits, our ptty.OutputReader() will return EOF after
10731088
// reading all process output.
@@ -1115,8 +1130,12 @@ func (a *agent) handleReconnectingPTY(ctx context.Context, logger slog.Logger, m
11151130
rpty.Close()
11161131
a.reconnectingPTYs.Delete(msg.ID)
11171132
}); err != nil {
1133+
_ = process.Kill()
1134+
_ = ptty.Close()
11181135
return xerrors.Errorf("start routine: %w", err)
11191136
}
1137+
connected = true
1138+
sendConnected <- rpty
11201139
}
11211140
// Resize the PTY to initial height + width.
11221141
err := rpty.ptty.Resize(msg.Height, msg.Width)

coderd/database/dbgen/generator.go renamed to coderd/database/dbgen/dbgen.go

+37
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,40 @@ func must[V any](v V, err error) V {
525525
}
526526
return v
527527
}
528+
529+
func takeFirstIP(values ...net.IPNet) net.IPNet {
530+
return takeFirstF(values, func(v net.IPNet) bool {
531+
return len(v.IP) != 0 && len(v.Mask) != 0
532+
})
533+
}
534+
535+
// takeFirstSlice implements takeFirst for []any.
536+
// []any is not a comparable type.
537+
func takeFirstSlice[T any](values ...[]T) []T {
538+
return takeFirstF(values, func(v []T) bool {
539+
return len(v) != 0
540+
})
541+
}
542+
543+
// takeFirstF takes the first value that returns true
544+
func takeFirstF[Value any](values []Value, take func(v Value) bool) Value {
545+
for _, v := range values {
546+
if take(v) {
547+
return v
548+
}
549+
}
550+
// If all empty, return the last element
551+
if len(values) > 0 {
552+
return values[len(values)-1]
553+
}
554+
var empty Value
555+
return empty
556+
}
557+
558+
// takeFirst will take the first non-empty value.
559+
func takeFirst[Value comparable](values ...Value) Value {
560+
var empty Value
561+
return takeFirstF(values, func(v Value) bool {
562+
return v != empty
563+
})
564+
}

coderd/database/dbgen/take.go

-40
This file was deleted.

0 commit comments

Comments
 (0)