diff --git a/cli/server.go b/cli/server.go index 11a979d11297f..844e6f1ef4aee 100644 --- a/cli/server.go +++ b/cli/server.go @@ -788,6 +788,22 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. Prometheus: vals.Prometheus.Enable.Value(), STUN: len(vals.DERP.Server.STUNAddresses) != 0, Tunnel: tunnel != nil, + ParseLicenseJWT: func(lic *telemetry.License) error { + // This will be nil when running in AGPL-only mode. + if options.ParseLicenseClaims == nil { + return nil + } + + email, trial, err := options.ParseLicenseClaims(lic.JWT) + if err != nil { + return err + } + if email != "" { + lic.Email = &email + } + lic.Trial = &trial + return nil + }, }) if err != nil { return xerrors.Errorf("create telemetry reporter: %w", err) diff --git a/coderd/coderd.go b/coderd/coderd.go index 747ac04ee8407..3a9430d930373 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -172,6 +172,11 @@ type Options struct { StatsBatcher *batchstats.Batcher WorkspaceAppsStatsCollectorOptions workspaceapps.StatsCollectorOptions + + // This janky function is used in telemetry to parse fields out of the raw + // JWT. It needs to be passed through like this because license parsing is + // under the enterprise license, and can't be imported into AGPL. + ParseLicenseClaims func(rawJWT string) (email string, trial bool, err error) } // @title Coder API diff --git a/coderd/httpmw/workspaceagentparam_test.go b/coderd/httpmw/workspaceagentparam_test.go index 127dda843f1e8..c1a5aef4e85f6 100644 --- a/coderd/httpmw/workspaceagentparam_test.go +++ b/coderd/httpmw/workspaceagentparam_test.go @@ -6,12 +6,12 @@ import ( "net/http/httptest" "testing" - "cdr.dev/slog" "github.com/go-chi/chi/v5" "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" diff --git a/coderd/telemetry/telemetry.go b/coderd/telemetry/telemetry.go index 39f3b892c2150..71dce3a77ce08 100644 --- a/coderd/telemetry/telemetry.go +++ b/coderd/telemetry/telemetry.go @@ -52,6 +52,7 @@ type Options struct { STUN bool SnapshotFrequency time.Duration Tunnel bool + ParseLicenseJWT func(lic *License) error } // New constructs a reporter for telemetry data. @@ -446,7 +447,13 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) { } snapshot.Licenses = make([]License, 0, len(licenses)) for _, license := range licenses { - snapshot.Licenses = append(snapshot.Licenses, ConvertLicense(license)) + tl := ConvertLicense(license) + if r.options.ParseLicenseJWT != nil { + if err := r.options.ParseLicenseJWT(&tl); err != nil { + r.options.Logger.Warn(ctx, "parse license JWT", slog.Error(err)) + } + } + snapshot.Licenses = append(snapshot.Licenses, tl) } return nil }) @@ -904,6 +911,10 @@ type License struct { UploadedAt time.Time `json:"uploaded_at"` Exp time.Time `json:"exp"` UUID uuid.UUID `json:"uuid"` + // These two fields are set by decoding the JWT. If the signing keys aren't + // passed in, these will always be nil. + Email *string `json:"email"` + Trial *bool `json:"trial"` } type WorkspaceProxy struct { diff --git a/enterprise/coderd/coderd.go b/enterprise/coderd/coderd.go index 028ae5a6768c1..32e96ec25975e 100644 --- a/enterprise/coderd/coderd.go +++ b/enterprise/coderd/coderd.go @@ -109,6 +109,13 @@ func New(ctx context.Context, options *Options) (_ *API, err error) { } }() + api.AGPL.Options.ParseLicenseClaims = func(rawJWT string) (email string, trial bool, err error) { + c, err := license.ParseClaims(rawJWT, Keys) + if err != nil { + return "", false, err + } + return c.Subject, c.Trial, nil + } api.AGPL.Options.SetUserGroups = api.setUserGroups api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles api.AGPL.SiteHandler.AppearanceFetcher = api.fetchAppearanceConfig diff --git a/enterprise/trialer/trialer.go b/enterprise/trialer/trialer.go index 14a8fa7b50ce0..e143225b886cb 100644 --- a/enterprise/trialer/trialer.go +++ b/enterprise/trialer/trialer.go @@ -9,9 +9,8 @@ import ( "net/http" "time" - "golang.org/x/xerrors" - "github.com/google/uuid" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtime"