Skip to content

fix: allow posting licenses that will be valid in future #14491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion enterprise/coderd/coderdenttest/coderdenttest.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ type LicenseOptions struct {
// ExpiresAt is the time at which the license will hard expire.
// ExpiresAt should always be greater then GraceAt.
ExpiresAt time.Time
// NotBefore is the time at which the license becomes valid. If set to the
// zero value, the `nbf` claim on the license is set to 1 minute in the
// past.
NotBefore time.Time
Features license.Features
}

Expand Down Expand Up @@ -233,13 +237,16 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
if options.GraceAt.IsZero() {
options.GraceAt = time.Now().Add(time.Hour)
}
if options.NotBefore.IsZero() {
options.NotBefore = time.Now().Add(-time.Minute)
}

c := &license.Claims{
RegisteredClaims: jwt.RegisteredClaims{
ID: uuid.NewString(),
Issuer: "test@testing.test",
ExpiresAt: jwt.NewNumericDate(options.ExpiresAt),
NotBefore: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
NotBefore: jwt.NewNumericDate(options.NotBefore),
IssuedAt: jwt.NewNumericDate(time.Now().Add(-time.Minute)),
},
LicenseExpires: jwt.NewNumericDate(options.GraceAt),
Expand Down
41 changes: 39 additions & 2 deletions enterprise/coderd/license/license.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ var (
ErrInvalidVersion = xerrors.New("license must be version 3")
ErrMissingKeyID = xerrors.Errorf("JOSE header must contain %s", HeaderKeyID)
ErrMissingLicenseExpires = xerrors.New("license missing license_expires")
ErrMissingExp = xerrors.New("exp claim missing or not parsable")
ErrMultipleIssues = xerrors.New("license has multiple issues; contact support")
)

type Features map[codersdk.FeatureName]int64
Expand Down Expand Up @@ -336,7 +338,7 @@ func ParseRaw(l string, keys map[string]ed25519.PublicKey) (jwt.MapClaims, error
return nil, xerrors.New("unable to parse Claims")
}

// ParseClaims validates a database.License record, and if valid, returns the claims. If
// ParseClaims validates a raw JWT, and if valid, returns the claims. If
// unparsable or invalid, it returns an error
func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims(
Expand All @@ -348,18 +350,53 @@ func ParseClaims(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, err
if err != nil {
return nil, err
}
if claims, ok := tok.Claims.(*Claims); ok && tok.Valid {
return validateClaims(tok)
}

func validateClaims(tok *jwt.Token) (*Claims, error) {
if claims, ok := tok.Claims.(*Claims); ok {
if claims.Version != uint64(CurrentVersion) {
return nil, ErrInvalidVersion
}
if claims.LicenseExpires == nil {
return nil, ErrMissingLicenseExpires
}
if claims.ExpiresAt == nil {
return nil, ErrMissingExp
}
return claims, nil
}
return nil, xerrors.New("unable to parse Claims")
}

// ParseClaimsIgnoreNbf validates a raw JWT, but ignores `nbf` claim. If otherwise valid, it returns
// the claims. If unparsable or invalid, it returns an error. Ignoring the `nbf` (not before) is
// useful to determine if a JWT _will_ become valid at any point now or in the future.
func ParseClaimsIgnoreNbf(rawJWT string, keys map[string]ed25519.PublicKey) (*Claims, error) {
tok, err := jwt.ParseWithClaims(
rawJWT,
&Claims{},
keyFunc(keys),
jwt.WithValidMethods(ValidMethods),
)
var vErr *jwt.ValidationError
if xerrors.As(err, &vErr) {
// zero out the NotValidYet error to check if there were other problems
vErr.Errors = vErr.Errors & (^jwt.ValidationErrorNotValidYet)
if vErr.Errors != 0 {
// There are other errors besides not being valid yet. We _could_ go
// through all the jwt.ValidationError bits and try to work out the
// correct error, but if we get here something very strange is
// going on so let's just return a generic error that says to get in
// touch with our support team.
return nil, ErrMultipleIssues
}
} else if err != nil {
return nil, err
}
return validateClaims(tok)
}

func keyFunc(keys map[string]ed25519.PublicKey) func(*jwt.Token) (interface{}, error) {
return func(j *jwt.Token) (interface{}, error) {
keyID, ok := j.Header[HeaderKeyID].(string)
Expand Down
32 changes: 11 additions & 21 deletions enterprise/coderd/licenses.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
return
}

rawClaims, err := license.ParseRaw(addLicense.License, api.LicenseKeys)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: err.Error(),
})
return
}
exp, ok := rawClaims["exp"].(float64)
if !ok {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Detail: "exp claim missing or not parsable",
})
return
}
expTime := time.Unix(int64(exp), 0)

