Skip to content

Commit d8e0be6

Browse files
authored
feat: add support for multiple banners (#13081)
1 parent a4bd50c commit d8e0be6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1473
-810
lines changed

agent/agent.go

+58-52
Original file line numberDiff line numberDiff line change
@@ -155,35 +155,35 @@ func New(options Options) Agent {
155155
hardCtx, hardCancel := context.WithCancel(context.Background())
156156
gracefulCtx, gracefulCancel := context.WithCancel(hardCtx)
157157
a := &agent{
158-
tailnetListenPort: options.TailnetListenPort,
159-
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
160-
logger: options.Logger,
161-
gracefulCtx: gracefulCtx,
162-
gracefulCancel: gracefulCancel,
163-
hardCtx: hardCtx,
164-
hardCancel: hardCancel,
165-
coordDisconnected: make(chan struct{}),
166-
environmentVariables: options.EnvironmentVariables,
167-
client: options.Client,
168-
exchangeToken: options.ExchangeToken,
169-
filesystem: options.Filesystem,
170-
logDir: options.LogDir,
171-
tempDir: options.TempDir,
172-
scriptDataDir: options.ScriptDataDir,
173-
lifecycleUpdate: make(chan struct{}, 1),
174-
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
175-
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
176-
ignorePorts: options.IgnorePorts,
177-
portCacheDuration: options.PortCacheDuration,
178-
reportMetadataInterval: options.ReportMetadataInterval,
179-
serviceBannerRefreshInterval: options.ServiceBannerRefreshInterval,
180-
sshMaxTimeout: options.SSHMaxTimeout,
181-
subsystems: options.Subsystems,
182-
addresses: options.Addresses,
183-
syscaller: options.Syscaller,
184-
modifiedProcs: options.ModifiedProcesses,
185-
processManagementTick: options.ProcessManagementTick,
186-
logSender: agentsdk.NewLogSender(options.Logger),
158+
tailnetListenPort: options.TailnetListenPort,
159+
reconnectingPTYTimeout: options.ReconnectingPTYTimeout,
160+
logger: options.Logger,
161+
gracefulCtx: gracefulCtx,
162+
gracefulCancel: gracefulCancel,
163+
hardCtx: hardCtx,
164+
hardCancel: hardCancel,
165+
coordDisconnected: make(chan struct{}),
166+
environmentVariables: options.EnvironmentVariables,
167+
client: options.Client,
168+
exchangeToken: options.ExchangeToken,
169+
filesystem: options.Filesystem,
170+
logDir: options.LogDir,
171+
tempDir: options.TempDir,
172+
scriptDataDir: options.ScriptDataDir,
173+
lifecycleUpdate: make(chan struct{}, 1),
174+
lifecycleReported: make(chan codersdk.WorkspaceAgentLifecycle, 1),
175+
lifecycleStates: []agentsdk.PostLifecycleRequest{{State: codersdk.WorkspaceAgentLifecycleCreated}},
176+
ignorePorts: options.IgnorePorts,
177+
portCacheDuration: options.PortCacheDuration,
178+
reportMetadataInterval: options.ReportMetadataInterval,
179+
notificationBannersRefreshInterval: options.ServiceBannerRefreshInterval,
180+
sshMaxTimeout: options.SSHMaxTimeout,
181+
subsystems: options.Subsystems,
182+
addresses: options.Addresses,
183+
syscaller: options.Syscaller,
184+
modifiedProcs: options.ModifiedProcesses,
185+
processManagementTick: options.ProcessManagementTick,
186+
logSender: agentsdk.NewLogSender(options.Logger),
187187

188188
prometheusRegistry: prometheusRegistry,
189189
metrics: newAgentMetrics(prometheusRegistry),
@@ -193,7 +193,7 @@ func New(options Options) Agent {
193193
// that gets closed on disconnection. This is used to wait for graceful disconnection from the
194194
// coordinator during shut down.
195195
close(a.coordDisconnected)
196-
a.serviceBanner.Store(new(codersdk.ServiceBannerConfig))
196+
a.notificationBanners.Store(new([]codersdk.BannerConfig))
197197
a.sessionToken.Store(new(string))
198198
a.init()
199199
return a
@@ -231,14 +231,14 @@ type agent struct {
231231

232232
environmentVariables map[string]string
233233

234-
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
235-
reportMetadataInterval time.Duration
236-
scriptRunner *agentscripts.Runner
237-
serviceBanner atomic.Pointer[codersdk.ServiceBannerConfig] // serviceBanner is atomic because it is periodically updated.
238-
serviceBannerRefreshInterval time.Duration
239-
sessionToken atomic.Pointer[string]
240-
sshServer *agentssh.Server
241-
sshMaxTimeout time.Duration
234+
manifest atomic.Pointer[agentsdk.Manifest] // manifest is atomic because values can change after reconnection.
235+
reportMetadataInterval time.Duration
236+
scriptRunner *agentscripts.Runner
237+
notificationBanners atomic.Pointer[[]codersdk.BannerConfig] // notificationBanners is atomic because it is periodically updated.
238+
notificationBannersRefreshInterval time.Duration
239+
sessionToken atomic.Pointer[string]
240+
sshServer *agentssh.Server
241+
sshMaxTimeout time.Duration
242242

243243
lifecycleUpdate chan struct{}
244244
lifecycleReported chan codersdk.WorkspaceAgentLifecycle
@@ -272,11 +272,11 @@ func (a *agent) TailnetConn() *tailnet.Conn {
272272
func (a *agent) init() {
273273
// pass the "hard" context because we explicitly close the SSH server as part of graceful shutdown.
274274
sshSrv, err := agentssh.NewServer(a.hardCtx, a.logger.Named("ssh-server"), a.prometheusRegistry, a.filesystem, &agentssh.Config{
275-
MaxTimeout: a.sshMaxTimeout,
276-
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
277-
ServiceBanner: func() *codersdk.ServiceBannerConfig { return a.serviceBanner.Load() },
278-
UpdateEnv: a.updateCommandEnv,
279-
WorkingDirectory: func() string { return a.manifest.Load().Directory },
275+
MaxTimeout: a.sshMaxTimeout,
276+
MOTDFile: func() string { return a.manifest.Load().MOTDFile },
277+
NotificationBanners: func() *[]codersdk.BannerConfig { return a.notificationBanners.Load() },
278+
UpdateEnv: a.updateCommandEnv,
279+
WorkingDirectory: func() string { return a.manifest.Load().Directory },
280280
})
281281
if err != nil {
282282
panic(err)
@@ -709,23 +709,26 @@ func (a *agent) setLifecycle(state codersdk.WorkspaceAgentLifecycle) {
709709
// (and must be done before the session actually starts).
710710
func (a *agent) fetchServiceBannerLoop(ctx context.Context, conn drpc.Conn) error {
711711
aAPI := proto.NewDRPCAgentClient(conn)
712-
ticker := time.NewTicker(a.serviceBannerRefreshInterval)
712+
ticker := time.NewTicker(a.notificationBannersRefreshInterval)
713713
defer ticker.Stop()
714714
for {
715715
select {
716716
case <-ctx.Done():
717717
return ctx.Err()
718718
case <-ticker.C:
719-
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
719+
bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{})
720720
if err != nil {
721721
if ctx.Err() != nil {
722722
return ctx.Err()
723723
}
724-
a.logger.Error(ctx, "failed to update service banner", slog.Error(err))
724+
a.logger.Error(ctx, "failed to update notification banners", slog.Error(err))
725725
return err
726726
}
727-
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
728-
a.serviceBanner.Store(&serviceBanner)
727+
banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners))
728+
for _, bannerProto := range bannersProto.NotificationBanners {
729+
banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto))
730+
}
731+
a.notificationBanners.Store(&banners)
729732
}
730733
}
731734
}
@@ -757,15 +760,18 @@ func (a *agent) run() (retErr error) {
757760
// redial the coder server and retry.
758761
connMan := newAPIConnRoutineManager(a.gracefulCtx, a.hardCtx, a.logger, conn)
759762

760-
connMan.start("init service banner", gracefulShutdownBehaviorStop,
763+
connMan.start("init notification banners", gracefulShutdownBehaviorStop,
761764
func(ctx context.Context, conn drpc.Conn) error {
762765
aAPI := proto.NewDRPCAgentClient(conn)
763-
sbp, err := aAPI.GetServiceBanner(ctx, &proto.GetServiceBannerRequest{})
766+
bannersProto, err := aAPI.GetNotificationBanners(ctx, &proto.GetNotificationBannersRequest{})
764767
if err != nil {
765768
return xerrors.Errorf("fetch service banner: %w", err)
766769
}
767-
serviceBanner := agentsdk.ServiceBannerFromProto(sbp)
768-
a.serviceBanner.Store(&serviceBanner)
770+
banners := make([]codersdk.BannerConfig, 0, len(bannersProto.NotificationBanners))
771+
for _, bannerProto := range bannersProto.NotificationBanners {
772+
banners = append(banners, agentsdk.BannerConfigFromProto(bannerProto))
773+
}
774+
a.notificationBanners.Store(&banners)
769775
return nil
770776
},
771777
)

agent/agent_test.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -614,12 +614,12 @@ func TestAgent_Session_TTY_MOTD_Update(t *testing.T) {
614614
// Set new banner func and wait for the agent to call it to update the
615615
// banner.
616616
ready := make(chan struct{}, 2)
617-
client.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
617+
client.SetNotificationBannersFunc(func() ([]codersdk.BannerConfig, error) {
618618
select {
619619
case ready <- struct{}{}:
620620
default:
621621
}
622-
return test.banner, nil
622+
return []codersdk.BannerConfig{test.banner}, nil
623623
})
624624
<-ready
625625
<-ready // Wait for two updates to ensure the value has propagated.
@@ -2193,15 +2193,15 @@ func setupAgentSSHClient(ctx context.Context, t *testing.T) *ssh.Client {
21932193
func setupSSHSession(
21942194
t *testing.T,
21952195
manifest agentsdk.Manifest,
2196-
serviceBanner codersdk.ServiceBannerConfig,
2196+
banner codersdk.BannerConfig,
21972197
prepareFS func(fs afero.Fs),
21982198
opts ...func(*agenttest.Client, *agent.Options),
21992199
) *ssh.Session {
22002200
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
22012201
defer cancel()
22022202
opts = append(opts, func(c *agenttest.Client, o *agent.Options) {
2203-
c.SetServiceBannerFunc(func() (codersdk.ServiceBannerConfig, error) {
2204-
return serviceBanner, nil
2203+
c.SetNotificationBannersFunc(func() ([]codersdk.BannerConfig, error) {
2204+
return []codersdk.BannerConfig{banner}, nil
22052205
})
22062206
})
22072207
//nolint:dogsled

agent/agentssh/agentssh.go

+14-11
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ type Config struct {
6363
// file will be displayed to the user upon login.
6464
MOTDFile func() string
6565
// ServiceBanner returns the configuration for the Coder service banner.
66-
ServiceBanner func() *codersdk.ServiceBannerConfig
66+
NotificationBanners func() *[]codersdk.BannerConfig
6767
// UpdateEnv updates the environment variables for the command to be
6868
// executed. It can be used to add, modify or replace environment variables.
6969
UpdateEnv func(current []string) (updated []string, err error)
@@ -123,8 +123,8 @@ func NewServer(ctx context.Context, logger slog.Logger, prometheusRegistry *prom
123123
if config.MOTDFile == nil {
124124
config.MOTDFile = func() string { return "" }
125125
}
126-
if config.ServiceBanner == nil {
127-
config.ServiceBanner = func() *codersdk.ServiceBannerConfig { return &codersdk.ServiceBannerConfig{} }
126+
if config.NotificationBanners == nil {
127+
config.NotificationBanners = func() *[]codersdk.BannerConfig { return &[]codersdk.BannerConfig{} }
128128
}
129129
if config.WorkingDirectory == nil {
130130
config.WorkingDirectory = func() string {
@@ -441,12 +441,15 @@ func (s *Server) startPTYSession(logger slog.Logger, session ptySession, magicTy
441441
session.DisablePTYEmulation()
442442

443443
if isLoginShell(session.RawCommand()) {
444-
serviceBanner := s.config.ServiceBanner()
445-
if serviceBanner != nil {
446-
err := showServiceBanner(session, serviceBanner)
447-
if err != nil {
448-
logger.Error(ctx, "agent failed to show service banner", slog.Error(err))
449-
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "service_banner").Add(1)
444+
banners := s.config.NotificationBanners()
445+
if banners != nil {
446+
for _, banner := range *banners {
447+
err := showNotificationBanner(session, banner)
448+
if err != nil {
449+
logger.Error(ctx, "agent failed to show service banner", slog.Error(err))
450+
s.metrics.sessionErrors.WithLabelValues(magicTypeLabel, "yes", "notification_banner").Add(1)
451+
break
452+
}
450453
}
451454
}
452455
}
@@ -891,9 +894,9 @@ func isQuietLogin(fs afero.Fs, rawCommand string) bool {
891894
return err == nil
892895
}
893896

894-
// showServiceBanner will write the service banner if enabled and not blank
897+
// showNotificationBanner will write the service banner if enabled and not blank
895898
// along with a blank line for spacing.
896-
func showServiceBanner(session io.Writer, banner *codersdk.ServiceBannerConfig) error {
899+
func showNotificationBanner(session io.Writer, banner codersdk.BannerConfig) error {
897900
if banner.Enabled && banner.Message != "" {
898901
// The banner supports Markdown so we might want to parse it but Markdown is
899902
// still fairly readable in its raw form.

agent/agenttest/client.go

+19-11
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ func (c *Client) GetStartupLogs() []agentsdk.Log {
138138
return c.logs
139139
}
140140

141-
func (c *Client) SetServiceBannerFunc(f func() (codersdk.ServiceBannerConfig, error)) {
142-
c.fakeAgentAPI.SetServiceBannerFunc(f)
141+
func (c *Client) SetNotificationBannersFunc(f func() ([]codersdk.ServiceBannerConfig, error)) {
142+
c.fakeAgentAPI.SetNotificationBannersFunc(f)
143143
}
144144

145145
func (c *Client) PushDERPMapUpdate(update *tailcfg.DERPMap) error {
@@ -171,31 +171,39 @@ type FakeAgentAPI struct {
171171
lifecycleStates []codersdk.WorkspaceAgentLifecycle
172172
metadata map[string]agentsdk.Metadata
173173

174-
getServiceBannerFunc func() (codersdk.ServiceBannerConfig, error)
174+
getNotificationBannersFunc func() ([]codersdk.BannerConfig, error)
175175
}
176176

177177
func (f *FakeAgentAPI) GetManifest(context.Context, *agentproto.GetManifestRequest) (*agentproto.Manifest, error) {
178178
return f.manifest, nil
179179
}
180180

181-
func (f *FakeAgentAPI) SetServiceBannerFunc(fn func() (codersdk.ServiceBannerConfig, error)) {
181+
func (*FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
182+
return &agentproto.ServiceBanner{}, nil
183+
}
184+
185+
func (f *FakeAgentAPI) SetNotificationBannersFunc(fn func() ([]codersdk.BannerConfig, error)) {
182186
f.Lock()
183187
defer f.Unlock()
184-
f.getServiceBannerFunc = fn
185-
f.logger.Info(context.Background(), "updated ServiceBannerFunc")
188+
f.getNotificationBannersFunc = fn
189+
f.logger.Info(context.Background(), "updated notification banners")
186190
}
187191

188-
func (f *FakeAgentAPI) GetServiceBanner(context.Context, *agentproto.GetServiceBannerRequest) (*agentproto.ServiceBanner, error) {
192+
func (f *FakeAgentAPI) GetNotificationBanners(context.Context, *agentproto.GetNotificationBannersRequest) (*agentproto.GetNotificationBannersResponse, error) {
189193
f.Lock()
190194
defer f.Unlock()
191-
if f.getServiceBannerFunc == nil {
192-
return &agentproto.ServiceBanner{}, nil
195+
if f.getNotificationBannersFunc == nil {
196+
return &agentproto.GetNotificationBannersResponse{NotificationBanners: []*agentproto.BannerConfig{}}, nil
193197
}
194-
sb, err := f.getServiceBannerFunc()
198+
banners, err := f.getNotificationBannersFunc()
195199
if err != nil {
196200
return nil, err
197201
}
198-
return agentsdk.ProtoFromServiceBanner(sb), nil
202+
bannersProto := make([]*agentproto.BannerConfig, 0, len(banners))
203+
for _, banner := range banners {
204+
bannersProto = append(bannersProto, agentsdk.ProtoFromBannerConfig(banner))
205+
}
206+
return &agentproto.GetNotificationBannersResponse{NotificationBanners: bannersProto}, nil
199207
}
200208

201209
func (f *FakeAgentAPI) UpdateStats(ctx context.Context, req *agentproto.UpdateStatsRequest) (*agentproto.UpdateStatsResponse, error) {

0 commit comments

Comments
 (0)