Skip to content

Commit 208d830

Browse files
Merge pull request #1370 from Workiva/nestedTypes
Adding nested type support, more updates will follow, see summary on this PR
2 parents c95887b + 3e9afa2 commit 208d830

20 files changed

+987
-137
lines changed

compiler/compiler_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"go/types"
66
"regexp"
77
"sort"
8+
"strings"
89
"testing"
910
"time"
1011

@@ -749,6 +750,143 @@ func TestArchiveSelectionAfterSerialization(t *testing.T) {
749750
}
750751
}
751752

753+
func TestNestedConcreteTypeInGenericFunc(t *testing.T) {
754+
// This is a test of a type defined inside a generic function
755+
// that uses the type parameter of the function as a field type.
756+
// The `T` type is unique for each instance of `F`.
757+
// The use of `A` as a field is do demonstrate the difference in the types
758+
// however even if T had no fields, the type would still be different.
759+
//
760+
// Change `print(F[?]())` to `fmt.Printf("%T\n", F[?]())` for
761+
// golang playground to print the type of T in the different F instances.
762+
// (I just didn't want this test to depend on `fmt` when it doesn't need to.)
763+
764+
src := `
765+
package main
766+
func F[A any]() any {
767+
type T struct{
768+
a A
769+
}
770+
return T{}
771+
}
772+
func main() {
773+
type Int int
774+
print(F[int]())
775+
print(F[Int]())
776+
}
777+
`
778+
779+
srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
780+
root := srctesting.ParseSources(t, srcFiles, nil)
781+
archives := compileProject(t, root, false)
782+
mainPkg := archives[root.PkgPath]
783+
insts := collectDeclInstances(t, mainPkg)
784+
785+
exp := []string{
786+
`F[int]`,
787+
`F[main.Int]`, // Go prints `F[main.Int·2]`
788+
`T[int;]`, // `T` from `F[int]` (Go prints `T[int]`)
789+
`T[main.Int;]`, // `T` from `F[main.Int]` (Go prints `T[main.Int·2]`)
790+
}
791+
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
792+
t.Errorf("the instances of generics are different:\n%s", diff)
793+
}
794+
}
795+
796+
func TestNestedGenericTypeInGenericFunc(t *testing.T) {
797+
// This is a subset of the type param nested test from the go repo.
798+
// See https://github.com/golang/go/blob/go1.19.13/test/typeparam/nested.go
799+
// The test is failing because nested types aren't being typed differently.
800+
// For example the type of `T[int]` below is different based on `F[X]`
801+
// instance for different `X` type parameters, hence Go prints the type as
802+
// `T[X;int]` instead of `T[int]`.
803+
804+
src := `
805+
package main
806+
func F[A any]() any {
807+
type T[B any] struct{
808+
a A
809+
b B
810+
}
811+
return T[int]{}
812+
}
813+
func main() {
814+
type Int int
815+
print(F[int]())
816+
print(F[Int]())
817+
}
818+
`
819+
820+
srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
821+
root := srctesting.ParseSources(t, srcFiles, nil)
822+
archives := compileProject(t, root, false)
823+
mainPkg := archives[root.PkgPath]
824+
insts := collectDeclInstances(t, mainPkg)
825+
826+
exp := []string{
827+
`F[int]`,
828+
`F[main.Int]`,
829+
`T[int; int]`,
830+
`T[main.Int; int]`,
831+
}
832+
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
833+
t.Errorf("the instances of generics are different:\n%s", diff)
834+
}
835+
}
836+
837+
func TestNestedGenericTypeInGenericFuncWithSharedTArgs(t *testing.T) {
838+
src := `
839+
package main
840+
func F[A any]() any {
841+
type T[B any] struct {
842+
b B
843+
}
844+
return T[A]{}
845+
}
846+
func main() {
847+
type Int int
848+
print(F[int]())
849+
print(F[Int]())
850+
}`
851+
852+
srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
853+
root := srctesting.ParseSources(t, srcFiles, nil)
854+
archives := compileProject(t, root, false)
855+
mainPkg := archives[root.PkgPath]
856+
insts := collectDeclInstances(t, mainPkg)
857+
858+
exp := []string{
859+
`F[int]`,
860+
`F[main.Int]`,
861+
`T[int; int]`,
862+
`T[main.Int; main.Int]`,
863+
// Make sure that T[int;main.Int] and T[main.Int;int] aren't created.
864+
}
865+
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
866+
t.Errorf("the instances of generics are different:\n%s", diff)
867+
}
868+
}
869+
870+
func collectDeclInstances(t *testing.T, pkg *Archive) []string {
871+
t.Helper()
872+
873+
// Regex to match strings like `Foo[42 /* bar */] =` and capture
874+
// the name (`Foo`), the index (`42`), and the instance type (`bar`).
875+
rex := regexp.MustCompile(`^\s*(\w+)\s*\[\s*(\d+)\s*\/\*(.+)\*\/\s*\]\s*\=`)
876+
877+
// Collect all instances of generics (e.g. `Foo[bar] @ 2`) written to the decl code.
878+
insts := []string{}
879+
for _, decl := range pkg.Declarations {
880+
if match := rex.FindAllStringSubmatch(string(decl.DeclCode), 1); len(match) > 0 {
881+
instance := match[0][1] + `[` + strings.TrimSpace(match[0][3]) + `]`
882+
instance = strings.ReplaceAll(instance, `command-line-arguments`, pkg.Name)
883+
insts = append(insts, instance)
884+
}
885+
}
886+
sort.Strings(insts)
887+
return insts
888+
}
889+
752890
func compareOrder(t *testing.T, sourceFiles []srctesting.Source, minify bool) {
753891
t.Helper()
754892
outputNormal := compile(t, sourceFiles, minify)

compiler/decls.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl {
433433
FullName: typeVarDeclFullName(obj),
434434
Vars: []string{name},
435435
}
436-
if typeparams.HasTypeParams(obj.Type()) {
436+
if fc.pkgCtx.instanceSet.Pkg(obj.Pkg()).ObjHasInstances(obj) {
437437
varDecl.DeclCode = fc.CatchOutput(0, func() {
438438
fc.Printf("%s = {};", name)
439439
})
@@ -451,16 +451,28 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl {
451451
func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, error) {
452452
originType := inst.Object.Type().(*types.Named)
453453

454-
fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, typeparams.ToSlice(originType.TypeParams()), inst.TArgs)
454+
var nestResolver *typeparams.Resolver
455+
if len(inst.TNest) > 0 {
456+
fn := typeparams.FindNestingFunc(inst.Object)
457+
tp := typeparams.SignatureTypeParams(fn.Type().(*types.Signature))
458+
nestResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, tp, inst.TNest, nil)
459+
}
460+
fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, originType.TypeParams(), inst.TArgs, nestResolver)
455461
defer func() { fc.typeResolver = nil }()
456462

457463
instanceType := originType
458464
if !inst.IsTrivial() {
459-
instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true)
460-
if err != nil {
461-
return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err)
465+
if len(inst.TArgs) > 0 {
466+
instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true)
467+
if err != nil {
468+
return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err)
469+
}
470+
instanceType = instantiated.(*types.Named)
471+
}
472+
if len(inst.TNest) > 0 {
473+
instantiated := nestResolver.Substitute(instanceType)
474+
instanceType = instantiated.(*types.Named)
462475
}
463-
instanceType = instantiated.(*types.Named)
464476
}
465477

