Skip to content

chore: use database in current context for file cache #18490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coderd/coderd.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func New(options *Options) *API {
TemplateScheduleStore: options.TemplateScheduleStore,
UserQuietHoursScheduleStore: options.UserQuietHoursScheduleStore,
AccessControlStore: options.AccessControlStore,
FileCache: files.NewFromStore(options.Database, options.PrometheusRegistry, options.Authorizer),
FileCache: files.New(options.PrometheusRegistry, options.Authorizer),
Experiments: experiments,
WebpushDispatcher: options.WebPushDispatcher,
healthCheckGroup: &singleflight.Group[string, *healthsdk.HealthcheckReport]{},
Expand Down
4 changes: 2 additions & 2 deletions coderd/dynamicparameters/render.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ func (r *loader) dynamicRenderer(ctx context.Context, db database.Store, cache *
var templateFS fs.FS
var err error

templateFS, err = cache.Acquire(fileCtx, r.job.FileID)
templateFS, err = cache.Acquire(fileCtx, db, r.job.FileID)
if err != nil {
return nil, xerrors.Errorf("acquire template file: %w", err)
}

var moduleFilesFS *files.CloseFS
if r.terraformValues.CachedModuleFiles.Valid {
moduleFilesFS, err = cache.Acquire(fileCtx, r.terraformValues.CachedModuleFiles.UUID)
moduleFilesFS, err = cache.Acquire(fileCtx, db, r.terraformValues.CachedModuleFiles.UUID)
if err != nil {
return nil, xerrors.Errorf("acquire module files: %w", err)
}
Expand Down
67 changes: 29 additions & 38 deletions coderd/files/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,15 @@ import (
)

type FileAcquirer interface {
Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error)
Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error)
}

// NewFromStore returns a file cache that will fetch files from the provided
// database.
func NewFromStore(store database.Store, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache {
fetch := func(ctx context.Context, fileID uuid.UUID) (CacheEntryValue, error) {
// Make sure the read does not fail due to authorization issues.
// Authz is checked on the Acquire call, so this is safe.
//nolint:gocritic
file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID)
if err != nil {
return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err)
}

content := bytes.NewBuffer(file.Data)
return CacheEntryValue{
Object: file.RBACObject(),
FS: archivefs.FromTarReader(content),
Size: int64(len(file.Data)),
}, nil
}

return New(fetch, registerer, authz)
}

