Skip to content

Commit 15b637d

Browse files
committed
fix: suppress license expiry warning if a new license covers the gap
Previously, if you had a new license that would start before the current one fully expired, you would get a warning. Now, the license validity periods are merged together, and a warning is only generated based on the end of the current contiguous period of license coverage.
1 parent be40b8c commit 15b637d

File tree

3 files changed

+349
-8
lines changed

3 files changed

+349
-8
lines changed

enterprise/coderd/license/license.go

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"database/sql"
77
"fmt"
88
"math"
9+
"sort"
910
"time"
1011

1112
"github.com/golang-jwt/jwt/v4"
@@ -192,6 +193,13 @@ func LicensesEntitlements(
192193
})
193194
}
194195

196+
// nextLicenseValidityPeriod holds the current or next contiguous period
197+
// where there will be at least one active license. This is used for
198+
// generating license expiry warnings. Previously we would generate licenses
199+
// expiry warnings for each license, but it means that the warning will show
200+
// even if you've loaded up a new license that doesn't have any gap.
201+
nextLicenseValidityPeriod := &licenseValidityPeriod{}
202+
195203
// TODO: License specific warnings and errors should be tied to the license, not the
196204
// 'Entitlements' group as a whole.
197205
for _, license := range licenses {
@@ -201,6 +209,17 @@ func LicensesEntitlements(
201209
// The license isn't valid yet. We don't consider any entitlements contained in it, but
202210
// it's also not an error. Just skip it silently. This can happen if an administrator
203211
// uploads a license for a new term that hasn't started yet.
212+
//
213+
// We still want to factor this into our validity period, though.
214+
// This ensures we can suppress license expiry warnings for expiring
215+
// licenses while a new license is ready to take its place.
216+
//
217+
// claims is nil, so reparse the claims with the IgnoreNbf function.
218+
claims, err = ParseClaimsIgnoreNbf(license.JWT, keys)
219+
if err != nil {
220+
continue
221+
}
222+
nextLicenseValidityPeriod.ApplyClaims(claims)
204223
continue
205224
}
206225
if err != nil {
@@ -209,6 +228,10 @@ func LicensesEntitlements(
209228
continue
210229
}
211230

231+
// Obviously, valid licenses should be considered for the license
232+
// validity period.
233+
nextLicenseValidityPeriod.ApplyClaims(claims)
234+
212235
usagePeriodStart := claims.NotBefore.Time // checked not-nil when validating claims
213236
usagePeriodEnd := claims.ExpiresAt.Time // checked not-nil when validating claims
214237
if usagePeriodStart.After(usagePeriodEnd) {
@@ -237,10 +260,6 @@ func LicensesEntitlements(
237260
entitlement = codersdk.EntitlementGracePeriod
238261
}
239262

240-
// Will add a warning if the license is expiring soon.
241-
// This warning can be raised multiple times if there is more than 1 license.
242-
licenseExpirationWarning(&entitlements, now, claims)
243-
244263
// 'claims.AllFeature' is the legacy way to set 'claims.FeatureSet = codersdk.FeatureSetEnterprise'
245264
// If both are set, ignore the legacy 'claims.AllFeature'
246265
if claims.AllFeatures && claims.FeatureSet == "" {
@@ -405,6 +424,10 @@ func LicensesEntitlements(
405424

406425
// Now the license specific warnings and errors are added to the entitlements.
407426

427+
// Add a single warning if we are currently in the license validity period
428+
// and it's expiring soon.
429+
nextLicenseValidityPeriod.LicenseExpirationWarning(&entitlements, now)
430+
408431
// If HA is enabled, ensure the feature is entitled.
409432
if featureArguments.ReplicaCount > 1 {
410433
feature := entitlements.Features[codersdk.FeatureHighAvailability]
@@ -742,10 +765,85 @@ func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, e
742765
}
743766
}
744767

745-
// licenseExpirationWarning adds a warning message if the license is expiring soon.
746-
func licenseExpirationWarning(entitlements *codersdk.Entitlements, now time.Time, claims *Claims) {
747-
// Add warning if license is expiring soon
748-
daysToExpire := int(math.Ceil(claims.LicenseExpires.Sub(now).Hours() / 24))
768+
// licenseValidityPeriod keeps track of all license validity periods, and
769+
// generates warnings over contiguous periods across multiple licenses.
770+
//
771+
// Note: this does not track the actual entitlements of each license to ensure
772+
// newer licenses cover the same features as older licenses before merging. It
773+
// is assumed that all licenses cover the same features.
774+
type licenseValidityPeriod struct {
775+
// parts contains all tracked license periods prior to merging.
776+
parts [][2]time.Time
777+
}
778+
779+
// ApplyClaims tracks a license validity period. This should only be called with
780+
// valid (including not-yet-valid), unexpired licenses.
781+
func (p *licenseValidityPeriod) ApplyClaims(claims *Claims) {
782+
if claims == nil || claims.NotBefore == nil || claims.LicenseExpires == nil {
783+
// Bad data
784+
return
785+
}
786+
p.Apply(claims.NotBefore.Time, claims.LicenseExpires.Time)
787+
}
788+
789+
// Apply adds a license validity period.
790+
func (p *licenseValidityPeriod) Apply(start, end time.Time) {
791+
if end.Before(start) {
792+
// Bad data
793+
return
794+
}
795+
p.parts = append(p.parts, [2]time.Time{start, end})
796+
}
797+
798+
// merged merges the license validity periods into contiguous blocks, and sorts
799+
// the merged blocks.
800+
func (p *licenseValidityPeriod) merged() [][2]time.Time {
801+
if len(p.parts) == 0 {
802+
return nil
803+
}
804+
805+
// Sort the input periods by start time.
806+
sorted := make([][2]time.Time, len(p.parts))
807+
copy(sorted, p.parts)
808+
sort.Slice(sorted, func(i, j int) bool {
809+
return sorted[i][0].Before(sorted[j][0])
810+
})
811+
812+
out := make([][2]time.Time, 0, len(sorted))
813+
cur := sorted[0]
814+
for i := 1; i < len(sorted); i++ {
815+
next := sorted[i]
816+
817+
// If the current period's end time is before or equal to the next
818+
// period's start time, they should be merged.
819+
if !next[0].After(cur[1]) {
820+
// Pick the maximum end time.
821+
if next[1].After(cur[1]) {
822+
cur[1] = next[1]
823+
}
824+
continue
825+
}
826+
827+
// They don't overlap, so commit the current period and start a new one.
828+
out = append(out, cur)
829+
cur = next
830+
}
831+
// Commit the final period.
832+
out = append(out, cur)
833+
return out
834+
}
835+
836+
// LicenseExpirationWarning adds a warning message if we are currently in the
837+
// license validity period and it's expiring soon.
838+
func (p *licenseValidityPeriod) LicenseExpirationWarning(entitlements *codersdk.Entitlements, now time.Time) {
839+
merged := p.merged()
840+
if len(merged) == 0 {
841+
// No licenses
842+
return
843+
}
844+
end := merged[0][1]
845+
846+
daysToExpire := int(math.Ceil(end.Sub(now).Hours() / 24))
749847
showWarningDays := 30
750848
isTrial := entitlements.Trial
751849
if isTrial {
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package license
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestNextLicenseValidityPeriod(t *testing.T) {
11+
t.Parallel()
12+
13+
t.Run("Apply", func(t *testing.T) {
14+
t.Parallel()
15+
16+
testCases := []struct {
17+
name string
18+
19+
licensePeriods [][2]time.Time
20+
expectedPeriods [][2]time.Time
21+
}{
22+
{
23+
name: "None",
24+
licensePeriods: [][2]time.Time{},
25+
expectedPeriods: [][2]time.Time{},
26+
},
27+
{
28+
name: "One",
29+
licensePeriods: [][2]time.Time{
30+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
31+
},
32+
expectedPeriods: [][2]time.Time{
33+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
34+
},
35+
},
36+
{
37+
name: "TwoOverlapping",
38+
licensePeriods: [][2]time.Time{
39+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC)},
40+
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
41+
},
42+
expectedPeriods: [][2]time.Time{
43+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
44+
},
45+
},
46+
{
47+
name: "TwoNonOverlapping",
48+
licensePeriods: [][2]time.Time{
49+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
50+
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
51+
},
52+
expectedPeriods: [][2]time.Time{
53+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
54+
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
55+
},
56+
},
57+
{
58+
name: "ThreeOverlapping",
59+
licensePeriods: [][2]time.Time{
60+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC)},
61+
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC)},
62+
{time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
63+
},
64+
expectedPeriods: [][2]time.Time{
65+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
66+
},
67+
},
68+
{
69+
name: "ThreeNonOverlapping",
70+
licensePeriods: [][2]time.Time{
71+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
72+
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
73+
{time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
74+
},
75+
expectedPeriods: [][2]time.Time{
76+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
77+
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
78+
{time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
79+
},
80+
},
81+
{
82+
name: "PeriodContainsAnotherPeriod",
83+
licensePeriods: [][2]time.Time{
84+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 8, 0, 0, 0, 0, time.UTC)},
85+
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
86+
},
87+
expectedPeriods: [][2]time.Time{
88+
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 8, 0, 0, 0, 0, time.UTC)},
89+
},
90+
},
91+
{
92+
name: "EndBeforeStart",
93+
licensePeriods: [][2]time.Time{
94+
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)},
95+
},
96+
expectedPeriods: nil,
97+
},
98+
}
99+
100+
for _, tc := range testCases {
101+
t.Run(tc.name, func(t *testing.T) {
102+
t.Parallel()
103+
104+
// Test with all possible permutations of the periods to ensure
105+
// consistency regardless of the order.
106+
ps := permutations(tc.licensePeriods)
107+
for _, p := range ps {
108+
t.Logf("permutation: %v", p)
109+
period := &licenseValidityPeriod{}
110+
for _, times := range p {
111+
t.Logf("applying %v", times)
112+
period.Apply(times[0], times[1])
113+
}
114+
assert.Equal(t, tc.expectedPeriods, period.merged(), "merged")
115+
}
116+
})
117+
}
118+
})
119+
}
120+
121+
func permutations[T any](arr []T) [][]T {
122+
var res [][]T
123+
var helper func([]T, int)
124+
helper = func(a []T, i int) {
125+
if i == len(a)-1 {
126+
// make a copy before appending
127+
tmp := make([]T, len(a))
128+
copy(tmp, a)
129+
res = append(res, tmp)
130+
return
131+
}
132+
for j := i; j < len(a); j++ {
133+
a[i], a[j] = a[j], a[i]
134+
helper(a, i+1)
135+
a[i], a[j] = a[j], a[i] // backtrack
136+
}
137+
}
138+
helper(arr, 0)
139+
return res
140+
}

0 commit comments

Comments
 (0)