Skip to content

Commit 78d88fe

Browse files
committed
do some more intelligent HCL parsing to determine required variables
1 parent e56cac7 commit 78d88fe

File tree

3 files changed

+101
-46
lines changed

3 files changed

+101
-46
lines changed

provisioner/terraform/parse.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (s *server) Parse(sess *provisionersdk.Session, _ *proto.ParseRequest, _ <-
2626
return provisionersdk.ParseErrorf("load module: %s", formatDiagnostics(sess.WorkDirectory, diags))
2727
}
2828

29-
workspaceTags, err := parser.WorkspaceTags(ctx)
29+
workspaceTags, _, err := parser.WorkspaceTags(ctx)
3030
if err != nil {
3131
return provisionersdk.ParseErrorf("can't load workspace tags: %v", err)
3232
}

provisioner/terraform/tfparse/tfparse.go

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,12 @@ func New(workdir string, opts ...Option) (*Parser, tfconfig.Diagnostics) {
8080
}
8181

8282
// WorkspaceTags looks for all coder_workspace_tags datasource in the module
83-
// and returns the raw values for the tags.
84-
func (p *Parser) WorkspaceTags(ctx context.Context) (map[string]string, error) {
83+
// and returns the raw values for the tags. It also returns the set of
84+
// variables referenced by any expressions in the raw values of tags.
85+
func (p *Parser) WorkspaceTags(ctx context.Context) (map[string]string, map[string]struct{}, error) {
8586
tags := map[string]string{}
86-
var skipped []string
87+
skipped := []string{}
88+
requiredVars := map[string]struct{}{}
8789
for _, dataResource := range p.module.DataResources {
8890
if dataResource.Type != "coder_workspace_tags" {
8991
skipped = append(skipped, strings.Join([]string{"data", dataResource.Type, dataResource.Name}, "."))
@@ -99,13 +101,13 @@ func (p *Parser) WorkspaceTags(ctx context.Context) (map[string]string, error) {
99101
// We know in which HCL file is the data resource defined.
100102
file, diags = p.underlying.ParseHCLFile(dataResource.Pos.Filename)
101103
if diags.HasErrors() {
102-
return nil, xerrors.Errorf("can't parse the resource file: %s", diags.Error())
104+
return nil, nil, xerrors.Errorf("can't parse the resource file: %s", diags.Error())
103105
}
104106

105107
// Parse root to find "coder_workspace_tags".
106108
content, _, diags := file.Body.PartialContent(rootTemplateSchema)
107109
if diags.HasErrors() {
108-
return nil, xerrors.Errorf("can't parse the resource file: %s", diags.Error())
110+
return nil, nil, xerrors.Errorf("can't parse the resource file: %s", diags.Error())
109111
}
110112

111113
// Iterate over blocks to locate the exact "coder_workspace_tags" data resource.
@@ -117,51 +119,99 @@ func (p *Parser) WorkspaceTags(ctx context.Context) (map[string]string, error) {
117119
// Parse "coder_workspace_tags" to find all key-value tags.
118120
resContent, _, diags := block.Body.PartialContent(coderWorkspaceTagsSchema)
119121
if diags.HasErrors() {
120-
return nil, xerrors.Errorf(`can't parse the resource coder_workspace_tags: %s`, diags.Error())
122+
return nil, nil, xerrors.Errorf(`can't parse the resource coder_workspace_tags: %s`, diags.Error())
121123
}
122124

123125
if resContent == nil {
124126
continue // workspace tags are not present
125127
}
126128

127129
if _, ok := resContent.Attributes["tags"]; !ok {
128-
return nil, xerrors.Errorf(`"tags" attribute is required by coder_workspace_tags`)
130+
return nil, nil, xerrors.Errorf(`"tags" attribute is required by coder_workspace_tags`)
129131
}
130132

131133
expr := resContent.Attributes["tags"].Expr
132134
tagsExpr, ok := expr.(*hclsyntax.ObjectConsExpr)
133135
if !ok {
134-
return nil, xerrors.Errorf(`"tags" attribute is expected to be a key-value map`)
136+
return nil, nil, xerrors.Errorf(`"tags" attribute is expected to be a key-value map`)
135137
}
136138

137139
// Parse key-value entries in "coder_workspace_tags"
138140
for _, tagItem := range tagsExpr.Items {
139141
key, err := previewFileContent(tagItem.KeyExpr.Range())
140142
if err != nil {
141-
return nil, xerrors.Errorf("can't preview the resource file: %v", err)
143+
return nil, nil, xerrors.Errorf("can't preview the resource file: %v", err)
142144
}
143145
key = strings.Trim(key, `"`)
144146

145147
value, err := previewFileContent(tagItem.ValueExpr.Range())
146148
if err != nil {
147-
return nil, xerrors.Errorf("can't preview the resource file: %v", err)
149+
return nil, nil, xerrors.Errorf("can't preview the resource file: %v", err)
148150
}
149151

150152
if _, ok := tags[key]; ok {
151-
return nil, xerrors.Errorf(`workspace tag %q is defined multiple times`, key)
153+
return nil, nil, xerrors.Errorf(`workspace tag %q is defined multiple times`, key)
152154
}
153155
tags[key] = value
156+
157+
// Find values referenced by the expression.
158+
refVars := referencedVariablesExpr(tagItem.ValueExpr)
159+
for _, refVar := range refVars {
160+
requiredVars[refVar] = struct{}{}
161+
}
154162
}
155163
}
156164
}
157-
p.logger.Debug(ctx, "found workspace tags", slog.F("tags", maps.Keys(tags)), slog.F("skipped", skipped))
158-
return tags, nil
165+
166+
requiredVarNames := maps.Keys(requiredVars)
167+
slices.Sort(requiredVarNames)
168+
p.logger.Debug(ctx, "found workspace tags", slog.F("tags", maps.Keys(tags)), slog.F("skipped", skipped), slog.F("required_vars", requiredVarNames))
169+
return tags, requiredVars, nil
170+
}
171+
172+
// referencedVariablesExpr determines the variables referenced in expr
173+
// and returns the names of those variables.
174+
func referencedVariablesExpr(expr hclsyntax.Expression) (names []string) {
175+
var parts []string
176+
for _, expVar := range expr.Variables() {
177+
for _, tr := range expVar {
178+
switch v := tr.(type) {
179+
case hcl.TraverseRoot:
180+
parts = append(parts, v.Name)
181+
case hcl.TraverseAttr:
182+
parts = append(parts, v.Name)
183+
default: // skip
184+
}
185+
}
186+
187+
cleaned := cleanupTraversalName(parts)
188+
names = append(names, strings.Join(cleaned, "."))
189+
}
190+
return names
191+
}
192+
193+
// cleanupTraversalName chops off extraneous pieces of the traversal.
194+
// for example:
195+
// - var.foo -> unchanged
196+
// - data.coder_parameter.bar.value -> data.coder_parameter.bar
197+
// - null_resource.baz.zap -> null_resource.baz
198+
func cleanupTraversalName(parts []string) []string {
199+
if len(parts) == 0 {
200+
return parts
201+
}
202+
if len(parts) > 3 && parts[0] == "data" {
203+
return parts[:3]
204+
}
205+
if len(parts) > 2 {
206+
return parts[:2]
207+
}
208+
return parts
159209
}
160210

161211
func (p *Parser) WorkspaceTagDefaults(ctx context.Context) (map[string]string, error) {
162212
// This only gets us the expressions. We need to evaluate them.
163213
// Example: var.region -> "us"
164-
tags, err := p.WorkspaceTags(ctx)
214+
tags, requiredVars, err := p.WorkspaceTags(ctx)
165215
if err != nil {
166216
return nil, xerrors.Errorf("extract workspace tags: %w", err)
167217
}
@@ -172,11 +222,11 @@ func (p *Parser) WorkspaceTagDefaults(ctx context.Context) (map[string]string, e
172222

173223
// To evaluate the expressions, we need to load the default values for
174224
// variables and parameters.
175-
varsDefaults, err := p.VariableDefaults(ctx, tags)
225+
varsDefaults, err := p.VariableDefaults(ctx)
176226
if err != nil {
177227
return nil, xerrors.Errorf("load variable defaults: %w", err)
178228
}
179-
paramsDefaults, err := p.CoderParameterDefaults(ctx, varsDefaults, tags)
229+
paramsDefaults, err := p.CoderParameterDefaults(ctx, varsDefaults, requiredVars)
180230
if err != nil {
181231
return nil, xerrors.Errorf("load parameter defaults: %w", err)
182232
}
@@ -251,39 +301,28 @@ func WriteArchive(bs []byte, mimetype string, path string) error {
251301
return nil
252302
}
253303

254-
// VariableDefaults returns the default values for all variables referenced in the values of tags.
255-
func (p *Parser) VariableDefaults(ctx context.Context, tags map[string]string) (map[string]string, error) {
256-
var skipped []string
304+
// VariableDefaults returns the default values for all variables in the module.
305+
func (p *Parser) VariableDefaults(ctx context.Context) (map[string]string, error) {
257306
// iterate through vars to get the default values for all
258307
// required variables.
259308
m := make(map[string]string)
260309
for _, v := range p.module.Variables {
261310
if v == nil {
262311
continue
263312
}
264-
var found bool
265-
for _, tv := range tags {
266-
if strings.Contains(tv, v.Name) {
267-
found = true
268-
}
269-
}
270-
if !found {
271-
skipped = append(skipped, "var."+v.Name)
272-
continue
273-
}
274313
sv, err := interfaceToString(v.Default)
275314
if err != nil {
276315
return nil, xerrors.Errorf("can't convert variable default value to string: %v", err)
277316
}
278317
m[v.Name] = strings.Trim(sv, `"`)
279318
}
280-
p.logger.Debug(ctx, "found default values for variables", slog.F("defaults", m), slog.F("skipped", skipped))
319+
p.logger.Debug(ctx, "found default values for variables", slog.F("defaults", m))
281320
return m, nil
282321
}
283322

284323
// CoderParameterDefaults returns the default values of all coder_parameter data sources
285324
// in the parsed module.
286-
func (p *Parser) CoderParameterDefaults(ctx context.Context, varsDefaults map[string]string, tags map[string]string) (map[string]string, error) {
325+
func (p *Parser) CoderParameterDefaults(ctx context.Context, varsDefaults map[string]string, names map[string]struct{}) (map[string]string, error) {
287326
defaultsM := make(map[string]string)
288327
var (
289328
skipped []string
@@ -296,23 +335,17 @@ func (p *Parser) CoderParameterDefaults(ctx context.Context, varsDefaults map[st
296335
continue
297336
}
298337

299-
if dataResource.Type != "coder_parameter" {
300-
skipped = append(skipped, strings.Join([]string{"data", dataResource.Type, dataResource.Name}, "."))
301-
continue
302-
}
303-
304338
if !strings.HasSuffix(dataResource.Pos.Filename, ".tf") {
305339
continue
306340
}
307341

308-
var found bool
309342
needle := strings.Join([]string{"data", dataResource.Type, dataResource.Name}, ".")
310-
for _, tv := range tags {
311-
if strings.Contains(tv, needle) {
312-
found = true
313-
}
343+
if dataResource.Type != "coder_parameter" {
344+
skipped = append(skipped, needle)
345+
continue
314346
}
315-
if !found {
347+
348+
if _, found := names[needle]; !found {
316349
skipped = append(skipped, needle)
317350
continue
318351
}

provisioner/terraform/tfparse/tfparse_test.go

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func Test_WorkspaceTagDefaultsFromFile(t *testing.T) {
3434
name: "single text file",
3535
files: map[string]string{
3636
"file.txt": `
37-
hello world`,
37+
hello world`,
3838
},
3939
expectTags: map[string]string{},
4040
expectError: "",
@@ -539,6 +539,28 @@ func Test_WorkspaceTagDefaultsFromFile(t *testing.T) {
539539
expectTags: nil,
540540
expectError: `can't convert variable default value to string: unsupported type map[string]interface {}`,
541541
},
542+
{
543+
name: "overlapping var name",
544+
files: map[string]string{
545+
`main.tf`: `
546+
variable "a" {
547+
type = string
548+
default = "1"
549+
}
550+
variable "ab" {
551+
description = "This is a variable of type string"
552+
type = string
553+
default = "ab"
554+
}
555+
data "coder_workspace_tags" "tags" {
556+
tags = {
557+
"foo": "bar",
558+
"a": var.a,
559+
}
560+
}`,
561+
},
562+
expectTags: map[string]string{"foo": "bar", "a": "1"},
563+
},
542564
} {
543565
tc := tc
544566
t.Run(tc.name+"/tar", func(t *testing.T) {
@@ -622,7 +644,7 @@ func BenchmarkWorkspaceTagDefaultsFromFile(b *testing.B) {
622644
tfparse.WriteArchive(tarFile, "application/x-tar", tmpDir)
623645
parser, diags := tfparse.New(tmpDir, tfparse.WithLogger(logger))
624646
require.NoError(b, diags.Err())
625-
_, err := parser.WorkspaceTags(ctx)
647+
_, _, err := parser.WorkspaceTags(ctx)
626648
if err != nil {
627649
b.Fatal(err)
628650
}
@@ -636,7 +658,7 @@ func BenchmarkWorkspaceTagDefaultsFromFile(b *testing.B) {
636658
tfparse.WriteArchive(zipFile, "application/zip", tmpDir)
637659
parser, diags := tfparse.New(tmpDir, tfparse.WithLogger(logger))
638660
require.NoError(b, diags.Err())
639-
_, err := parser.WorkspaceTags(ctx)
661+
_, _, err := parser.WorkspaceTags(ctx)
640662
if err != nil {
641663
b.Fatal(err)
642664
}

0 commit comments

Comments
 (0)