Skip to content

feat: add support for multiple banners #13081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,35 +155,35 @@ func New(options Options) Agent {
hardCtx, hardCancel := context.WithCancel(context.Background())
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
a := &agent{
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
reportMetadataInterval: options.ReportMetadataInterval,
serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
subsystems: options.Subsystems,
addresses: options.Addresses,
syscaller: options.Syscaller,
modifiedProcs: options.ModifiedProcesses,
processManagementTick: options.ProcessManagementTick,
logSender: agentsdk.NewLogSender(options.Logger),
tailnetListenPort: options.TailnetListenPort,
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
logger: options.Logger,
gracefulCtx: gracefulCtx,
gracefulCancel: gracefulCancel,
hardCtx: hardCtx,
hardCancel: hardCancel,
coordDisconnected: make(chan struct{}),
environmentVariables: options.EnvironmentVariables,
client: options.Client,
exchangeToken: options.ExchangeToken,
filesystem: options.Filesystem,
logDir: options.LogDir,
tempDir: options.TempDir,
scriptDataDir: options.ScriptDataDir,
lifecycleUpdate: make(chan struct{}, 1),
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
ignorePorts: options.IgnorePorts,
portCacheDuration: options.PortCacheDuration,
reportMetadataInterval: options.ReportMetadataInterval,
notificationBannersRefreshInterval: options.ServiceBannerRefreshInterval,
sshMaxTimeout: options.SSHMaxTimeout,
subsystems: options.Subsystems,
addresses: options.Addresses,
syscaller: options.Syscaller,
modifiedProcs: options.ModifiedProcesses,
processManagementTick: options.ProcessManagementTick,
logSender: agentsdk.NewLogSender(options.Logger),

prometheusRegistry: prometheusRegistry,
metrics: newAgentMetrics(prometheusRegistry),
Expand All @@ -193,7 +193,7 @@ func New(options Options) Agent {
// that gets closed on disconnection. This is used to wait for graceful disconnection from the
// coordinator during shut down.
close(a.coordDisconnected)
a.serviceBanner.Store(new(codersdk.ServiceBannerConfig))
a.notificationBanners.Store(new([]codersdk.BannerConfig))
a.sessionToken.Store(new(string))
a.init()
return a
Expand Down Expand Up @@ -231,14 +231,14 @@ type agent struct {

environmentVariables map[string]string

manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
reportMetadataInterval time.Duration
scriptRunner *agentscripts.Runner
serviceBanner atomic.Pointer[codersdk.ServiceBannerConfig] // serviceBanner is atomic because it is periodically updated.
serviceBannerRefreshInterval time.Duration
sessionToken atomic.Pointer[string]
sshServer *agentssh.Server
sshMaxTimeout time.Duration
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
reportMetadataInterval time.Duration
scriptRunner *agentscripts.Runner
notificationBanners atomic.Pointer[[]codersdk.BannerConfig] // notificationBanners is atomic because it is periodically updated.
notificationBannersRefreshInterval time.Duration
sessionToken atomic.Pointer[string]
sshServer *agentssh.Server
sshMaxTimeout time.Duration

lifecycleUpdate chan struct{}
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
Expand Down Expand Up @@ -272,11 +272,11 @@ func (a *agent) TailnetConn() *tailnet.Conn {
func (a *agent) init() {
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
MaxTimeout: a.sshMaxTimeout,
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
NotificationBanners: func() *[]codersdk.BannerConfig { return a.notificationBanners.Load() },
UpdateEnv: a.updateCommandEnv,
WorkingDirectory: func() string { return a.manifest.Load().Directory },
})
if err != nil {
panic(err)
Expand Down Expand Up @@ -709,23 +709,26 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
// (and must be done before the session actually starts).
func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error {
aAPI := proto.NewDRPCAgentClient(conn)
ticker := time.NewTicker(a.serviceBannerRefreshInterval)
ticker := time.NewTicker(a.notificationBannersRefreshInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{})
if err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
a.logger.Error(ctx, "failed to update service banner", slog.Error(err))
a.logger.Error(ctx, "failed to update notification banners", slog.Error(err))
return err
}
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner)
banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners))
for _, bannerProto := range bannersProto.NotificationBanners {
banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto))
}
a.notificationBanners.Store(&banners)
}
}
}
Expand Down Expand Up @@ -757,15 +760,18 @@ func (a *agent) run() (retErr error) {
// redial the coder server and retry.
connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn)

