Skip to content

Adding nested type support #1370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"go/types"
"regexp"
"sort"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -749,6 +750,143 @@ func TestArchiveSelectionAfterSerialization(t *testing.T) {
}
}

func TestNestedConcreteTypeInGenericFunc(t *testing.T) {
// This is a test of a type defined inside a generic function
// that uses the type parameter of the function as a field type.
// The `T` type is unique for each instance of `F`.
// The use of `A` as a field is do demonstrate the difference in the types
// however even if T had no fields, the type would still be different.
//
// Change `print(F[?]())` to `fmt.Printf("%T\n", F[?]())` for
// golang playground to print the type of T in the different F instances.
// (I just didn't want this test to depend on `fmt` when it doesn't need to.)

src := `
package main
func F[A any]() any {
type T struct{
a A
}
return T{}
}
func main() {
type Int int
print(F[int]())
print(F[Int]())
}
`

srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
root := srctesting.ParseSources(t, srcFiles, nil)
archives := compileProject(t, root, false)
mainPkg := archives[root.PkgPath]
insts := collectDeclInstances(t, mainPkg)

exp := []string{
`F[int]`,
`F[main.Int]`, // Go prints `F[main.Int·2]`
`T[int;]`, // `T` from `F[int]` (Go prints `T[int]`)
`T[main.Int;]`, // `T` from `F[main.Int]` (Go prints `T[main.Int·2]`)
}
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
t.Errorf("the instances of generics are different:\n%s", diff)
}
}

func TestNestedGenericTypeInGenericFunc(t *testing.T) {
// This is a subset of the type param nested test from the go repo.
// See https://github.com/golang/go/blob/go1.19.13/test/typeparam/nested.go
// The test is failing because nested types aren't being typed differently.
// For example the type of `T[int]` below is different based on `F[X]`
// instance for different `X` type parameters, hence Go prints the type as
// `T[X;int]` instead of `T[int]`.

src := `
package main
func F[A any]() any {
type T[B any] struct{
a A
b B
}
return T[int]{}
}
func main() {
type Int int
print(F[int]())
print(F[Int]())
}
`

srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
root := srctesting.ParseSources(t, srcFiles, nil)
archives := compileProject(t, root, false)
mainPkg := archives[root.PkgPath]
insts := collectDeclInstances(t, mainPkg)

exp := []string{
`F[int]`,
`F[main.Int]`,
`T[int; int]`,
`T[main.Int; int]`,
}
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
t.Errorf("the instances of generics are different:\n%s", diff)
}
}

func TestNestedGenericTypeInGenericFuncWithSharedTArgs(t *testing.T) {
src := `
package main
func F[A any]() any {
type T[B any] struct {
b B
}
return T[A]{}
}
func main() {
type Int int
print(F[int]())
print(F[Int]())
}`

srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}}
root := srctesting.ParseSources(t, srcFiles, nil)
archives := compileProject(t, root, false)
mainPkg := archives[root.PkgPath]
insts := collectDeclInstances(t, mainPkg)

exp := []string{
`F[int]`,
`F[main.Int]`,
`T[int; int]`,
`T[main.Int; main.Int]`,
// Make sure that T[int;main.Int] and T[main.Int;int] aren't created.
}
if diff := cmp.Diff(exp, insts); len(diff) > 0 {
t.Errorf("the instances of generics are different:\n%s", diff)
}
}

func collectDeclInstances(t *testing.T, pkg *Archive) []string {
t.Helper()

// Regex to match strings like `Foo[42 /* bar */] =` and capture
// the name (`Foo`), the index (`42`), and the instance type (`bar`).
rex := regexp.MustCompile(`^\s*(\w+)\s*\[\s*(\d+)\s*\/\*(.+)\*\/\s*\]\s*\=`)

// Collect all instances of generics (e.g. `Foo[bar] @ 2`) written to the decl code.
insts := []string{}
for _, decl := range pkg.Declarations {
if match := rex.FindAllStringSubmatch(string(decl.DeclCode), 1); len(match) > 0 {
instance := match[0][1] + `[` + strings.TrimSpace(match[0][3]) + `]`
instance = strings.ReplaceAll(instance, `command-line-arguments`, pkg.Name)
insts = append(insts, instance)
}
}
sort.Strings(insts)
return insts
}

