Skip to content

Commit 27f53aa

Browse files
committed
Add pointer.Handle to atomically obtain references
This uses a context to ensure the same value persists through multiple executions to `Load()`.
1 parent d0526ed commit 27f53aa

File tree

8 files changed

+112
-36
lines changed

8 files changed

+112
-36
lines changed

coderd/coderd.go

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import (
77
"net/url"
88
"path/filepath"
99
"sync"
10-
"sync/atomic"
1110
"time"
1211

1312
"github.com/andybalholm/brotli"
@@ -33,6 +32,7 @@ import (
3332
"github.com/coder/coder/coderd/httpapi"
3433
"github.com/coder/coder/coderd/httpmw"
3534
"github.com/coder/coder/coderd/metricscache"
35+
"github.com/coder/coder/coderd/pointer"
3636
"github.com/coder/coder/coderd/rbac"
3737
"github.com/coder/coder/coderd/telemetry"
3838
"github.com/coder/coder/coderd/tracing"
@@ -148,9 +148,8 @@ func New(options *Options) *API {
148148
Logger: options.Logger,
149149
},
150150
metricsCache: metricsCache,
151-
Auditor: atomic.Pointer[audit.Auditor]{},
151+
Auditor: pointer.New(options.Auditor),
152152
}
153-
api.Auditor.Store(&options.Auditor)
154153

155154
if options.TailscaleEnable {
156155
api.workspaceAgentCache = wsconncache.New(api.dialWorkspaceAgentTailnet, 0)
@@ -495,7 +494,6 @@ func New(options *Options) *API {
495494
r.Use(apiKeyMiddleware)
496495
r.Get("/", nopEntitlements)
497496
})
498-
r.HandleFunc("/licenses", unsupported)
499497
})
500498

501499
r.NotFound(compressHandler(http.HandlerFunc(api.siteHandler.ServeHTTP)).ServeHTTP)
@@ -504,19 +502,19 @@ func New(options *Options) *API {
504502

505503
type API struct {
506504
*Options
505+
Auditor *pointer.Handle[audit.Auditor]
506+
HTTPAuth *HTTPAuthorizer
507507

508-
derpServer *derp.Server
508+
// APIHandler serves "/api/v2" and all children routes.
509+
APIHandler chi.Router
510+
RootHandler chi.Router
509511

510-
Auditor atomic.Pointer[audit.Auditor]
511-
RootHandler chi.Router
512-
APIHandler chi.Router
512+
derpServer *derp.Server
513+
metricsCache *metricscache.Cache
513514
siteHandler http.Handler
514515
websocketWaitMutex sync.Mutex
515516
websocketWaitGroup sync.WaitGroup
516517
workspaceAgentCache *wsconncache.Cache
517-
HTTPAuth *HTTPAuthorizer
518-
519-
metricsCache *metricscache.Cache
520518
}
521519

522520
// Close waits for all WebSocket connections to drain before returning.
@@ -564,11 +562,3 @@ func nopEntitlements(rw http.ResponseWriter, _ *http.Request) {
564562
HasLicense: false,
565563
})
566564
}
567-
568-
func unsupported(rw http.ResponseWriter, _ *http.Request) {
569-
httpapi.Write(rw, http.StatusNotFound, codersdk.Response{
570-
Message: "Unsupported",
571-
Detail: "These endpoints are not supported in AGPL-licensed Coder",
572-
Validations: nil,
573-
})
574-
}

coderd/pointer/pointer.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package pointer
2+
3+
import (
4+
"context"
5+
6+
"go.uber.org/atomic"
7+
)
8+
9+
func New[T any](value T) *Handle[T] {
10+
h := &Handle[T]{
11+
key: struct{}{},
12+
ptr: atomic.Pointer[T]{},
13+
}
14+
h.Store(value)
15+
return h
16+
}
17+
18+
// Handle loads the stored value into a context, and returns
19+
// a context with the attached value. It's intention is to
20+
// hold a single handle for the lifecycle of a request.
21+
type Handle[T any] struct {
22+
key struct{}
23+
ptr atomic.Pointer[T]
24+
}
25+
26+
func (p *Handle[T]) Load(ctx context.Context) (context.Context, T) {
27+
value, ok := ctx.Value(&p.key).(T)
28+
if !ok {
29+
ctx = context.WithValue(ctx, &p.key, *p.ptr.Load())
30+
return p.Load(ctx)
31+
}
32+
return ctx, value
33+
}
34+
35+
func (p *Handle[T]) Store(t T) {
36+
p.ptr.Store(&t)
37+
}

