Skip to content

Commit 3cb39ea

Browse files
committed
test: add unit test for closing files
1 parent a5ee374 commit 3cb39ea

File tree

3 files changed

+90
-10
lines changed

3 files changed

+90
-10
lines changed

coderd/files/cache.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616

1717
// NewFromStore returns a file cache that will fetch files from the provided
1818
// database.
19-
func NewFromStore(store database.Store) Cache {
19+
func NewFromStore(store database.Store) *Cache {
2020
fetcher := func(ctx context.Context, fileID uuid.UUID) (fs.FS, error) {
2121
file, err := store.GetFileByID(ctx, fileID)
2222
if err != nil {
@@ -27,7 +27,7 @@ func NewFromStore(store database.Store) Cache {
2727
return archivefs.FromTarReader(content), nil
2828
}
2929

30-
return Cache{
30+
return &Cache{
3131
lock: sync.Mutex{},
3232
data: make(map[uuid.UUID]*cacheEntry),
3333
fetcher: fetcher,

coderd/parameters.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,23 @@ func (api *API) templateVersionDynamicParameters(rw http.ResponseWriter, r *http
7171
return
7272
}
7373

74+
// staticDiagnostics is a set of diagnostics to be applied to all rendered results.
7475
staticDiagnostics := parameterProvisionerVersionDiagnostic(tf)
7576

77+
// render is the function that given a set of input values, will return the
78+
// parameter state. There is 2 rendering functions.
79+
//
80+
// prepareStaticPreview uses the static set of parameters saved from the template
81+
// import. These parameters are returned on every request, and have no dynamic
82+
// functionality. This exists for backwards compatibility with older template versions
83+
// which have not uploaded their plan & module files.
84+
//
85+
// prepareDynamicPreview uses the dynamic preview engine.
7686
var render previewFunction
7787
major, minor, err := apiversion.Parse(tf.ProvisionerdVersion)
7888
if err != nil || major < 1 || (major == 1 && minor < 5) {
89+
// Versions < 1.5 do not upload the required files.
90+
// Versions == "" are < 1.5, but we don't know the exact version.
7991
staticRender, err := prepareStaticPreview(ctx, api.Database, templateVersion.ID)
8092
if err != nil {
8193
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -113,7 +125,7 @@ func (api *API) templateVersionDynamicParameters(rw http.ResponseWriter, r *http
113125
// Send an initial form state, computed without any user input.
114126
result, diagnostics := render(ctx, map[string]string{})
115127
response := codersdk.DynamicParametersResponse{
116-
ID: -1,
128+
ID: -1, // Always start with -1.
117129
Diagnostics: previewtypes.Diagnostics(diagnostics.Extend(staticDiagnostics)),
118130
}
119131
if result != nil {
@@ -138,6 +150,7 @@ func (api *API) templateVersionDynamicParameters(rw http.ResponseWriter, r *http
138150
// The connection has been closed, so there is no one to write to
139151
return
140152
}
153+
141154
result, diagnostics := render(ctx, update.Inputs)
142155
response := codersdk.DynamicParametersResponse{
143156
ID: update.ID,
@@ -158,12 +171,16 @@ func (api *API) templateVersionDynamicParameters(rw http.ResponseWriter, r *http
158171
type previewFunction func(ctx context.Context, values map[string]string) (*preview.Output, hcl.Diagnostics)
159172

160173
func prepareDynamicPreview(ctx context.Context, rw http.ResponseWriter, db database.Store, fc *files.Cache, tf database.TemplateVersionTerraformValue, templateVersion database.TemplateVersion, user database.User) (render previewFunction, closer func(), success bool) {
174+
// keep track of all files opened
161175
openFiles := make([]uuid.UUID, 0)
162176
closeFiles := func() {
163177
for _, it := range openFiles {
164178
fc.Release(it)
165179
}
166180
}
181+
182+
// This defer will close the files if the function exits early without success.
183+
// Closing the files is important to avoid having a memory leak.
167184
defer func() {
168185
if !success {
169186
closeFiles()
@@ -182,6 +199,8 @@ func prepareDynamicPreview(ctx context.Context, rw http.ResponseWriter, db datab
182199
return nil, nil, false
183200
}
184201

202+
// Add the file first. Calling `Release` if it fails is a no-op, so this is safe.
203+
openFiles = append(openFiles, fileID)
185204
templateFS, err := fc.Acquire(fileCtx, fileID)
186205
if err != nil {
187206
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@@ -190,14 +209,14 @@ func prepareDynamicPreview(ctx context.Context, rw http.ResponseWriter, db datab
190209
})
191210
return nil, nil, false
192211
}
193-
openFiles = append(openFiles, fileID)
194212

195213
// Having the Terraform plan available for the evaluation engine is helpful
196214
// for populating values from data blocks, but isn't strictly required. If
197215
// we don't have a cached plan available, we just use an empty one instead.
198216
plan := json.RawMessage("{}")
199217
plan = tf.CachedPlan
200218

219+
openFiles = append(openFiles, tf.CachedModuleFiles.UUID)
201220
if tf.CachedModuleFiles.Valid {
202221
moduleFilesFS, err := fc.Acquire(fileCtx, tf.CachedModuleFiles.UUID)
203222
if err != nil {
@@ -207,7 +226,6 @@ func prepareDynamicPreview(ctx context.Context, rw http.ResponseWriter, db datab
207226
})
208227
return nil, nil, false
209228
}
210-
openFiles = append(openFiles, tf.CachedModuleFiles.UUID)
211229

212230
templateFS = files.NewOverlayFS(templateFS, []files.Overlay{{Path: ".terraform/modules", FS: moduleFilesFS}})
213231
}
@@ -371,7 +389,10 @@ func getWorkspaceOwnerData(
371389

372390
var publicKey string
373391
g.Go(func() error {
374-
key, err := db.GetGitSSHKey(ctx, user.ID)
392+
// The correct public key has to be sent. This will not be leaked
393+
// unless the template leaks it.
394+
// nolint:gocritic
395+
key, err := db.GetGitSSHKey(dbauthz.AsSystemRestricted(ctx), user.ID)
375396
if err != nil {
376397
return err
377398
}
@@ -381,7 +402,11 @@ func getWorkspaceOwnerData(
381402

382403
var groupNames []string
383404
g.Go(func() error {
384-
groups, err := db.GetGroups(ctx, database.GetGroupsParams{
405+
// The groups need to be sent to preview. These groups are not exposed to the
406+
// user, unless the template does it through the parameters. Regardless, we need
407+
// the correct groups, and a user might not have read access.
408+
// nolint:gocritic
409+
groups, err := db.GetGroups(dbauthz.AsSystemRestricted(ctx), database.GetGroupsParams{
385410
OrganizationID: organizationID,
386411
HasMemberID: user.ID,
387412
})

coderd/parameters_test.go

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
package coderd_test
22

33
import (
4+
"context"
45
"os"
56
"testing"
67

8+
"github.com/google/uuid"
79
"github.com/stretchr/testify/require"
10+
"golang.org/x/xerrors"
811

912
"github.com/coder/coder/v2/coderd"
1013
"github.com/coder/coder/v2/coderd/coderdtest"
14+
"github.com/coder/coder/v2/coderd/database"
15+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
16+
"github.com/coder/coder/v2/coderd/database/pubsub"
1117
"github.com/coder/coder/v2/coderd/rbac"
1218
"github.com/coder/coder/v2/codersdk"
1319
"github.com/coder/coder/v2/codersdk/wsjson"
@@ -141,6 +147,8 @@ func TestDynamicParametersWithTerraformValues(t *testing.T) {
141147
t.Parallel()
142148

143149
t.Run("OK_Modules", func(t *testing.T) {
150+
t.Parallel()
151+
144152
dynamicParametersTerraformSource, err := os.ReadFile("testdata/parameters/modules/main.tf")
145153
require.NoError(t, err)
146154

@@ -172,6 +180,8 @@ func TestDynamicParametersWithTerraformValues(t *testing.T) {
172180

173181
// OldProvisioners use the static parameters in the dynamic param flow
174182
t.Run("OldProvisioner", func(t *testing.T) {
183+
t.Parallel()
184+
175185
setup := setupDynamicParamsTest(t, setupDynamicParamsTestParams{
176186
provisionerDaemonVersion: "1.4",
177187
mainTF: nil,
@@ -244,15 +254,42 @@ func TestDynamicParametersWithTerraformValues(t *testing.T) {
244254
}
245255

246256
})
257+
258+
t.Run("FileError", func(t *testing.T) {
259+
// Verify files close even if the websocket terminates from an error
260+
t.Parallel()
261+
262+
db, ps := dbtestutil.NewDB(t)
263+
dynamicParametersTerraformSource, err := os.ReadFile("testdata/parameters/modules/main.tf")
264+
require.NoError(t, err)
265+
266+
modulesArchive, err := terraform.GetModulesArchive(os.DirFS("testdata/parameters/modules"))
267+
require.NoError(t, err)
268+
269+
setup := setupDynamicParamsTest(t, setupDynamicParamsTestParams{
270+
db: &dbRejectGitSSHKey{Store: db},
271+
ps: ps,
272+
provisionerDaemonVersion: provProto.CurrentVersion.String(),
273+
mainTF: dynamicParametersTerraformSource,
274+
modulesArchive: modulesArchive,
275+
expectWebsocketError: true,
276+
})
277+
// This is checked in setupDynamicParamsTest. Just doing this in the
278+
// test to make it obvious what this test is doing.
279+
require.Zero(t, setup.api.FileCache.Count())
280+
})
247281
}
248282

249283
type setupDynamicParamsTestParams struct {
284+
db database.Store
285+
ps pubsub.Pubsub
250286
provisionerDaemonVersion string
251287
mainTF []byte
252288
modulesArchive []byte
253289
plan []byte
254290

255-
static []*proto.RichParameter
291+
static []*proto.RichParameter
292+
expectWebsocketError bool
256293
}
257294

258295
type dynamicParamsTest struct {
@@ -265,6 +302,8 @@ func setupDynamicParamsTest(t *testing.T, args setupDynamicParamsTestParams) dyn
265302
cfg := coderdtest.DeploymentValues(t)
266303
cfg.Experiments = []string{string(codersdk.ExperimentDynamicParameters)}
267304
ownerClient, _, api := coderdtest.NewWithAPI(t, &coderdtest.Options{
305+
Database: args.db,
306+
Pubsub: args.ps,
268307
IncludeProvisionerDaemon: true,
269308
ProvisionerDaemonVersion: args.provisionerDaemonVersion,
270309
DeploymentValues: cfg,
@@ -292,10 +331,16 @@ func setupDynamicParamsTest(t *testing.T, args setupDynamicParamsTestParams) dyn
292331

293332
ctx := testutil.Context(t, testutil.WaitShort)
294333
stream, err := templateAdmin.TemplateVersionDynamicParameters(ctx, templateAdminUser.ID, version.ID)
295-
require.NoError(t, err)
334+
if args.expectWebsocketError {
335+
require.Errorf(t, err, "expected error forming websocket")
336+
} else {
337+
require.NoError(t, err)
338+
}
296339

297340
t.Cleanup(func() {
298-
_ = stream.Close(websocket.StatusGoingAway)
341+
if stream != nil {
342+
_ = stream.Close(websocket.StatusGoingAway)
343+
}
299344
// Cache should always have 0 files when the only stream is closed
300345
require.Eventually(t, func() bool {
301346
return api.FileCache.Count() == 0
@@ -308,3 +353,13 @@ func setupDynamicParamsTest(t *testing.T, args setupDynamicParamsTestParams) dyn
308353
api: api,
309354
}
310355
}
356+
357+
// dbRejectGitSSHKey is a cheeky way to force an error to occur in a place
358+
// that is generally impossible to force an error.
359+
type dbRejectGitSSHKey struct {
360+
database.Store
361+
}
362+
363+
func (d *dbRejectGitSSHKey) GetGitSSHKey(_ context.Context, _ uuid.UUID) (database.GitSSHKey, error) {
364+
return database.GitSSHKey{}, xerrors.New("forcing a fake error")
365+
}

0 commit comments

Comments
 (0)