func compareOrder(t *testing.T, sourceFiles []srctesting.Source, minify bool) {
t.Helper()
outputNormal := compile(t, sourceFiles, minify)
Expand Down
27 changes: 20 additions & 7 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl {
FullName: typeVarDeclFullName(obj),
Vars: []string{name},
}
if typeparams.HasTypeParams(obj.Type()) {
if fc.pkgCtx.instanceSet.Pkg(obj.Pkg()).ObjHasInstances(obj) {
varDecl.DeclCode = fc.CatchOutput(0, func() {
fc.Printf("%s = {};", name)
})
Expand All @@ -451,16 +451,28 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl {
func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, error) {
originType := inst.Object.Type().(*types.Named)

fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, typeparams.ToSlice(originType.TypeParams()), inst.TArgs)
var nestResolver *typeparams.Resolver
if len(inst.TNest) > 0 {
fn := typeparams.FindNestingFunc(inst.Object)
tp := typeparams.SignatureTypeParams(fn.Type().(*types.Signature))
nestResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, tp, inst.TNest, nil)
}
fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, originType.TypeParams(), inst.TArgs, nestResolver)
defer func() { fc.typeResolver = nil }()

instanceType := originType
if !inst.IsTrivial() {
instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true)
if err != nil {
return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err)
if len(inst.TArgs) > 0 {
instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true)
if err != nil {
return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err)
}
instanceType = instantiated.(*types.Named)
}
if len(inst.TNest) > 0 {
instantiated := nestResolver.Substitute(instanceType)
instanceType = instantiated.(*types.Named)
}
instanceType = instantiated.(*types.Named)
}

underlying := instanceType.Underlying()
Expand Down Expand Up @@ -541,7 +553,8 @@ func (fc *funcContext) structConstructor(t *types.Struct) string {
// If no arguments were passed, zero-initialize all fields.
fmt.Fprintf(constructor, "\t\tif (arguments.length === 0) {\n")
for i := 0; i < t.NumFields(); i++ {
fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String())
zeroValue := fc.zeroValue(fc.fieldType(t, i))
fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(zeroValue).String())
}
fmt.Fprintf(constructor, "\t\t\treturn;\n")
fmt.Fprintf(constructor, "\t\t}\n")
Expand Down
19 changes: 10 additions & 9 deletions compiler/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,18 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
}
if !isKeyValue {
for i, element := range e.Elts {
elements[i] = fc.translateImplicitConversionWithCloning(element, t.Field(i).Type()).String()
elements[i] = fc.translateImplicitConversionWithCloning(element, fc.fieldType(t, i)).String()
}
}
if isKeyValue {
for i := range elements {
elements[i] = fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String()
elements[i] = fc.translateExpr(fc.zeroValue(fc.fieldType(t, i))).String()
}
for _, element := range e.Elts {
kve := element.(*ast.KeyValueExpr)
for j := range elements {
if kve.Key.(*ast.Ident).Name == t.Field(j).Name() {
elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, t.Field(j).Type()).String()
elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, fc.fieldType(t, j)).String()
break
}
}
Expand Down Expand Up @@ -801,7 +801,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
switch t := exprType.Underlying().(type) {
case *types.Basic:
if t.Kind() != types.UnsafePointer {
panic("unexpected basic type")
panic(fmt.Errorf(`unexpected basic type: %v in %v`, t, e.Name))
}
return fc.formatExpr("0")
case *types.Slice, *types.Pointer:
Expand Down Expand Up @@ -917,7 +917,7 @@ func (fc *funcContext) makeReceiver(e *ast.SelectorExpr) *expression {
recvType = ptr.Elem()
}
s := recvType.Underlying().(*types.Struct)
recvType = s.Field(index).Type()
recvType = fc.fieldType(s, index)
}