coderd/pointer/pointer_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package pointer_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/coderd/pointer"
10+
)
11+
12+
func TestHandle(t *testing.T) {
13+
t.Parallel()
14+
t.Run("Single", func(t *testing.T) {
15+
t.Parallel()
16+
ptr := pointer.New("hello")
17+
ctx := context.Background()
18+
ctx, value := ptr.Load(ctx)
19+
require.Equal(t, "hello", value)
20+
ptr.Store("world")
21+
_, value = ptr.Load(ctx)
22+
require.Equal(t, "hello", value)
23+
})
24+
t.Run("Multiple", func(t *testing.T) {
25+
t.Parallel()
26+
ptr1 := pointer.New("1")
27+
ptr2 := pointer.New("2")
28+
ctx := context.Background()
29+
ctx, v1 := ptr1.Load(ctx)
30+
require.Equal(t, "1", v1)
31+
_, v2 := ptr2.Load(ctx)
32+
require.Equal(t, "2", v2)
33+
})
34+
}

coderd/templates.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,9 @@ func (api *API) template(rw http.ResponseWriter, r *http.Request) {
8585
func (api *API) deleteTemplate(rw http.ResponseWriter, r *http.Request) {
8686
var (
8787
template = httpmw.TemplateParam(r)
88+
_, auditor = api.Auditor.Load(r.Context())
8889
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
89-
Audit: *api.Auditor.Load(),
90+
Audit: auditor,
9091
Log: api.Logger,
9192
Request: r,
9293
Action: database.AuditActionDelete,
@@ -139,14 +140,15 @@ func (api *API) postTemplateByOrganization(rw http.ResponseWriter, r *http.Reque
139140
createTemplate codersdk.CreateTemplateRequest
140141
organization = httpmw.OrganizationParam(r)
141142
apiKey = httpmw.APIKey(r)
143+
_, auditor = api.Auditor.Load(r.Context())
142144
templateAudit, commitTemplateAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
143-
Audit: *api.Auditor.Load(),
145+
Audit: auditor,
144146
Log: api.Logger,
145147
Request: r,
146148
Action: database.AuditActionCreate,
147149
})
148150
templateVersionAudit, commitTemplateVersionAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
149-
Audit: *api.Auditor.Load(),
151+
Audit: auditor,
150152
Log: api.Logger,
151153
Request: r,
152154
Action: database.AuditActionWrite,
@@ -435,8 +437,9 @@ func (api *API) templateByOrganizationAndName(rw http.ResponseWriter, r *http.Re
435437
func (api *API) patchTemplateMeta(rw http.ResponseWriter, r *http.Request) {
436438
var (
437439
template = httpmw.TemplateParam(r)
440+
_, auditor = api.Auditor.Load(r.Context())
438441
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
439-
Audit: *api.Auditor.Load(),
442+
Audit: auditor,
440443
Log: api.Logger,
441444
Request: r,
442445
Action: database.AuditActionWrite,

coderd/templateversions.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,8 +559,9 @@ func (api *API) templateVersionByName(rw http.ResponseWriter, r *http.Request) {
559559
func (api *API) patchActiveTemplateVersion(rw http.ResponseWriter, r *http.Request) {
560560
var (
561561
template = httpmw.TemplateParam(r)
562+
_, auditor = api.Auditor.Load(r.Context())
562563
aReq, commitAudit = audit.InitRequest[database.Template](rw, &audit.RequestParams{
563-
Audit: *api.Auditor.Load(),
564+
Audit: auditor,
564565
Log: api.Logger,
565566
Request: r,
566567
Action: database.AuditActionWrite,
@@ -631,8 +632,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
631632
var (
632633
apiKey = httpmw.APIKey(r)
633634
organization = httpmw.OrganizationParam(r)
635+
_, auditor = api.Auditor.Load(r.Context())
634636
aReq, commitAudit = audit.InitRequest[database.TemplateVersion](rw, &audit.RequestParams{
635-
Audit: *api.Auditor.Load(),
637+
Audit: auditor,
636638
Log: api.Logger,
637639
Request: r,
638640
Action: database.AuditActionCreate,

coderd/users.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,9 @@ func (api *API) users(rw http.ResponseWriter, r *http.Request) {
255255

256256
// Creates a new user.
257257
func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
258+
_, auditor := api.Auditor.Load(r.Context())
258259
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
259-
Audit: *api.Auditor.Load(),
260+
Audit: auditor,
260261
Log: api.Logger,
261262
Request: r,
262263
Action: database.AuditActionCreate,
@@ -339,9 +340,10 @@ func (api *API) postUser(rw http.ResponseWriter, r *http.Request) {
339340
}
340341

341342
func (api *API) deleteUser(rw http.ResponseWriter, r *http.Request) {
343+
_, auditor := api.Auditor.Load(r.Context())
342344
user := httpmw.UserParam(r)
343345
aReq, commitAudit := audit.InitRequest[database.User](rw, &audit.RequestParams{
344-
Audit: *api.Auditor.Load(),
346+
Audit: auditor,
345347
Log: api.Logger,
346348
Request: r,
347349
Action: database.AuditActionDelete,
@@ -414,8 +416,9 @@ func (api *API) userByName(rw http.ResponseWriter, r *http.Request) {
414416
func (api *API) putUserProfile(rw http.ResponseWriter, r *http.Request) {
415417
var (
416418
user = httpmw.UserParam(r)
419+
_, auditor = api.Auditor.Load(r.Context())
417420
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
418-
Audit: *api.Auditor.Load(),
421+
Audit: auditor,
419422
Log: api.Logger,
420423
Request: r,
421424
Action: database.AuditActionWrite,
@@ -494,8 +497,9 @@ func (api *API) putUserStatus(status database.UserStatus) func(rw http.ResponseW
494497
var (
495498
user = httpmw.UserParam(r)
496499
apiKey = httpmw.APIKey(r)
500+
_, auditor = api.Auditor.Load(r.Context())
497501
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
498-
Audit: *api.Auditor.Load(),
502+
Audit: auditor,
499503
Log: api.Logger,
500504
Request: r,
501505
Action: database.AuditActionWrite,
@@ -560,8 +564,9 @@ func (api *API) putUserPassword(rw http.ResponseWriter, r *http.Request) {
560564
var (
561565
user = httpmw.UserParam(r)
562566
params codersdk.UpdateUserPasswordRequest
567+
_, auditor = api.Auditor.Load(r.Context())
563568
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
564-
Audit: *api.Auditor.Load(),
569+
Audit: auditor,
565570
Log: api.Logger,
566571
Request: r,
567572
Action: database.AuditActionWrite,
@@ -698,8 +703,9 @@ func (api *API) putUserRoles(rw http.ResponseWriter, r *http.Request) {
698703
user = httpmw.UserParam(r)
699704
actorRoles = httpmw.AuthorizationUserRoles(r)
700705
apiKey = httpmw.APIKey(r)
706+
_, auditor = api.Auditor.Load(r.Context())
701707
aReq, commitAudit = audit.InitRequest[database.User](rw, &audit.RequestParams{
702-
Audit: *api.Auditor.Load(),
708+
Audit: auditor,
703709
Log: api.Logger,
704710
Request: r,
705711
Action: database.AuditActionWrite,

coderd/workspaces.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
254254
var (
255255
organization = httpmw.OrganizationParam(r)
256256
apiKey = httpmw.APIKey(r)
257+
_, auditor = api.Auditor.Load(r.Context())
257258
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
258-
Audit: *api.Auditor.Load(),
259+
Audit: auditor,
259260
Log: api.Logger,
260261
Request: r,
261262
Action: database.AuditActionCreate,
@@ -495,8 +496,9 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
495496
func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
496497
var (
497498
workspace = httpmw.WorkspaceParam(r)
499+
_, auditor = api.Auditor.Load(r.Context())
498500
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
499-
Audit: *api.Auditor.Load(),
501+
Audit: auditor,
500502
Log: api.Logger,
501503
Request: r,
502504
Action: database.AuditActionWrite,
@@ -571,8 +573,9 @@ func (api *API) patchWorkspace(rw http.ResponseWriter, r *http.Request) {
571573
func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
572574
var (
573575
workspace = httpmw.WorkspaceParam(r)
576+
_, auditor = api.Auditor.Load(r.Context())
574577
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
575-
Audit: *api.Auditor.Load(),
578+
Audit: auditor,
576579
Log: api.Logger,
577580
Request: r,
578581
Action: database.AuditActionWrite,
@@ -631,8 +634,9 @@ func (api *API) putWorkspaceAutostart(rw http.ResponseWriter, r *http.Request) {
631634
func (api *API) putWorkspaceTTL(rw http.ResponseWriter, r *http.Request) {
632635
var (
633636
workspace = httpmw.WorkspaceParam(r)
637+
_, auditor = api.Auditor.Load(r.Context())
634638
aReq, commitAudit = audit.InitRequest[database.Workspace](rw, &audit.RequestParams{
635-
Audit: *api.Auditor.Load(),
639+
Audit: auditor,
636640
Log: api.Logger,
637641
Request: r,
638642
Action: database.AuditActionWrite,

enterprise/coderd/coderd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ func (api *API) updateEntitlements(ctx context.Context) error {
161161
backends.NewSlog(api.Logger),
162162
)
163163
}
164-
api.AGPL.Auditor.Store(&auditor)
164+
api.AGPL.Auditor.Store(auditor)
165165
}
166166
return nil
167167
}

0 commit comments

Comments
 (0)