diff --git a/coderd/features.go b/coderd/features.go index 55ddd2af895f9..ecc720e4db4e2 100644 --- a/coderd/features.go +++ b/coderd/features.go @@ -2,7 +2,11 @@ package coderd import ( "net/http" + "reflect" + "golang.org/x/xerrors" + + "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" ) @@ -11,11 +15,10 @@ import ( type FeaturesService interface { EntitlementsAPI(w http.ResponseWriter, r *http.Request) - // TODO - // Get returns the implementations for feature interfaces. Parameter `s `must be a pointer to a + // Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a // struct type containing feature interfaces as fields. The FeatureService sets all fields to // the correct implementations depending on whether the features are turned on. - // Get(s any) error + Get(s any) error } type featuresService struct{} @@ -34,3 +37,57 @@ func (featuresService) EntitlementsAPI(rw http.ResponseWriter, _ *http.Request) HasLicense: false, }) } + +// Get returns the implementations for feature interfaces. Parameter `s` must be a pointer to a +// struct type containing feature interfaces as fields. The AGPL featureService always returns the +// "disabled" version of the feature interface because it doesn't include any enterprise features +// by definition. +func (featuresService) Get(ps any) error { + if reflect.TypeOf(ps).Kind() != reflect.Pointer { + return xerrors.New("input must be pointer to struct") + } + vs := reflect.ValueOf(ps).Elem() + if vs.Kind() != reflect.Struct { + return xerrors.New("input must be pointer to struct") + } + for i := 0; i < vs.NumField(); i++ { + vf := vs.Field(i) + tf := vf.Type() + if tf.Kind() != reflect.Interface { + return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) + } + err := setImplementation(vf, tf) + if err != nil { + return err + } + } + return nil +} + +// setImplementation finds the correct implementation for the field's type, and sets it on the +// struct. It returns an error if unsuccessful +func setImplementation(vf reflect.Value, tf reflect.Type) error { + // when we get more than a few features it might make sense to have a data structure for finding + // the correct implementation that's faster than just a linear search, but for now just spin + // through the implementations we have. + vd := reflect.ValueOf(DisabledImplementations) + for j := 0; j < vd.NumField(); j++ { + vdf := vd.Field(j) + if vdf.Type() == tf { + vf.Set(vdf) + return nil + } + } + return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) +} + +// FeatureInterfaces contains a field for each interface controlled by an enterprise feature. +type FeatureInterfaces struct { + Auditor audit.Auditor +} + +// DisabledImplementations includes all the implementations of turned-off features. There are no +// turned-on implementations in AGPL code. +var DisabledImplementations = FeatureInterfaces{ + Auditor: audit.NewNop(), +} diff --git a/coderd/features_internal_test.go b/coderd/features_internal_test.go index d06fc96e19626..0c6a7f052b73a 100644 --- a/coderd/features_internal_test.go +++ b/coderd/features_internal_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/coderd/audit" "github.com/coder/coder/codersdk" ) @@ -36,3 +37,64 @@ func TestEntitlements(t *testing.T) { } }) } + +func TestFeaturesServiceGet(t *testing.T) { + t.Parallel() + t.Run("Auditor", func(t *testing.T) { + t.Parallel() + uut := featuresService{} + target := struct { + Auditor audit.Auditor + }{} + err := uut.Get(&target) + require.NoError(t, err) + assert.NotNil(t, target.Auditor) + }) + + t.Run("NotPointer", func(t *testing.T) { + t.Parallel() + uut := featuresService{} + target := struct { + Auditor audit.Auditor + }{} + err := uut.Get(target) + require.Error(t, err) + assert.Nil(t, target.Auditor) + }) + + t.Run("UnknownInterface", func(t *testing.T) { + t.Parallel() + uut := featuresService{} + target := struct { + test testInterface + }{} + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target.test) + }) + + t.Run("PointerToNonStruct", func(t *testing.T) { + t.Parallel() + uut := featuresService{} + var target audit.Auditor + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target) + }) + + t.Run("StructWithNonInterfaces", func(t *testing.T) { + t.Parallel() + uut := featuresService{} + target := struct { + N int64 + Auditor audit.Auditor + }{} + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target.Auditor) + }) +} + +type testInterface interface { + Test() error +} diff --git a/enterprise/coderd/features.go b/enterprise/coderd/features.go index 2102cdc0eb122..511e3eb05cdc9 100644 --- a/enterprise/coderd/features.go +++ b/enterprise/coderd/features.go @@ -5,17 +5,23 @@ import ( "crypto/ed25519" "fmt" "net/http" + "reflect" "sync" "time" + "github.com/coder/coder/enterprise/audit/backends" + "github.com/cenkalti/backoff/v4" + "golang.org/x/xerrors" "cdr.dev/slog" agpl "github.com/coder/coder/coderd" + agplAudit "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" ) type Enablements struct { @@ -29,6 +35,13 @@ type featuresService struct { keys map[string]ed25519.PublicKey enablements Enablements resyncInterval time.Duration + // enabledImplementations includes an "enabled" implementation of every feature. This is + // initialized at start of day and remains static. The consequence of this is that these things + // are hanging around using memory even if not licensed or in use, but it greatly simplifies the + // logic because we don't have to bother creating and destroying them as entitlements change. + // If we have a particularly memory-hungry feature in future, we might wish to reconsider this + // choice. + enabledImplementations agpl.FeatureInterfaces mu sync.RWMutex entitlements entitlements @@ -44,11 +57,18 @@ func newFeaturesService( enablements Enablements, ) agpl.FeaturesService { fs := &featuresService{ - logger: logger, - database: db, - pubsub: pubsub, - keys: keys, - enablements: enablements, + logger: logger, + database: db, + pubsub: pubsub, + keys: keys, + enablements: enablements, + enabledImplementations: agpl.FeatureInterfaces{ + Auditor: audit.NewAuditor( + audit.DefaultFilter, + backends.NewPostgres(db, true), + backends.NewSlog(logger), + ), + }, resyncInterval: 10 * time.Minute, entitlements: entitlements{ activeUsers: numericalEntitlement{ @@ -259,3 +279,48 @@ func max(a, b int64) int64 { } return b } + +func (s *featuresService) Get(ps any) error { + if reflect.TypeOf(ps).Kind() != reflect.Pointer { + return xerrors.New("input must be pointer to struct") + } + vs := reflect.ValueOf(ps).Elem() + if vs.Kind() != reflect.Struct { + return xerrors.New("input must be pointer to struct") + } + // grab a local copy of entitlements so that we have a consistent set, but aren't keeping it + // locked from updates while we process. + s.mu.RLock() + ent := s.entitlements + s.mu.RUnlock() + + for i := 0; i < vs.NumField(); i++ { + vf := vs.Field(i) + tf := vf.Type() + if tf.Kind() != reflect.Interface { + return xerrors.Errorf("fields of input struct must be interfaces: %s", tf.String()) + } + + err := s.setImplementation(ent, vf, tf) + if err != nil { + return err + } + } + return nil +} + +func (s *featuresService) setImplementation(ent entitlements, vf reflect.Value, tf reflect.Type) error { + // c.f. https://stackoverflow.com/questions/7132848/how-to-get-the-reflect-type-of-an-interface + switch tf { + case reflect.TypeOf((*agplAudit.Auditor)(nil)).Elem(): + // Audit logging + if !s.enablements.AuditLogs || ent.auditLogs.state == notEntitled { + vf.Set(reflect.ValueOf(agpl.DisabledImplementations.Auditor)) + return nil + } + vf.Set(reflect.ValueOf(s.enabledImplementations.Auditor)) + return nil + default: + return xerrors.Errorf("unable to find implementation of interface %s", tf.String()) + } +} diff --git a/enterprise/coderd/features_internal_test.go b/enterprise/coderd/features_internal_test.go index b8fbe7d30e8e9..da5bb39176bd1 100644 --- a/enterprise/coderd/features_internal_test.go +++ b/enterprise/coderd/features_internal_test.go @@ -7,21 +7,24 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" "time" "github.com/golang-jwt/jwt/v4" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "cdr.dev/slog/sloggers/slogtest" - "github.com/coder/coder/coderd" + agplCoderd "github.com/coder/coder/coderd" + agplAudit "github.com/coder/coder/coderd/audit" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/databasefake" "github.com/coder/coder/codersdk" + "github.com/coder/coder/enterprise/audit" + "github.com/coder/coder/enterprise/audit/backends" "github.com/coder/coder/testutil" ) @@ -282,7 +285,7 @@ func TestFeaturesServiceSyncEntitlements(t *testing.T) { }) } -func requestEntitlements(t *testing.T, uut coderd.FeaturesService) codersdk.Entitlements { +func requestEntitlements(t *testing.T, uut agplCoderd.FeaturesService) codersdk.Entitlements { t.Helper() r := httptest.NewRequest("GET", "https://example.com/api/v2/entitlements", nil) rw := httptest.NewRecorder() @@ -335,3 +338,207 @@ func userLimitIs(fs *featuresService, limit int64) func(context.Context) bool { return fs.entitlements.activeUsers.limit == limit } } + +func TestFeaturesServiceGet(t *testing.T) { + t.Parallel() + logger := slogtest.Make(t, nil) + + // Note that these are not actually used because we don't run the syncEntitlements + // routine in this test. + pubsub := database.NewPubsubInMemory() + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + keyID := "testing" + db := databasefake.New() + + t.Run("AuditorOff", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + target := struct { + Auditor agplAudit.Auditor + }{} + err := uut.Get(&target) + require.NoError(t, err) + assert.NotNil(t, target.Auditor) + nop := agplAudit.NewNop() + assert.Equal(t, reflect.ValueOf(nop).Type(), reflect.ValueOf(target.Auditor).Type()) + }) + + t.Run("AuditorOn", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{entitled}, + }, + } + target := struct { + Auditor agplAudit.Auditor + }{} + err := uut.Get(&target) + require.NoError(t, err) + assert.NotNil(t, target.Auditor) + ea := audit.NewAuditor( + audit.DefaultFilter, + backends.NewPostgres(db, true), + backends.NewSlog(logger), + ) + assert.Equal(t, reflect.ValueOf(ea).Type(), reflect.ValueOf(target.Auditor).Type()) + }) + + t.Run("NotPointer", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + target := struct { + Auditor agplAudit.Auditor + }{} + err := uut.Get(target) + require.Error(t, err) + assert.Nil(t, target.Auditor) + }) + + t.Run("UnknownInterface", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + target := struct { + test testInterface + }{} + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target.test) + }) + + t.Run("PointerToNonStruct", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + var target agplAudit.Auditor + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target) + }) + + t.Run("StructWithNonInterfaces", func(t *testing.T) { + t.Parallel() + uut := &featuresService{ + logger: logger, + database: db, + pubsub: pubsub, + keys: map[string]ed25519.PublicKey{keyID: pub}, + enablements: Enablements{AuditLogs: true}, + enabledImplementations: agplCoderd.FeatureInterfaces{ + Auditor: audit.NewAuditor(audit.DefaultFilter), + }, + entitlements: entitlements{ + hasLicense: false, + activeUsers: numericalEntitlement{ + entitlement{notEntitled}, + entitlementLimit{ + unlimited: true, + }, + }, + auditLogs: entitlement{notEntitled}, + }, + } + target := struct { + N int64 + Auditor agplAudit.Auditor + }{} + err := uut.Get(&target) + require.Error(t, err) + assert.Nil(t, target.Auditor) + }) +} + +type testInterface interface { + Test() error +}