func New(fetch fetcher, registerer prometheus.Registerer, authz rbac.Authorizer) *Cache {
// New returns a file cache that will fetch files from a database
func New(registerer prometheus.Registerer, authz rbac.Authorizer) *Cache {
return (&Cache{
lock: sync.Mutex{},
data: make(map[uuid.UUID]*cacheEntry),
fetcher: fetch,
authz: authz,
lock: sync.Mutex{},
data: make(map[uuid.UUID]*cacheEntry),
authz: authz,
}).registerMetrics(registerer)
}

Expand Down Expand Up @@ -110,9 +87,8 @@ func (c *Cache) registerMetrics(registerer prometheus.Registerer) *Cache {
// loaded into memory exactly once. We hold those files until there are no
// longer any open connections, and then we remove the value from the map.
type Cache struct {
lock sync.Mutex
data map[uuid.UUID]*cacheEntry
fetcher
lock sync.Mutex
data map[uuid.UUID]*cacheEntry
authz rbac.Authorizer

// metrics
Expand Down Expand Up @@ -142,8 +118,6 @@ type cacheEntry struct {
value *lazy.ValueWithError[CacheEntryValue]
}

type fetcher func(context.Context, uuid.UUID) (CacheEntryValue, error)

var _ fs.FS = (*CloseFS)(nil)

// CloseFS is a wrapper around fs.FS that implements io.Closer. The Close()
Expand All @@ -163,12 +137,12 @@ func (f *CloseFS) Close() { f.close() }
//
// Safety: Every call to Acquire that does not return an error must have a
// matching call to Release.
func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) {
func (c *Cache) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) {
// It's important that this `Load` call occurs outside `prepare`, after the
// mutex has been released, or we would continue to hold the lock until the
// entire file has been fetched, which may be slow, and would prevent other
// files from being fetched in parallel.
it, err := c.prepare(ctx, fileID).Load()
it, err := c.prepare(ctx, db, fileID).Load()
if err != nil {
c.release(fileID)
return nil, err
Expand All @@ -195,14 +169,14 @@ func (c *Cache) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error)
}, nil
}

func (c *Cache) prepare(ctx context.Context, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] {
func (c *Cache) prepare(ctx context.Context, db database.Store, fileID uuid.UUID) *lazy.ValueWithError[CacheEntryValue] {
c.lock.Lock()
defer c.lock.Unlock()

entry, ok := c.data[fileID]
if !ok {
value := lazy.NewWithError(func() (CacheEntryValue, error) {
val, err := c.fetcher(ctx, fileID)
val, err := fetch(ctx, db, fileID)

// Always add to the cache size the bytes of the file loaded.
if err == nil {
Expand Down Expand Up @@ -269,3 +243,20 @@ func (c *Cache) Count() int {

return len(c.data)
}

func fetch(ctx context.Context, store database.Store, fileID uuid.UUID) (CacheEntryValue, error) {
// Make sure the read does not fail due to authorization issues.
// Authz is checked on the Acquire call, so this is safe.
//nolint:gocritic
file, err := store.GetFileByID(dbauthz.AsFileReader(ctx), fileID)
if err != nil {
return CacheEntryValue{}, xerrors.Errorf("failed to read file from database: %w", err)
}

content := bytes.NewBuffer(file.Data)
return CacheEntryValue{
Object: file.RBACObject(),
FS: archivefs.FromTarReader(content),
Size: int64(len(file.Data)),
}, nil
}
46 changes: 26 additions & 20 deletions coderd/files/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ import (

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/afero"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/sync/errgroup"

"cdr.dev/slog/sloggers/slogtest"
Expand All @@ -18,6 +18,7 @@ import (
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/files"
"github.com/coder/coder/v2/coderd/rbac"
Expand Down Expand Up @@ -58,7 +59,7 @@ func TestCacheRBAC(t *testing.T) {
require.Equal(t, 0, cache.Count())
rec.Reset()

_, err := cache.Acquire(nobody, file.ID)
_, err := cache.Acquire(nobody, db, file.ID)
require.Error(t, err)
require.True(t, rbac.IsUnauthorizedError(err))

Expand All @@ -75,18 +76,18 @@ func TestCacheRBAC(t *testing.T) {
require.Equal(t, 0, cache.Count())

// Read the file with a file reader to put it into the cache.
a, err := cache.Acquire(cacheReader, file.ID)
a, err := cache.Acquire(cacheReader, db, file.ID)
require.NoError(t, err)
require.Equal(t, 1, cache.Count())

// "nobody" should not be able to read the file.
_, err = cache.Acquire(nobody, file.ID)
_, err = cache.Acquire(nobody, db, file.ID)
require.Error(t, err)
require.True(t, rbac.IsUnauthorizedError(err))
require.Equal(t, 1, cache.Count())

// UserReader can
b, err := cache.Acquire(userReader, file.ID)
b, err := cache.Acquire(userReader, db, file.ID)
require.NoError(t, err)
require.Equal(t, 1, cache.Count())

Expand All @@ -110,16 +111,21 @@ func TestConcurrency(t *testing.T) {
ctx := dbauthz.AsFileReader(t.Context())

const fileSize = 10
emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs()))
var fetches atomic.Int64
reg := prometheus.NewRegistry()
c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) {

dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) {
fetches.Add(1)
// Wait long enough before returning to make sure that all of the goroutines
// Wait long enough before returning to make sure that all the goroutines
// will be waiting in line, ensuring that no one duplicated a fetch.
time.Sleep(testutil.IntervalMedium)
return files.CacheEntryValue{FS: emptyFS, Size: fileSize}, nil
}, reg, &coderdtest.FakeAuthorizer{})
return database.File{
Data: make([]byte, fileSize),
}, nil
}).AnyTimes()

c := files.New(reg, &coderdtest.FakeAuthorizer{})

batches := 1000
groups := make([]*errgroup.Group, 0, batches)
Expand All @@ -137,7 +143,7 @@ func TestConcurrency(t *testing.T) {
g.Go(func() error {
// We don't bother to Release these references because the Cache will be
// released at the end of the test anyway.
_, err := c.Acquire(ctx, id)
_, err := c.Acquire(ctx, dbM, id)
return err
})
}
Expand All @@ -164,14 +170,15 @@ func TestRelease(t *testing.T) {
ctx := dbauthz.AsFileReader(t.Context())

const fileSize = 10
emptyFS := afero.NewIOFS(afero.NewReadOnlyFs(afero.NewMemMapFs()))
reg := prometheus.NewRegistry()
c := files.New(func(_ context.Context, _ uuid.UUID) (files.CacheEntryValue, error) {
return files.CacheEntryValue{
FS: emptyFS,
Size: fileSize,
dbM := dbmock.NewMockStore(gomock.NewController(t))
dbM.EXPECT().GetFileByID(gomock.Any(), gomock.Any()).DoAndReturn(func(mTx context.Context, fileID uuid.UUID) (database.File, error) {
return database.File{
Data: make([]byte, fileSize),
}, nil
}, reg, &coderdtest.FakeAuthorizer{})
}).AnyTimes()

c := files.New(reg, &coderdtest.FakeAuthorizer{})

batches := 100
ids := make([]uuid.UUID, 0, batches)
Expand All @@ -184,9 +191,8 @@ func TestRelease(t *testing.T) {
batchSize := 10
for openedIdx, id := range ids {
for batchIdx := range batchSize {
it, err := c.Acquire(ctx, id)
it, err := c.Acquire(ctx, dbM, id)
require.NoError(t, err)
require.Equal(t, emptyFS, it.FS)
releases[id] = append(releases[id], it.Close)

// Each time a new file is opened, the metrics should be updated as so:
Expand Down Expand Up @@ -257,7 +263,7 @@ func cacheAuthzSetup(t *testing.T) (database.Store, *files.Cache, *coderdtest.Re

// Dbauthz wrap the db
db = dbauthz.New(db, rec, logger, coderdtest.AccessControlStorePointer())
c := files.NewFromStore(db, reg, rec)
c := files.New(reg, rec)
return db, c, rec
}

Expand Down
6 changes: 4 additions & 2 deletions coderd/files/closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/google/uuid"
"golang.org/x/xerrors"

"github.com/coder/coder/v2/coderd/database"
)

// CacheCloser is a cache wrapper used to close all acquired files.
Expand Down Expand Up @@ -38,15 +40,15 @@ func (c *CacheCloser) Close() {
c.closers = nil
}

func (c *CacheCloser) Acquire(ctx context.Context, fileID uuid.UUID) (*CloseFS, error) {
func (c *CacheCloser) Acquire(ctx context.Context, db database.Store, fileID uuid.UUID) (*CloseFS, error) {
c.mu.Lock()
defer c.mu.Unlock()

if c.cache == nil {
return nil, xerrors.New("cache is closed, and cannot acquire new files")
}

f, err := c.cache.Acquire(ctx, fileID)
f, err := c.cache.Acquire(ctx, db, fileID)
if err != nil {
return nil, err
}
Expand Down
Loading