Skip to content

Commit 02b6743

Browse files
committed
add heartbeat
1 parent ff1032a commit 02b6743

File tree

4 files changed

+117
-14
lines changed

4 files changed

+117
-14
lines changed

coderd/coderd.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,9 +1183,9 @@ func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, name string
11831183
api.Logger.Info(ctx, "starting in-memory provisioner daemon", slog.F("name", name))
11841184
logger := api.Logger.Named(fmt.Sprintf("inmem-provisionerd-%s", name))
11851185
srv, err := provisionerdserver.NewServer(
1186-
api.ctx,
1186+
api.ctx, // use the same ctx as the API
11871187
api.AccessURL,
1188-
uuid.New(),
1188+
daemon.ID,
11891189
logger,
11901190
daemon.Provisioners,
11911191
provisionerdserver.Tags(daemon.Tags),

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ import (
4343
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
4444
)
4545

46-
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
47-
// canceling and returning an empty job.
48-
const DefaultAcquireJobLongPollDur = time.Second * 5
46+
const (
47+
// DefaultAcquireJobLongPollDur is the time the (deprecated) AcquireJob rpc waits to try to obtain a job before
48+
// canceling and returning an empty job.
49+
DefaultAcquireJobLongPollDur = time.Second * 5
50+
51+
// DefaultHeartbeatInterval is the interval at which the provisioner daemon
52+
// will update its last seen at timestamp in the database.
53+
DefaultHeartbeatInterval = time.Minute
54+
)
4955

5056
type Options struct {
5157
OIDCConfig httpmw.OAuth2Config
@@ -55,6 +61,15 @@ type Options struct {
5561

5662
// AcquireJobLongPollDur is used in tests
5763
AcquireJobLongPollDur time.Duration
64+
65+
// HeartbeatInterval is the interval at which the provisioner daemon
66+
// will update its last seen at timestamp in the database.
67+
HeartbeatInterval time.Duration
68+
69+
// HeartbeatFn is the function that will be called at the interval
70+
// specified by HeartbeatInterval.
71+
// This is only used in tests.
72+
HeartbeatFn func(context.Context) error
5873
}
5974

6075
type server struct {
@@ -84,6 +99,9 @@ type server struct {
8499
TimeNowFn func() time.Time
85100

86101
acquireJobLongPollDur time.Duration
102+
103+
HeartbeatInterval time.Duration
104+
HeartbeatFn func(ctx context.Context) error
87105
}
88106

89107
// We use the null byte (0x00) in generating a canonical map key for tags, so
@@ -160,7 +178,11 @@ func NewServer(
160178
if options.AcquireJobLongPollDur == 0 {
161179
options.AcquireJobLongPollDur = DefaultAcquireJobLongPollDur
162180
}
163-
return &server{
181+
if options.HeartbeatInterval == 0 {
182+
options.HeartbeatInterval = DefaultHeartbeatInterval
183+
}
184+
185+
s := &server{
164186
lifecycleCtx: lifecycleCtx,
165187
AccessURL: accessURL,
166188
ID: id,
@@ -181,7 +203,13 @@ func NewServer(
181203
OIDCConfig: options.OIDCConfig,
182204
TimeNowFn: options.TimeNowFn,
183205
acquireJobLongPollDur: options.AcquireJobLongPollDur,
184-
}, nil
206+
HeartbeatInterval: options.HeartbeatInterval,
207+
HeartbeatFn: options.HeartbeatFn,
208+
}
209+
210+
go s.heartbeat()
211+
212+
return s, nil
185213
}
186214

187215
// timeNow should be used when trying to get the current time for math
@@ -193,6 +221,44 @@ func (s *server) timeNow() time.Time {
193221
return dbtime.Now()
194222
}
195223

224+
// heartbeat runs heartbeatOnce at the interval specified by HeartbeatInterval
225+
// until the lifecycle context is canceled.
226+
func (s *server) heartbeat() {
227+
tick := time.NewTicker(time.Nanosecond)
228+
defer tick.Stop()
229+
for {
230+
select {
231+
case <-s.lifecycleCtx.Done():
232+
return
233+
case <-tick.C:
234+
hbCtx, hbCancel := context.WithTimeout(s.lifecycleCtx, s.HeartbeatInterval)
235+
if err := s.heartbeatOnce(hbCtx); err != nil {
236+
s.Logger.Error(hbCtx, "heartbeat failed", slog.Error(err))
237+
}
238+
hbCancel()
239+
tick.Reset(s.HeartbeatInterval)
240+
}
241+
}
242+
}
243+
244+
// heartbeatOnce updates the last seen at timestamp in the database.
245+
// If HeartbeatFn is set, it will be called instead.
246+
func (s *server) heartbeatOnce(ctx context.Context) error {
247+
if s.HeartbeatFn != nil {
248+
return s.HeartbeatFn(ctx)
249+
}
250+
251+
if s.lifecycleCtx.Err() != nil {
252+
return nil
253+
}
254+
255+
//nolint:gocritic // Provisionerd has specific authz rules.
256+
return s.Database.UpdateProvisionerDaemonLastSeenAt(dbauthz.AsProvisionerd(ctx), database.UpdateProvisionerDaemonLastSeenAtParams{
257+
ID: s.ID,
258+
LastSeenAt: sql.NullTime{Time: s.timeNow(), Valid: true},
259+
})
260+
}
261+
196262
// AcquireJob queries the database to lock a job.
197263
//
198264
// Deprecated: This method is only available for back-level provisioner daemons.

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,29 @@ func TestAcquireJobWithCancel_Cancel(t *testing.T) {
9595
require.Equal(t, "", job.JobId)
9696
}
9797

98+
func TestHeartbeat(t *testing.T) {
99+
t.Parallel()
100+
101+
ctx, cancel := context.WithCancel(context.Background())
102+
t.Cleanup(cancel)
103+
heartbeatChan := make(chan struct{})
104+
heartbeatFn := func(context.Context) error {
105+
heartbeatChan <- struct{}{}
106+
return nil
107+
}
108+
//nolint:dogsled // this is a test
109+
_, _, _ = setup(t, false, &overrides{
110+
ctx: ctx,
111+
heartbeatFn: heartbeatFn,
112+
heartbeatInterval: testutil.IntervalFast,
113+
})
114+
115+
<-heartbeatChan
116+
cancel()
117+
close(heartbeatChan)
118+
<-time.After(testutil.IntervalFast)
119+
}
120+
98121
func TestAcquireJob(t *testing.T) {
99122
t.Parallel()
100123

@@ -1686,19 +1709,20 @@ func TestInsertWorkspaceResource(t *testing.T) {
16861709
}
16871710

16881711
type overrides struct {
1712+
ctx context.Context
16891713
deploymentValues *codersdk.DeploymentValues
16901714
externalAuthConfigs []*externalauth.Config
16911715
id *uuid.UUID
16921716
templateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
16931717
userQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
16941718
timeNowFn func() time.Time
16951719
acquireJobLongPollDuration time.Duration
1720+
heartbeatFn func(ctx context.Context) error
1721+
heartbeatInterval time.Duration
16961722
}
16971723

16981724
func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisionerDaemonServer, database.Store, pubsub.Pubsub) {
16991725
t.Helper()
1700-
ctx, cancel := context.WithCancel(context.Background())
1701-
t.Cleanup(cancel)
17021726
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
17031727
db := dbmem.New()
17041728
ps := pubsub.NewInMemory()
@@ -1710,6 +1734,14 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17101734
var timeNowFn func() time.Time
17111735
pollDur := time.Duration(0)
17121736
if ov != nil {
1737+
if ov.ctx == nil {
1738+
ctx, cancel := context.WithCancel(context.Background())
1739+
t.Cleanup(cancel)
1740+
ov.ctx = ctx
1741+
}
1742+
if ov.heartbeatInterval == 0 {
1743+
ov.heartbeatInterval = testutil.IntervalMedium
1744+
}
17131745
if ov.deploymentValues != nil {
17141746
deploymentValues = ov.deploymentValues
17151747
}
@@ -1744,15 +1776,15 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17441776
}
17451777

17461778
srv, err := provisionerdserver.NewServer(
1747-
ctx,
1779+
ov.ctx,
17481780
&url.URL{},
17491781
srvID,
17501782
slogtest.Make(t, &slogtest.Options{IgnoreErrors: ignoreLogErrors}),
17511783
[]database.ProvisionerType{database.ProvisionerTypeEcho},
17521784
provisionerdserver.Tags{},
17531785
db,
17541786
ps,
1755-
provisionerdserver.NewAcquirer(ctx, logger.Named("acquirer"), db, ps),
1787+
provisionerdserver.NewAcquirer(ov.ctx, logger.Named("acquirer"), db, ps),
17561788
telemetry.NewNoop(),
17571789
trace.NewNoopTracerProvider().Tracer("noop"),
17581790
&atomic.Pointer[proto.QuotaCommitter]{},
@@ -1765,6 +1797,8 @@ func setup(t *testing.T, ignoreLogErrors bool, ov *overrides) (proto.DRPCProvisi
17651797
TimeNowFn: timeNowFn,
17661798
OIDCConfig: &oauth2.Config{},
17671799
AcquireJobLongPollDur: pollDur,
1800+
HeartbeatInterval: ov.heartbeatInterval,
1801+
HeartbeatFn: ov.heartbeatFn,
17681802
},
17691803
)
17701804
require.NoError(t, err)

enterprise/coderd/provisionerdaemons.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
234234
}
235235

236236
// Create the daemon in the database.
237-
_, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
237+
daemon, err := api.Database.UpsertProvisionerDaemon(authCtx, database.UpsertProvisionerDaemonParams{
238238
Name: name,
239239
Provisioners: provisioners,
240240
Tags: tags,
@@ -295,11 +295,13 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
295295
}
296296
mux := drpcmux.New()
297297
logger := api.Logger.Named(fmt.Sprintf("ext-provisionerd-%s", name))
298+
srvCtx, srvCancel := context.WithCancel(ctx)
299+
defer srvCancel()
298300
logger.Info(ctx, "starting external provisioner daemon")
299301
srv, err := provisionerdserver.NewServer(
300-
api.ctx,
302+
srvCtx,
301303
api.AccessURL,
302-
id,
304+
daemon.ID,
303305
logger,
304306
provisioners,
305307
tags,
@@ -339,6 +341,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
339341
},
340342
})
341343
err = server.Serve(ctx, session)
344+
srvCancel()
342345
logger.Info(ctx, "provisioner daemon disconnected", slog.Error(err))
343346
if err != nil && !xerrors.Is(err, io.EOF) {
344347
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))

0 commit comments

Comments
 (0)