diff --git a/coderd/coderd.go b/coderd/coderd.go index 929c9f44a7a8b..4507cd1dd7605 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -574,7 +574,7 @@ func New(options *Options) *API { TemplateScheduleStore: options.TemplateScheduleStore, UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore, AccessControlStore: options.AccessControlStore, - FileCache: files.NewFromStore(options.Database, options.PrometheusRegistry, options.Authorizer), + FileCache: files.New(options.PrometheusRegistry, options.Authorizer), Experiments: experiments, WebpushDispatcher: options.WebPushDispatcher, healthCheckGroup: &singleflight.Group[string, *healthsdk.HealthcheckReport]{}, diff --git a/coderd/dynamicparameters/render.go b/coderd/dynamicparameters/render.go index 8e7df929505f1..733b2f2ab5f5d 100644 --- a/coderd/dynamicparameters/render.go +++ b/coderd/dynamicparameters/render.go @@ -169,14 +169,14 @@ func (r *loader) dynamicRenderer(ctx context.Context, db database.Store, cache * var templateFS fs.FS var err error - templateFS, err = cache.Acquire(fileCtx, r.job.FileID) + templateFS, err = cache.Acquire(fileCtx, db, r.job.FileID) if err != nil { return nil, xerrors.Errorf("acquire template file: %w", err) } var moduleFilesFS *files.CloseFS if r.terraformValues.CachedModuleFiles.Valid { - moduleFilesFS, err = cache.Acquire(fileCtx, r.terraformValues.CachedModuleFiles.UUID) + moduleFilesFS, err = cache.Acquire(fileCtx, db, r.terraformValues.CachedModuleFiles.UUID) if err != nil { return nil, xerrors.Errorf("acquire module files: %w", err) } diff --git a/coderd/files/cache.go b/coderd/files/cache.go index 170abb10b1ff7..32e03e0b8f209 100644 --- a/coderd/files/cache.go +++ b/coderd/files/cache.go @@ -20,38 +20,15 @@ import ( ) type FileAcquirer interface { - Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) + Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) } -// NewFromStore returns a file cache that will fetch files from the provided -// database. -func NewFromStore(store database.Store, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache { - fetch := func(ctx context.Context, fileID uuid.UUID) (CacheEntryValue, error) { - // Make sure the read does not fail due to authorization issues. - // Authz is checked on the Acquire call, so this is safe. - //nolint:gocritic - file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID) - if err != nil { - return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err) - } - - content := bytes.NewBuffer(file.Data) - return CacheEntryValue{ - Object: file.RBACObject(), - FS: archivefs.FromTarReader(content), - Size: int64(len(file.Data)), - }, nil - } - - return New(fetch, registerer, authz) -} - -func New(fetch fetcher, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache { +// New returns a file cache that will fetch files from a database +func New(registerer prometheus.Registerer, authz rbac.Authorizer) *Cache { return (&Cache{ - lock: sync.Mutex{}, - data: make(map[uuid.UUID]*cacheEntry), - fetcher: fetch, - authz: authz, + lock: sync.Mutex{}, + data: make(map[uuid.UUID]*cacheEntry), + authz: authz, }).registerMetrics(registerer) } @@ -110,9 +87,8 @@ func (c *Cache) registerMetrics(registerer prometheus.Registerer) *Cache { // loaded into memory exactly once. We hold those files until there are no // longer any open connections, and then we remove the value from the map. type Cache struct { - lock sync.Mutex - data map[uuid.UUID]*cacheEntry - fetcher + lock sync.Mutex + data map[uuid.UUID]*cacheEntry authz rbac.Authorizer // metrics @@ -142,8 +118,6 @@ type cacheEntry struct { value *lazy.ValueWithError[CacheEntryValue] } -type fetcher func(context.Context, uuid.UUID) (CacheEntryValue, error) - var _ fs.FS = (*CloseFS)(nil) // CloseFS is a wrapper around fs.FS that implements io.Closer. The Close() @@ -163,12 +137,12 @@ func (f *CloseFS) Close() { f.close() } // // Safety: Every call to Acquire that does not return an error must have a // matching call to Release. -func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) { +func (c *Cache) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) { // It's important that this `Load` call occurs outside `prepare`, after the // mutex has been released, or we would continue to hold the lock until the // entire file has been fetched, which may be slow, and would prevent other // files from being fetched in parallel. - it, err := c.prepare(ctx, fileID).Load() + it, err := c.prepare(ctx, db, fileID).Load() if err != nil { c.release(fileID) return nil, err @@ -195,14 +169,14 @@ func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) }, nil } -func (c *Cache) prepare(ctx context.Context, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] { +func (c *Cache) prepare(ctx context.Context, db database.Store, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] { c.lock.Lock() defer c.lock.Unlock() entry, ok := c.data[fileID] if !ok { value := lazy.NewWithError(func() (CacheEntryValue, error) { - val, err := c.fetcher(ctx, fileID) + val, err := fetch(ctx, db, fileID) // Always add to the cache size the bytes of the file loaded. if err == nil { @@ -269,3 +243,20 @@ func (c *Cache) Count() int { return len(c.data) } + +func fetch(ctx context.Context, store database.Store, fileID uuid.UUID) (CacheEntryValue, error) { + // Make sure the read does not fail due to authorization issues. + // Authz is checked on the Acquire call, so this is safe. + //nolint:gocritic + file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID) + if err != nil { + return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err) + } + + content := bytes.NewBuffer(file.Data) + return CacheEntryValue{ + Object: file.RBACObject(), + FS: archivefs.FromTarReader(content), + Size: int64(len(file.Data)), + }, nil +} diff --git a/coderd/files/cache_test.go b/coderd/files/cache_test.go index 5efb4ba19be28..a5a5dfae268ca 100644 --- a/coderd/files/cache_test.go +++ b/coderd/files/cache_test.go @@ -8,8 +8,8 @@ import ( "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" - "github.com/spf13/afero" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" "golang.org/x/sync/errgroup" "cdr.dev/slog/sloggers/slogtest" @@ -18,6 +18,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/files" "github.com/coder/coder/v2/coderd/rbac" @@ -58,7 +59,7 @@ func TestCacheRBAC(t *testing.T) { require.Equal(t, 0, cache.Count()) rec.Reset() - _, err := cache.Acquire(nobody, file.ID) + _, err := cache.Acquire(nobody, db, file.ID) require.Error(t, err) require.True(t, rbac.IsUnauthorizedError(err)) @@ -75,18 +76,18 @@ func TestCacheRBAC(t *testing.T) { require.Equal(t, 0, cache.Count()) // Read the file with a file reader to put it into the cache. - a, err := cache.Acquire(cacheReader, file.ID) + a, err := cache.Acquire(cacheReader, db, file.ID) require.NoError(t, err) require.Equal(t, 1, cache.Count()) // "nobody" should not be able to read the file. - _, err = cache.Acquire(nobody, file.ID) + _, err = cache.Acquire(nobody, db, file.ID) require.Error(t, err) require.True(t, rbac.IsUnauthorizedError(err)) require.Equal(t, 1, cache.Count()) // UserReader can - b, err := cache.Acquire(userReader, file.ID) + b, err := cache.Acquire(userReader, db, file.ID) require.NoError(t, err) require.Equal(t, 1, cache.Count()) @@ -110,16 +111,21 @@ func TestConcurrency(t *testing.T) { ctx := dbauthz.AsFileReader(t.Context()) const fileSize = 10 - emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs())) var fetches atomic.Int64 reg := prometheus.NewRegistry() - c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) { + + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) { fetches.Add(1) - // Wait long enough before returning to make sure that all of the goroutines + // Wait long enough before returning to make sure that all the goroutines // will be waiting in line, ensuring that no one duplicated a fetch. time.Sleep(testutil.IntervalMedium) - return files.CacheEntryValue{FS: emptyFS, Size: fileSize}, nil - }, reg, &coderdtest.FakeAuthorizer{}) + return database.File{ + Data: make([]byte, fileSize), + }, nil + }).AnyTimes() + + c := files.New(reg, &coderdtest.FakeAuthorizer{}) batches := 1000 groups := make([]*errgroup.Group, 0, batches) @@ -137,7 +143,7 @@ func TestConcurrency(t *testing.T) { g.Go(func() error { // We don't bother to Release these references because the Cache will be // released at the end of the test anyway. - _, err := c.Acquire(ctx, id) + _, err := c.Acquire(ctx, dbM, id) return err }) } @@ -164,14 +170,15 @@ func TestRelease(t *testing.T) { ctx := dbauthz.AsFileReader(t.Context()) const fileSize = 10 - emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs())) reg := prometheus.NewRegistry() - c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) { - return files.CacheEntryValue{ - FS: emptyFS, - Size: fileSize, + dbM := dbmock.NewMockStore(gomock.NewController(t)) + dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) { + return database.File{ + Data: make([]byte, fileSize), }, nil - }, reg, &coderdtest.FakeAuthorizer{}) + }).AnyTimes() + + c := files.New(reg, &coderdtest.FakeAuthorizer{}) batches := 100 ids := make([]uuid.UUID, 0, batches) @@ -184,9 +191,8 @@ func TestRelease(t *testing.T) { batchSize := 10 for openedIdx, id := range ids { for batchIdx := range batchSize { - it, err := c.Acquire(ctx, id) + it, err := c.Acquire(ctx, dbM, id) require.NoError(t, err) - require.Equal(t, emptyFS, it.FS) releases[id] = append(releases[id], it.Close) // Each time a new file is opened, the metrics should be updated as so: @@ -257,7 +263,7 @@ func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.Re // Dbauthz wrap the db db = dbauthz.New(db, rec, logger, coderdtest.AccessControlStorePointer()) - c := files.NewFromStore(db, reg, rec) + c := files.New(reg, rec) return db, c, rec } diff --git a/coderd/files/closer.go b/coderd/files/closer.go index 9bd98fdd60caf..560786c78f80e 100644 --- a/coderd/files/closer.go +++ b/coderd/files/closer.go @@ -6,6 +6,8 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" ) // CacheCloser is a cache wrapper used to close all acquired files. @@ -38,7 +40,7 @@ func (c *CacheCloser) Close() { c.closers = nil } -func (c *CacheCloser) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) { +func (c *CacheCloser) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) { c.mu.Lock() defer c.mu.Unlock() @@ -46,7 +48,7 @@ func (c *CacheCloser) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, return nil, xerrors.New("cache is closed, and cannot acquire new files") } - f, err := c.cache.Acquire(ctx, fileID) + f, err := c.cache.Acquire(ctx, db, fileID) if err != nil { return nil, err }