Skip to content

Commit c67605b

Browse files
committed
Add entitlements struct
1 parent 7fd0903 commit c67605b

File tree

1 file changed

+48
-42
lines changed

1 file changed

+48
-42
lines changed

enterprise/coderd/coderd.go

Lines changed: 48 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ func New(ctx context.Context, options *Options) (*API, error) {
3838
AGPL: coderd.New(options.Options),
3939
Options: options,
4040

41-
activeUsers: codersdk.Feature{
42-
Entitlement: codersdk.EntitlementNotEntitled,
43-
Enabled: false,
41+
entitlements: entitlements{
42+
activeUsers: codersdk.Feature{
43+
Entitlement: codersdk.EntitlementNotEntitled,
44+
Enabled: false,
45+
},
46+
auditLogs: codersdk.EntitlementNotEntitled,
4447
},
45-
auditLogs: codersdk.EntitlementNotEntitled,
4648
cancelEntitlementsLoop: cancelFunc,
4749
}
4850
oauthConfigs := &httpmw.OAuth2Configs{
@@ -52,7 +54,7 @@ func New(ctx context.Context, options *Options) (*API, error) {
5254
apiKeyMiddleware := httpmw.ExtractAPIKey(options.Database, oauthConfigs, false)
5355

5456
api.AGPL.APIHandler.Group(func(r chi.Router) {
55-
r.Get("/entitlements", api.entitlements)
57+
r.Get("/entitlements", api.serveEntitlements)
5658
r.Route("/licenses", func(r chi.Router) {
5759
r.Use(apiKeyMiddleware)
5860
r.Post("/", api.postLicense)
@@ -83,10 +85,14 @@ type API struct {
8385
*Options
8486

8587
cancelEntitlementsLoop func()
86-
mutex sync.RWMutex
87-
hasLicense bool
88-
activeUsers codersdk.Feature
89-
auditLogs codersdk.Entitlement
88+
entitlementsMu sync.RWMutex
89+
entitlements entitlements
90+
}
91+
92+
type entitlements struct {
93+
hasLicense bool
94+
activeUsers codersdk.Feature
95+
auditLogs codersdk.Entitlement
9096
}
9197

9298
func (api *API) Close() error {
@@ -99,17 +105,19 @@ func (api *API) updateEntitlements(ctx context.Context) error {
99105
if err != nil {
100106
return err
101107
}
102-
api.mutex.Lock()
103-
defer api.mutex.Unlock()
108+
api.entitlementsMu.Lock()
109+
defer api.entitlementsMu.Unlock()
104110
now := time.Now()
105111

106112
// Default all entitlements to be disabled.
107-
hasLicense := false
108-
activeUsers := codersdk.Feature{
109-
Enabled: false,
110-
Entitlement: codersdk.EntitlementNotEntitled,
113+
entitlements := entitlements{
114+
hasLicense: false,
115+
activeUsers: codersdk.Feature{
116+
Enabled: false,
117+
Entitlement: codersdk.EntitlementNotEntitled,
118+
},
119+
auditLogs: codersdk.EntitlementNotEntitled,
111120
}
112-
auditLogs := codersdk.EntitlementNotEntitled
113121

114122
// Here we loop through licenses to detect enabled features.
115123
for _, l := range licenses {
@@ -119,33 +127,35 @@ func (api *API) updateEntitlements(ctx context.Context) error {
119127
slog.F("id", l.ID), slog.Error(err))
120128
continue
121129
}
122-
hasLicense = true
130+
entitlements.hasLicense = true
123131
entitlement := codersdk.EntitlementEntitled
124132
if now.After(claims.LicenseExpires.Time) {
125133
// if the grace period were over, the validation fails, so if we are after
126134
// LicenseExpires we must be in grace period.
127135
entitlement = codersdk.EntitlementGracePeriod
128136
}
129137
if claims.Features.UserLimit > 0 {
130-
activeUsers.Enabled = true
131-
activeUsers.Entitlement = entitlement
138+
entitlements.activeUsers = codersdk.Feature{
139+
Enabled: true,
140+
Entitlement: entitlement,
141+
}
132142
currentLimit := int64(0)
133-
if activeUsers.Limit != nil {
134-
currentLimit = *activeUsers.Limit
143+
if entitlements.activeUsers.Limit != nil {
144+
currentLimit = *entitlements.activeUsers.Limit
135145
}
136146
limit := max(currentLimit, claims.Features.UserLimit)
137-
activeUsers.Limit = &limit
147+
entitlements.activeUsers.Limit = &limit
138148
}
139149
if claims.Features.AuditLog > 0 {
140-
auditLogs = entitlement
150+
entitlements.auditLogs = entitlement
141151
}
142152
}
143153

144-
if auditLogs != api.auditLogs {
154+
if entitlements.auditLogs != api.entitlements.auditLogs {
145155
auditor := agplaudit.NewNop()
146156
// A flag could be added to the options that would allow disabling
147157
// enhanced audit logging here!
148-
if auditLogs == codersdk.EntitlementEntitled && api.AuditLogging {
158+
if entitlements.auditLogs == codersdk.EntitlementEntitled && api.AuditLogging {
149159
auditor = audit.NewAuditor(
150160
audit.DefaultFilter,
151161
backends.NewPostgres(api.Database, true),
@@ -155,27 +165,23 @@ func (api *API) updateEntitlements(ctx context.Context) error {
155165
api.AGPL.Auditor.Store(&auditor)
156166
}
157167

158-
api.hasLicense = hasLicense
159-
api.activeUsers = activeUsers
160-
api.auditLogs = auditLogs
168+
api.entitlements = entitlements
161169

162170
return nil
163171
}
164172

165-
func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) {
166-
api.mutex.RLock()
167-
hasLicense := api.hasLicense
168-
activeUsers := api.activeUsers
169-
auditLogs := api.auditLogs
170-
api.mutex.RUnlock()
173+
func (api *API) serveEntitlements(rw http.ResponseWriter, r *http.Request) {
174+
api.entitlementsMu.RLock()
175+
entitlements := api.entitlements
176+
api.entitlementsMu.RUnlock()
171177

172178
resp := codersdk.Entitlements{
173179
Features: make(map[string]codersdk.Feature),
174180
Warnings: make([]string, 0),
175-
HasLicense: hasLicense,
181+
HasLicense: entitlements.hasLicense,
176182
}
177183

178-
if activeUsers.Limit != nil {
184+
if entitlements.activeUsers.Limit != nil {
179185
activeUserCount, err := api.Database.GetActiveUserCount(r.Context())
180186
if err != nil {
181187
httpapi.Write(rw, http.StatusInternalServerError, codersdk.Response{
@@ -184,22 +190,22 @@ func (api *API) entitlements(rw http.ResponseWriter, r *http.Request) {
184190
})
185191
return
186192
}
187-
activeUsers.Actual = &activeUserCount
188-
if activeUserCount > *activeUsers.Limit {
193+
entitlements.activeUsers.Actual = &activeUserCount
194+
if activeUserCount > *entitlements.activeUsers.Limit {
189195
resp.Warnings = append(resp.Warnings,
190196
fmt.Sprintf(
191197
"Your deployment has %d active users but is only licensed for %d.",
192-
activeUserCount, *activeUsers.Limit))
198+
activeUserCount, *entitlements.activeUsers.Limit))
193199
}
194200
}
195-
resp.Features[codersdk.FeatureUserLimit] = activeUsers
201+
resp.Features[codersdk.FeatureUserLimit] = entitlements.activeUsers
196202

197203
// Audit logs
198204
resp.Features[codersdk.FeatureAuditLog] = codersdk.Feature{
199-
Entitlement: auditLogs,
205+
Entitlement: entitlements.auditLogs,
200206
Enabled: api.AuditLogging,
201207
}
202-
if auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging {
208+
if entitlements.auditLogs == codersdk.EntitlementGracePeriod && api.AuditLogging {
203209
resp.Warnings = append(resp.Warnings,
204210
"Audit logging is enabled but your license for this feature is expired.")
205211
}

0 commit comments

Comments
 (0)