claims, err := license.ParseClaims(addLicense.License, api.LicenseKeys)
claims, err := license.ParseClaimsIgnoreNbf(addLicense.License, api.LicenseKeys)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid license",
Expand Down Expand Up @@ -134,7 +116,7 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
dl, err := api.Database.InsertLicense(ctx, database.InsertLicenseParams{
UploadedAt: dbtime.Now(),
JWT: addLicense.License,
Exp: expTime,
Exp: claims.ExpiresAt.Time,
UUID: id,
})
if err != nil {
Expand All @@ -160,7 +142,15 @@ func (api *API) postLicense(rw http.ResponseWriter, r *http.Request) {
// don't fail the HTTP request, since we did write it successfully to the database
}

httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, rawClaims))
c, err := decodeClaims(dl)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to decode database response",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusCreated, convertLicense(dl, c))
}

// postRefreshEntitlements forces an `updateEntitlements` call and publishes
Expand Down
48 changes: 48 additions & 0 deletions enterprise/coderd/licenses_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"net/http"
"testing"
"time"

"github.com/google/uuid"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -82,6 +83,53 @@ func TestPostLicense(t *testing.T) {
t.Error("expected to get error status 400")
}
})

// Test a license that isn't yet valid, but will be in the future. We should allow this so that
// operators can upload a license ahead of time.
t.Run("NotYet", func(t *testing.T) {
t.Parallel()
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
respLic := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
AccountType: license.AccountTypeSalesforce,
AccountID: "testing",
Features: license.Features{
codersdk.FeatureAuditLog: 1,
},
NotBefore: time.Now().Add(time.Hour),
GraceAt: time.Now().Add(2 * time.Hour),
ExpiresAt: time.Now().Add(3 * time.Hour),
})
assert.GreaterOrEqual(t, respLic.ID, int32(0))
// just a couple spot checks for sanity
assert.Equal(t, "testing", respLic.Claims["account_id"])
features, err := respLic.FeaturesClaims()
require.NoError(t, err)
assert.EqualValues(t, 1, features[codersdk.FeatureAuditLog])
})

// Test we still reject a license that isn't valid yet, but has other issues (e.g. expired
// before it starts).
t.Run("NotEver", func(t *testing.T) {
t.Parallel()
client, _ := coderdenttest.New(t, &coderdenttest.Options{DontAddLicense: true})
lic := coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
AccountType: license.AccountTypeSalesforce,
AccountID: "testing",
Features: license.Features{
codersdk.FeatureAuditLog: 1,
},
NotBefore: time.Now().Add(time.Hour),
GraceAt: time.Now().Add(2 * time.Hour),
ExpiresAt: time.Now().Add(-time.Hour),
})
_, err := client.AddLicense(context.Background(), codersdk.AddLicenseRequest{
License: lic,
})
errResp := &codersdk.Error{}
require.ErrorAs(t, err, &errResp)
require.Equal(t, http.StatusBadRequest, errResp.StatusCode())
require.Contains(t, errResp.Detail, license.ErrMultipleIssues.Error())
})
}

func TestGetLicense(t *testing.T) {
Expand Down
Loading