466478
underlying := instanceType.Underlying()
@@ -541,7 +553,8 @@ func (fc *funcContext) structConstructor(t *types.Struct) string {
541553
// If no arguments were passed, zero-initialize all fields.
542554
fmt.Fprintf(constructor, "\t\tif (arguments.length === 0) {\n")
543555
for i := 0; i < t.NumFields(); i++ {
544-
fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String())
556+
zeroValue := fc.zeroValue(fc.fieldType(t, i))
557+
fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(zeroValue).String())
545558
}
546559
fmt.Fprintf(constructor, "\t\t\treturn;\n")
547560
fmt.Fprintf(constructor, "\t\t}\n")

compiler/expressions.go

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,18 +178,18 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
178178
}
179179
if !isKeyValue {
180180
for i, element := range e.Elts {
181-
elements[i] = fc.translateImplicitConversionWithCloning(element, t.Field(i).Type()).String()
181+
elements[i] = fc.translateImplicitConversionWithCloning(element, fc.fieldType(t, i)).String()
182182
}
183183
}
184184
if isKeyValue {
185185
for i := range elements {
186-
elements[i] = fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String()
186+
elements[i] = fc.translateExpr(fc.zeroValue(fc.fieldType(t, i))).String()
187187
}
188188
for _, element := range e.Elts {
189189
kve := element.(*ast.KeyValueExpr)
190190
for j := range elements {
191191
if kve.Key.(*ast.Ident).Name == t.Field(j).Name() {
192-
elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, t.Field(j).Type()).String()
192+
elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, fc.fieldType(t, j)).String()
193193
break
194194
}
195195
}
@@ -801,7 +801,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
801801
switch t := exprType.Underlying().(type) {
802802
case *types.Basic:
803803
if t.Kind() != types.UnsafePointer {
804-
panic("unexpected basic type")
804+
panic(fmt.Errorf(`unexpected basic type: %v in %v`, t, e.Name))
805805
}
806806
return fc.formatExpr("0")
807807
case *types.Slice, *types.Pointer:
@@ -917,7 +917,7 @@ func (fc *funcContext) makeReceiver(e *ast.SelectorExpr) *expression {
917917
recvType = ptr.Elem()
918918
}
919919
s := recvType.Underlying().(*types.Struct)
920-
recvType = s.Field(index).Type()
920+
recvType = fc.fieldType(s, index)
921921
}
922922