connMan.start("init service banner", gracefulShutdownBehaviorStop,
connMan.start("init notification banners", gracefulShutdownBehaviorStop,
func(ctx context.Context, conn drpc.Conn) error {
aAPI := proto.NewDRPCAgentClient(conn)
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{})
if err != nil {
return xerrors.Errorf("fetch service banner: %w", err)
}
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
a.serviceBanner.Store(&serviceBanner)
banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners))
for _, bannerProto := range bannersProto.NotificationBanners {
banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto))
}
a.notificationBanners.Store(&banners)
return nil
},
)
Expand Down
10 changes: 5 additions & 5 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,12 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
// Set new banner func and wait for the agent to call it to update the
// banner.
ready := make(chan struct{}, 2)
client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
client.SetNotificationBannersFunc(func() ([]codersdk.BannerConfig, error) {
select {
case ready <- struct{}{}:
default:
}
return test.banner, nil
return []codersdk.BannerConfig{test.banner}, nil
})
<-ready
<-ready // Wait for two updates to ensure the value has propagated.
Expand Down Expand Up @@ -2193,15 +2193,15 @@ func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
func setupSSHSession(
t *testing.T,
manifest agentsdk.Manifest,
serviceBanner codersdk.ServiceBannerConfig,
banner codersdk.BannerConfig,
prepareFS func(fs afero.Fs),
opts ...func(*agenttest.Client, *agent.Options),
) *ssh.Session {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
return serviceBanner, nil
c.SetNotificationBannersFunc(func() ([]codersdk.BannerConfig, error) {
return []codersdk.BannerConfig{banner}, nil
})
})
//nolint:dogsled
Expand Down
25 changes: 14 additions & 11 deletions agent/agentssh/agentssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type Config struct {
// file will be displayed to the user upon login.
MOTDFile func() string
// ServiceBanner returns the configuration for the Coder service banner.
ServiceBanner func() *codersdk.ServiceBannerConfig
NotificationBanners func() *[]codersdk.BannerConfig
// UpdateEnv updates the environment variables for the command to be
// executed. It can be used to add, modify or replace environment variables.
UpdateEnv func(current []string) (updated []string, err error)
Expand Down Expand Up @@ -123,8 +123,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
if config.MOTDFile == nil {
config.MOTDFile = func() string { return "" }
}
if config.ServiceBanner == nil {
config.ServiceBanner = func() *codersdk.ServiceBannerConfig { return &codersdk.ServiceBannerConfig{} }
if config.NotificationBanners == nil {
config.NotificationBanners = func() *[]codersdk.BannerConfig { return &[]codersdk.BannerConfig{} }
}
if config.WorkingDirectory == nil {
config.WorkingDirectory = func() string {
Expand Down Expand Up @@ -441,12 +441,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
session.DisablePTYEmulation()

if isLoginShell(session.RawCommand()) {
serviceBanner := s.config.ServiceBanner()
if serviceBanner != nil {
err := showServiceBanner(session, serviceBanner)
if err != nil {
logger.Error(ctx, "agent failed to show service banner", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "service_banner").Add(1)
banners := s.config.NotificationBanners()
if banners != nil {
for _, banner := range *banners {
err := showNotificationBanner(session, banner)
if err != nil {
logger.Error(ctx, "agent failed to show service banner", slog.Error(err))
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "notification_banner").Add(1)
break
}
}
}
}
Expand Down Expand Up @@ -891,9 +894,9 @@ func isQuietLogin(fs afero.Fs, rawCommand string) bool {
return err == nil
}

// showServiceBanner will write the service banner if enabled and not blank
// showNotificationBanner will write the service banner if enabled and not blank
// along with a blank line for spacing.
func showServiceBanner(session io.Writer, banner *codersdk.ServiceBannerConfig) error {
func showNotificationBanner(session io.Writer, banner codersdk.BannerConfig) error {
if banner.Enabled && banner.Message != "" {
// The banner supports Markdown so we might want to parse it but Markdown is
// still fairly readable in its raw form.
Expand Down
30 changes: 19 additions & 11 deletions agent/agenttest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ func (c *Client) GetStartupLogs() []agentsdk.Log {
return c.logs
}

func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
c.fakeAgentAPI.SetServiceBannerFunc(f)
func (c *Client) SetNotificationBannersFunc(f func() ([]codersdk.ServiceBannerConfig, error)) {
c.fakeAgentAPI.SetNotificationBannersFunc(f)
}

func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
Expand Down Expand Up @@ -171,31 +171,39 @@ type FakeAgentAPI struct {
lifecycleStates []codersdk.WorkspaceAgentLifecycle
metadata map[string]agentsdk.Metadata

getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
getNotificationBannersFunc func() ([]codersdk.BannerConfig, error)
}

func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
return f.manifest, nil
}

func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
func (*FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
return &agentproto.ServiceBanner{}, nil
}

func (f *FakeAgentAPI) SetNotificationBannersFunc(fn func() ([]codersdk.BannerConfig, error)) {
f.Lock()
defer f.Unlock()
f.getServiceBannerFunc = fn
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
f.getNotificationBannersFunc = fn
f.logger.Info(context.Background(), "updated notification banners")
}

func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
func (f *FakeAgentAPI) GetNotificationBanners(context.Context, *agentproto.GetNotificationBannersRequest) (*agentproto.GetNotificationBannersResponse, error) {
f.Lock()
defer f.Unlock()
if f.getServiceBannerFunc == nil {
return &agentproto.ServiceBanner{}, nil
if f.getNotificationBannersFunc == nil {
return &agentproto.GetNotificationBannersResponse{NotificationBanners: []*agentproto.BannerConfig{}}, nil
}
sb, err := f.getServiceBannerFunc()
banners, err := f.getNotificationBannersFunc()
if err != nil {
return nil, err
}
return agentsdk.ProtoFromServiceBanner(sb), nil
bannersProto := make([]*agentproto.BannerConfig, 0, len(banners))
for _, banner := range banners {
bannersProto = append(bannersProto, agentsdk.ProtoFromBannerConfig(banner))
}
return &agentproto.GetNotificationBannersResponse{NotificationBanners: bannersProto}, nil
}

func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {
Expand Down
Loading
Loading