fakeSel := &ast.SelectorExpr{X: x, Sel: ast.NewIdent("o")}
Expand Down Expand Up @@ -1314,12 +1314,13 @@ func (fc *funcContext) loadStruct(array, target string, s *types.Struct) string
var collectFields func(s *types.Struct, path string)
collectFields = func(s *types.Struct, path string) {
for i := 0; i < s.NumFields(); i++ {
field := s.Field(i)
if fs, isStruct := field.Type().Underlying().(*types.Struct); isStruct {
collectFields(fs, path+"."+fieldName(s, i))
fieldName := path + "." + fieldName(s, i)
fieldType := fc.fieldType(s, i)
if fs, isStruct := fieldType.Underlying().(*types.Struct); isStruct {
collectFields(fs, fieldName)
continue
}
fields = append(fields, types.NewVar(0, nil, path+"."+fieldName(s, i), field.Type()))
fields = append(fields, types.NewVar(0, nil, fieldName, fieldType))
}
}
collectFields(s, target)
Expand Down
4 changes: 2 additions & 2 deletions compiler/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func (fc *funcContext) nestedFunctionContext(info *analysis.FuncInfo, inst typep
}

if sig.TypeParams().Len() > 0 {
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.TypeParams()), inst.TArgs)
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.TypeParams(), inst.TArgs, nil)
} else if sig.RecvTypeParams().Len() > 0 {
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.RecvTypeParams()), inst.TArgs)
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.RecvTypeParams(), inst.TArgs, nil)
}
if c.objectNames == nil {
c.objectNames = map[types.Object]string{}
Expand Down
4 changes: 2 additions & 2 deletions compiler/internal/analysis/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func (info *Info) newFuncInfoInstances(fd *ast.FuncDecl) []*FuncInfo {
for _, inst := range instances {
var resolver *typeparams.Resolver
if sig, ok := obj.Type().(*types.Signature); ok {
tp := typeparams.ToSlice(typeparams.SignatureTypeParams(sig))
resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs)
tp := typeparams.SignatureTypeParams(sig)
resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs, nil)
}
fi := info.newFuncInfo(fd, inst.Object, inst.TArgs, resolver)
funcInfos = append(funcInfos, fi)
Expand Down
1 change: 1 addition & 0 deletions compiler/internal/analysis/info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ func TestBlocking_Defers_WithMultipleReturns(t *testing.T) {
// of which flow control statements (e.g. if-statements) are terminating
// or not. Any defers added in a terminating control flow would not
// propagate to returns that are not in that block.
// See golang.org/x/tools/go/ssa for flow control analysis.
//
// For now we simply build up the list of defers as we go making
// the return on line 31 also blocking.
Expand Down
11 changes: 8 additions & 3 deletions compiler/internal/symbol/symbol.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,30 @@ type Name struct {

// New constructs SymName for a given named symbol.
func New(o types.Object) Name {
pkgPath := `_`
if pkg := o.Pkg(); pkg != nil {
pkgPath = pkg.Path()
}

if fun, ok := o.(*types.Func); ok {
sig := fun.Type().(*types.Signature)
if recv := sig.Recv(); recv != nil {
// Special case: disambiguate names for different types' methods.
typ := recv.Type()
if ptr, ok := typ.(*types.Pointer); ok {
return Name{
PkgPath: o.Pkg().Path(),
PkgPath: pkgPath,
Name: "(*" + ptr.Elem().(*types.Named).Obj().Name() + ")." + o.Name(),
}
}
return Name{
PkgPath: o.Pkg().Path(),
PkgPath: pkgPath,
Name: typ.(*types.Named).Obj().Name() + "." + o.Name(),
}
}
}
return Name{
PkgPath: o.Pkg().Path(),
PkgPath: pkgPath,
Name: o.Name(),
}
}
Expand Down
Loading
Loading