Skip to content

refactor(agent): Move SSH server into agentssh package #7004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 6, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Improve handling of serve/close
  • Loading branch information
mafredri committed Apr 6, 2023
commit ed63a2bcf048d0fc178908ff06743c4abd823fd0
128 changes: 120 additions & 8 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@ const (
)

type Server struct {
serveWg sync.WaitGroup
logger slog.Logger
mu sync.RWMutex // Protects following.
listeners map[net.Listener]struct{}
conns map[net.Conn]struct{}
closing chan struct{}
// Wait for goroutines to exit, waited without
// a lock on mu but protected by closing.
wg sync.WaitGroup

srv *ssh.Server
logger slog.Logger
srv *ssh.Server

Env map[string]string
AgentToken func() string
Expand Down Expand Up @@ -78,7 +84,9 @@ func NewServer(ctx context.Context, logger slog.Logger, maxTimeout time.Duration
unixForwardHandler := &forwardedUnixHandler{log: logger}

s := &Server{
logger: logger,
listeners: make(map[net.Listener]struct{}),
conns: make(map[net.Conn]struct{}),
logger: logger,
}

s.srv = &ssh.Server{
Expand Down Expand Up @@ -472,14 +480,118 @@ func (s *Server) CreateCommand(ctx context.Context, script string, env []string)
}

func (s *Server) Serve(l net.Listener) error {
s.serveWg.Add(1)
defer s.serveWg.Done()
return s.srv.Serve(l)
defer l.Close()

s.trackListener(l, true)
defer s.trackListener(l, false)
for {
conn, err := l.Accept()
if err != nil {
return err
}
go s.handleConn(l, conn)
}
}

func (s *Server) handleConn(l net.Listener, c net.Conn) {
defer c.Close()

if !s.trackConn(l, c, true) {
// Server is closed or we no longer want
// connections from this listener.
s.logger.Debug(context.Background(), "received connection after server closed")
return
}
defer s.trackConn(l, c, false)

s.srv.HandleConn(c)
}

// trackListener registers the listener with the server. If the server is
// closing, the function will block until the server is closed.
//
//nolint:revive
func (s *Server) trackListener(l net.Listener, add bool) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems quite complicated since we only have one listener ever being served from what I can tell

Copy link
Member Author

@mafredri mafredri Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a suggestion for a simplification? I wanted this package to be able to manage it's own state and give guarantees for close/shutdown. This is in part motivated by the current setup of tailnet in the agent, which can re-run if an error is encountered (i.e. after a call to ssh server Serve).

(We also can't rely on the ssh package because it has broken guarantees in this regard.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just have a single listener on the struct instead of a map, but still add to the waitgroup? It seems that it's written in a way where it can be reused after close by calling serve again, but I don't believe we use that anywhere so it seems unnecessary.

Copy link
Member Author

@mafredri mafredri Apr 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it's written in a way where it can be reused after close by calling serve again

That's actually what happens if createTailnet encounters an error and a new tailnet is set up in the next retry, Serve will be called again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, could you make it so the createTailnet function recreates the SSH server when it wants to recreate the tailnet?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically in go structs when Close is called it's dead forever, so this seems to not match what most people would expect

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do that, perhaps something for a future refactor? For now I'd like to keep the functionality similar to what it was before. And I think a little complexity contained in a package is fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK, I don't think it's just a little complexity. The s.closed loop took me multiple read throughs to understand what it was trying to do. You should get a second opinion

s.mu.Lock()
defer s.mu.Unlock()
if add {
for s.closing != nil {
closing := s.closing
// Wait until close is complete before
// serving a new listener.
s.mu.Unlock()
<-closing
s.mu.Lock()
}
s.wg.Add(1)
s.listeners[l] = struct{}{}
return
}
s.wg.Done()
delete(s.listeners, l)
}

// trackConn registers the connection with the server. If the server is
// closed or the listener is closed, the connection is not registered
// and should be closed.
//
//nolint:revive
func (s *Server) trackConn(l net.Listener, c net.Conn, add bool) (ok bool) {
s.mu.Lock()
defer s.mu.Unlock()
if add {
found := false
for ll := range s.listeners {
if l == ll {
found = true
break
}
}
if s.closing != nil || !found {
// Server or listener closed.
return false
}
s.wg.Add(1)
s.conns[c] = struct{}{}
return true
}
s.wg.Done()
delete(s.conns, c)
return true
}

// Close the server and all active connections. Server can be re-used
// after Close is done.
func (s *Server) Close() error {
s.mu.Lock()

// Guard against multiple calls to Close and
// accepting new connections during close.
if s.closing != nil {
s.mu.Unlock()
return xerrors.New("server is closing")
}
s.closing = make(chan struct{})

// Close all active listeners and connections.
for l := range s.listeners {
_ = l.Close()
}
for c := range s.conns {
_ = c.Close()
}

// Close the underlying SSH server.
err := s.srv.Close()
s.serveWg.Wait()

s.mu.Unlock()
s.wg.Wait() // Wait for all goroutines to exit.

s.mu.Lock()
close(s.closing)
s.closing = nil
s.mu.Unlock()

return err
}

Expand Down