923923
fakeSel := &ast.SelectorExpr{X: x, Sel: ast.NewIdent("o")}
@@ -1314,12 +1314,13 @@ func (fc *funcContext) loadStruct(array, target string, s *types.Struct) string
13141314
var collectFields func(s *types.Struct, path string)
13151315
collectFields = func(s *types.Struct, path string) {
13161316
for i := 0; i < s.NumFields(); i++ {
1317-
field := s.Field(i)
1318-
if fs, isStruct := field.Type().Underlying().(*types.Struct); isStruct {
1319-
collectFields(fs, path+"."+fieldName(s, i))
1317+
fieldName := path + "." + fieldName(s, i)
1318+
fieldType := fc.fieldType(s, i)
1319+
if fs, isStruct := fieldType.Underlying().(*types.Struct); isStruct {
1320+
collectFields(fs, fieldName)
13201321
continue
13211322
}
1322-
fields = append(fields, types.NewVar(0, nil, path+"."+fieldName(s, i), field.Type()))
1323+
fields = append(fields, types.NewVar(0, nil, fieldName, fieldType))
13231324
}
13241325
}
13251326
collectFields(s, target)

compiler/functions.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ func (fc *funcContext) nestedFunctionContext(info *analysis.FuncInfo, inst typep
4949
}
5050

5151
if sig.TypeParams().Len() > 0 {
52-
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.TypeParams()), inst.TArgs)
52+
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.TypeParams(), inst.TArgs, nil)
5353
} else if sig.RecvTypeParams().Len() > 0 {
54-
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.RecvTypeParams()), inst.TArgs)
54+
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.RecvTypeParams(), inst.TArgs, nil)
5555
}
5656
if c.objectNames == nil {
5757
c.objectNames = map[types.Object]string{}

compiler/internal/analysis/info.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ func (info *Info) newFuncInfoInstances(fd *ast.FuncDecl) []*FuncInfo {
126126
for _, inst := range instances {
127127
var resolver *typeparams.Resolver
128128
if sig, ok := obj.Type().(*types.Signature); ok {
129-
tp := typeparams.ToSlice(typeparams.SignatureTypeParams(sig))
130-
resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs)
129+
tp := typeparams.SignatureTypeParams(sig)
130+
resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs, nil)
131131
}
132132
fi := info.newFuncInfo(fd, inst.Object, inst.TArgs, resolver)
133133
funcInfos = append(funcInfos, fi)

compiler/internal/analysis/info_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ func TestBlocking_Defers_WithMultipleReturns(t *testing.T) {
382382
// of which flow control statements (e.g. if-statements) are terminating
383383
// or not. Any defers added in a terminating control flow would not
384384
// propagate to returns that are not in that block.
385+
// See golang.org/x/tools/go/ssa for flow control analysis.
385386
//
386387
// For now we simply build up the list of defers as we go making
387388
// the return on line 31 also blocking.

compiler/internal/symbol/symbol.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,30 @@ type Name struct {
2121

2222
// New constructs SymName for a given named symbol.
2323
func New(o types.Object) Name {
24+
pkgPath := `_`
25+
if pkg := o.Pkg(); pkg != nil {
26+
pkgPath = pkg.Path()
27+
}
28+
2429
if fun, ok := o.(*types.Func); ok {
2530
sig := fun.Type().(*types.Signature)
2631
if recv := sig.Recv(); recv != nil {
2732
// Special case: disambiguate names for different types' methods.
2833
typ := recv.Type()
2934
if ptr, ok := typ.(*types.Pointer); ok {
3035
return Name{
31-
PkgPath: o.Pkg().Path(),
36+
PkgPath: pkgPath,
3237
Name: "(*" + ptr.Elem().(*types.Named).Obj().Name() + ")." + o.Name(),
3338
}
3439
}
3540
return Name{
36-
PkgPath: o.Pkg().Path(),
41+
PkgPath: pkgPath,
3742
Name: typ.(*types.Named).Obj().Name() + "." + o.Name(),
3843
}
3944
}
4045
}
4146
return Name{
42-
PkgPath: o.Pkg().Path(),
47+
PkgPath: pkgPath,
4348
Name: o.Name(),
4449
}
4550
}

0 commit comments

Comments
 (0)