diff --git a/agent/agentssh/agentssh.go b/agent/agentssh/agentssh.go index d6b6613ed92de..a22f86836d147 100644 --- a/agent/agentssh/agentssh.go +++ b/agent/agentssh/agentssh.go @@ -172,6 +172,7 @@ func (s *Server) sessionHandler(session ssh.Session) { ctx := session.Context() + extraEnv := make([]string, 0) x11, hasX11 := session.X11() if hasX11 { handled := s.x11Handler(session.Context(), x11) @@ -179,6 +180,7 @@ func (s *Server) sessionHandler(session ssh.Session) { _ = session.Exit(1) return } + extraEnv = append(extraEnv, fmt.Sprintf("DISPLAY=:%d.0", x11.ScreenNumber)) } switch ss := session.Subsystem(); ss { @@ -192,7 +194,7 @@ func (s *Server) sessionHandler(session ssh.Session) { return } - err := s.sessionStart(session) + err := s.sessionStart(session, extraEnv) var exitError *exec.ExitError if xerrors.As(err, &exitError) { s.logger.Debug(ctx, "ssh session returned", slog.Error(exitError)) @@ -209,9 +211,9 @@ func (s *Server) sessionHandler(session ssh.Session) { _ = session.Exit(0) } -func (s *Server) sessionStart(session ssh.Session) (retErr error) { +func (s *Server) sessionStart(session ssh.Session, extraEnv []string) (retErr error) { ctx := session.Context() - env := session.Environ() + env := append(session.Environ(), extraEnv...) var magicType string for index, kv := range env { if !strings.HasPrefix(kv, MagicSessionTypeEnvironmentVariable) { diff --git a/agent/agentssh/x11.go b/agent/agentssh/x11.go index 6d50d4a99078f..b301326a0acb3 100644 --- a/agent/agentssh/x11.go +++ b/agent/agentssh/x11.go @@ -52,7 +52,14 @@ func (s *Server) x11Handler(ctx ssh.Context, x11 ssh.X11) bool { s.logger.Warn(ctx, "failed to get server connection") return false } - listener, err := net.Listen("unix", filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber))) + // We want to overwrite the socket so that subsequent connections will succeed. + socketPath := filepath.Join(s.x11SocketDir, fmt.Sprintf("X%d", x11.ScreenNumber)) + err := os.Remove(socketPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + s.logger.Warn(ctx, "failed to remove existing X11 socket", slog.Error(err)) + return false + } + listener, err := net.Listen("unix", socketPath) if err != nil { s.logger.Warn(ctx, "failed to listen for X11", slog.Error(err)) return false