Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: suppress license expiry warning if a new license covers the gap
Previously, if you had a new license that would start before the current
one fully expired, you would get a warning. Now, the license validity
periods are merged together, and a warning is only generated based on
the end of the current contiguous period of license coverage.
  • Loading branch information
deansheather committed Aug 28, 2025
commit 15b637d043dc26f8380ee0197658ab192f3d69bf
114 changes: 106 additions & 8 deletions enterprise/coderd/license/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql"
"fmt"
"math"
"sort"
"time"

"github.com/golang-jwt/jwt/v4"
Expand Down Expand Up @@ -192,6 +193,13 @@ func LicensesEntitlements(
})
}

// nextLicenseValidityPeriod holds the current or next contiguous period
// where there will be at least one active license. This is used for
// generating license expiry warnings. Previously we would generate licenses
// expiry warnings for each license, but it means that the warning will show
// even if you've loaded up a new license that doesn't have any gap.
nextLicenseValidityPeriod := &licenseValidityPeriod{}

// TODO: License specific warnings and errors should be tied to the license, not the
// 'Entitlements' group as a whole.
for _, license := range licenses {
Expand All @@ -201,6 +209,17 @@ func LicensesEntitlements(
// The license isn't valid yet. We don't consider any entitlements contained in it, but
// it's also not an error. Just skip it silently. This can happen if an administrator
// uploads a license for a new term that hasn't started yet.
//
// We still want to factor this into our validity period, though.
// This ensures we can suppress license expiry warnings for expiring
// licenses while a new license is ready to take its place.
//
// claims is nil, so reparse the claims with the IgnoreNbf function.
claims, err = ParseClaimsIgnoreNbf(license.JWT, keys)
if err != nil {
continue
}
nextLicenseValidityPeriod.ApplyClaims(claims)
continue
}
if err != nil {
Expand All @@ -209,6 +228,10 @@ func LicensesEntitlements(
continue
}

// Obviously, valid licenses should be considered for the license
// validity period.
nextLicenseValidityPeriod.ApplyClaims(claims)

usagePeriodStart := claims.NotBefore.Time // checked not-nil when validating claims
usagePeriodEnd := claims.ExpiresAt.Time // checked not-nil when validating claims
if usagePeriodStart.After(usagePeriodEnd) {
Expand Down Expand Up @@ -237,10 +260,6 @@ func LicensesEntitlements(
entitlement = codersdk.EntitlementGracePeriod
}

// Will add a warning if the license is expiring soon.
// This warning can be raised multiple times if there is more than 1 license.
licenseExpirationWarning(&entitlements, now, claims)

// 'claims.AllFeature' is the legacy way to set 'claims.FeatureSet = codersdk.FeatureSetEnterprise'
// If both are set, ignore the legacy 'claims.AllFeature'
if claims.AllFeatures && claims.FeatureSet == "" {
Expand Down Expand Up @@ -405,6 +424,10 @@ func LicensesEntitlements(

// Now the license specific warnings and errors are added to the entitlements.

// Add a single warning if we are currently in the license validity period
// and it's expiring soon.
nextLicenseValidityPeriod.LicenseExpirationWarning(&entitlements, now)

// If HA is enabled, ensure the feature is entitled.
if featureArguments.ReplicaCount > 1 {
feature := entitlements.Features[codersdk.FeatureHighAvailability]
Expand Down Expand Up @@ -742,10 +765,85 @@ func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, e
}
}

// licenseExpirationWarning adds a warning message if the license is expiring soon.
func licenseExpirationWarning(entitlements *codersdk.Entitlements, now time.Time, claims *Claims) {
// Add warning if license is expiring soon
daysToExpire := int(math.Ceil(claims.LicenseExpires.Sub(now).Hours() / 24))
// licenseValidityPeriod keeps track of all license validity periods, and
// generates warnings over contiguous periods across multiple licenses.
//
// Note: this does not track the actual entitlements of each license to ensure
// newer licenses cover the same features as older licenses before merging. It
// is assumed that all licenses cover the same features.
type licenseValidityPeriod struct {
// parts contains all tracked license periods prior to merging.
parts [][2]time.Time
}

// ApplyClaims tracks a license validity period. This should only be called with
// valid (including not-yet-valid), unexpired licenses.
func (p *licenseValidityPeriod) ApplyClaims(claims *Claims) {
if claims == nil || claims.NotBefore == nil || claims.LicenseExpires == nil {
// Bad data
return
}
p.Apply(claims.NotBefore.Time, claims.LicenseExpires.Time)
}

// Apply adds a license validity period.
func (p *licenseValidityPeriod) Apply(start, end time.Time) {
if end.Before(start) {
// Bad data
return
}
p.parts = append(p.parts, [2]time.Time{start, end})
}

// merged merges the license validity periods into contiguous blocks, and sorts
// the merged blocks.
func (p *licenseValidityPeriod) merged() [][2]time.Time {
if len(p.parts) == 0 {
return nil
}

// Sort the input periods by start time.
sorted := make([][2]time.Time, len(p.parts))
copy(sorted, p.parts)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i][0].Before(sorted[j][0])
})

out := make([][2]time.Time, 0, len(sorted))
cur := sorted[0]
for i := 1; i < len(sorted); i++ {
next := sorted[i]

// If the current period's end time is before or equal to the next
// period's start time, they should be merged.
if !next[0].After(cur[1]) {
// Pick the maximum end time.
if next[1].After(cur[1]) {
cur[1] = next[1]
}
continue
}

// They don't overlap, so commit the current period and start a new one.
out = append(out, cur)
cur = next
}
// Commit the final period.
out = append(out, cur)
return out
}

// LicenseExpirationWarning adds a warning message if we are currently in the
// license validity period and it's expiring soon.
func (p *licenseValidityPeriod) LicenseExpirationWarning(entitlements *codersdk.Entitlements, now time.Time) {
merged := p.merged()
if len(merged) == 0 {
// No licenses
return
}
end := merged[0][1]

daysToExpire := int(math.Ceil(end.Sub(now).Hours() / 24))
showWarningDays := 30
isTrial := entitlements.Trial
if isTrial {
Expand Down
140 changes: 140 additions & 0 deletions enterprise/coderd/license/license_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package license

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestNextLicenseValidityPeriod(t *testing.T) {
t.Parallel()

t.Run("Apply", func(t *testing.T) {
t.Parallel()

testCases := []struct {
name string

licensePeriods [][2]time.Time
expectedPeriods [][2]time.Time
}{
{
name: "None",
licensePeriods: [][2]time.Time{},
expectedPeriods: [][2]time.Time{},
},
{
name: "One",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "TwoOverlapping",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "TwoNonOverlapping",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "ThreeOverlapping",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "ThreeNonOverlapping",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 4, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 5, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "PeriodContainsAnotherPeriod",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 8, 0, 0, 0, 0, time.UTC)},
{time.Date(2025, 1, 3, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: [][2]time.Time{
{time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 8, 0, 0, 0, 0, time.UTC)},
},
},
{
name: "EndBeforeStart",
licensePeriods: [][2]time.Time{
{time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)},
},
expectedPeriods: nil,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

// Test with all possible permutations of the periods to ensure
// consistency regardless of the order.
ps := permutations(tc.licensePeriods)
for _, p := range ps {
t.Logf("permutation: %v", p)
period := &licenseValidityPeriod{}
for _, times := range p {
t.Logf("applying %v", times)
period.Apply(times[0], times[1])
}
assert.Equal(t, tc.expectedPeriods, period.merged(), "merged")
}
})
}
})
}

func permutations[T any](arr []T) [][]T {
var res [][]T
var helper func([]T, int)
helper = func(a []T, i int) {
if i == len(a)-1 {
// make a copy before appending
tmp := make([]T, len(a))
copy(tmp, a)
res = append(res, tmp)
return
}
for j := i; j < len(a); j++ {
a[i], a[j] = a[j], a[i]
helper(a, i+1)
a[i], a[j] = a[j], a[i] // backtrack
}
}
helper(arr, 0)
return res
}
Loading
Loading