diff --git a/compiler/compiler_test.go b/compiler/compiler_test.go index e7add0015..88d8e525e 100644 --- a/compiler/compiler_test.go +++ b/compiler/compiler_test.go @@ -5,6 +5,7 @@ import ( "go/types" "regexp" "sort" + "strings" "testing" "time" @@ -749,6 +750,143 @@ func TestArchiveSelectionAfterSerialization(t *testing.T) { } } +func TestNestedConcreteTypeInGenericFunc(t *testing.T) { + // This is a test of a type defined inside a generic function + // that uses the type parameter of the function as a field type. + // The `T` type is unique for each instance of `F`. + // The use of `A` as a field is do demonstrate the difference in the types + // however even if T had no fields, the type would still be different. + // + // Change `print(F[?]())` to `fmt.Printf("%T\n", F[?]())` for + // golang playground to print the type of T in the different F instances. + // (I just didn't want this test to depend on `fmt` when it doesn't need to.) + + src := ` + package main + func F[A any]() any { + type T struct{ + a A + } + return T{} + } + func main() { + type Int int + print(F[int]()) + print(F[Int]()) + } + ` + + srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}} + root := srctesting.ParseSources(t, srcFiles, nil) + archives := compileProject(t, root, false) + mainPkg := archives[root.PkgPath] + insts := collectDeclInstances(t, mainPkg) + + exp := []string{ + `F[int]`, + `F[main.Int]`, // Go prints `F[main.Int·2]` + `T[int;]`, // `T` from `F[int]` (Go prints `T[int]`) + `T[main.Int;]`, // `T` from `F[main.Int]` (Go prints `T[main.Int·2]`) + } + if diff := cmp.Diff(exp, insts); len(diff) > 0 { + t.Errorf("the instances of generics are different:\n%s", diff) + } +} + +func TestNestedGenericTypeInGenericFunc(t *testing.T) { + // This is a subset of the type param nested test from the go repo. + // See https://github.com/golang/go/blob/go1.19.13/test/typeparam/nested.go + // The test is failing because nested types aren't being typed differently. + // For example the type of `T[int]` below is different based on `F[X]` + // instance for different `X` type parameters, hence Go prints the type as + // `T[X;int]` instead of `T[int]`. + + src := ` + package main + func F[A any]() any { + type T[B any] struct{ + a A + b B + } + return T[int]{} + } + func main() { + type Int int + print(F[int]()) + print(F[Int]()) + } + ` + + srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}} + root := srctesting.ParseSources(t, srcFiles, nil) + archives := compileProject(t, root, false) + mainPkg := archives[root.PkgPath] + insts := collectDeclInstances(t, mainPkg) + + exp := []string{ + `F[int]`, + `F[main.Int]`, + `T[int; int]`, + `T[main.Int; int]`, + } + if diff := cmp.Diff(exp, insts); len(diff) > 0 { + t.Errorf("the instances of generics are different:\n%s", diff) + } +} + +func TestNestedGenericTypeInGenericFuncWithSharedTArgs(t *testing.T) { + src := ` + package main + func F[A any]() any { + type T[B any] struct { + b B + } + return T[A]{} + } + func main() { + type Int int + print(F[int]()) + print(F[Int]()) + }` + + srcFiles := []srctesting.Source{{Name: `main.go`, Contents: []byte(src)}} + root := srctesting.ParseSources(t, srcFiles, nil) + archives := compileProject(t, root, false) + mainPkg := archives[root.PkgPath] + insts := collectDeclInstances(t, mainPkg) + + exp := []string{ + `F[int]`, + `F[main.Int]`, + `T[int; int]`, + `T[main.Int; main.Int]`, + // Make sure that T[int;main.Int] and T[main.Int;int] aren't created. + } + if diff := cmp.Diff(exp, insts); len(diff) > 0 { + t.Errorf("the instances of generics are different:\n%s", diff) + } +} + +func collectDeclInstances(t *testing.T, pkg *Archive) []string { + t.Helper() + + // Regex to match strings like `Foo[42 /* bar */] =` and capture + // the name (`Foo`), the index (`42`), and the instance type (`bar`). + rex := regexp.MustCompile(`^\s*(\w+)\s*\[\s*(\d+)\s*\/\*(.+)\*\/\s*\]\s*\=`) + + // Collect all instances of generics (e.g. `Foo[bar] @ 2`) written to the decl code. + insts := []string{} + for _, decl := range pkg.Declarations { + if match := rex.FindAllStringSubmatch(string(decl.DeclCode), 1); len(match) > 0 { + instance := match[0][1] + `[` + strings.TrimSpace(match[0][3]) + `]` + instance = strings.ReplaceAll(instance, `command-line-arguments`, pkg.Name) + insts = append(insts, instance) + } + } + sort.Strings(insts) + return insts +} + func compareOrder(t *testing.T, sourceFiles []srctesting.Source, minify bool) { t.Helper() outputNormal := compile(t, sourceFiles, minify) diff --git a/compiler/decls.go b/compiler/decls.go index 5b760fb15..eb95cd2f7 100644 --- a/compiler/decls.go +++ b/compiler/decls.go @@ -433,7 +433,7 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl { FullName: typeVarDeclFullName(obj), Vars: []string{name}, } - if typeparams.HasTypeParams(obj.Type()) { + if fc.pkgCtx.instanceSet.Pkg(obj.Pkg()).ObjHasInstances(obj) { varDecl.DeclCode = fc.CatchOutput(0, func() { fc.Printf("%s = {};", name) }) @@ -451,16 +451,28 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl { func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, error) { originType := inst.Object.Type().(*types.Named) - fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, typeparams.ToSlice(originType.TypeParams()), inst.TArgs) + var nestResolver *typeparams.Resolver + if len(inst.TNest) > 0 { + fn := typeparams.FindNestingFunc(inst.Object) + tp := typeparams.SignatureTypeParams(fn.Type().(*types.Signature)) + nestResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, tp, inst.TNest, nil) + } + fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, originType.TypeParams(), inst.TArgs, nestResolver) defer func() { fc.typeResolver = nil }() instanceType := originType if !inst.IsTrivial() { - instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true) - if err != nil { - return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err) + if len(inst.TArgs) > 0 { + instantiated, err := types.Instantiate(fc.pkgCtx.typesCtx, originType, inst.TArgs, true) + if err != nil { + return nil, fmt.Errorf("failed to instantiate type %v with args %v: %w", originType, inst.TArgs, err) + } + instanceType = instantiated.(*types.Named) + } + if len(inst.TNest) > 0 { + instantiated := nestResolver.Substitute(instanceType) + instanceType = instantiated.(*types.Named) } - instanceType = instantiated.(*types.Named) } underlying := instanceType.Underlying() @@ -541,7 +553,8 @@ func (fc *funcContext) structConstructor(t *types.Struct) string { // If no arguments were passed, zero-initialize all fields. fmt.Fprintf(constructor, "\t\tif (arguments.length === 0) {\n") for i := 0; i < t.NumFields(); i++ { - fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String()) + zeroValue := fc.zeroValue(fc.fieldType(t, i)) + fmt.Fprintf(constructor, "\t\t\tthis.%s = %s;\n", fieldName(t, i), fc.translateExpr(zeroValue).String()) } fmt.Fprintf(constructor, "\t\t\treturn;\n") fmt.Fprintf(constructor, "\t\t}\n") diff --git a/compiler/expressions.go b/compiler/expressions.go index f1c2e68f5..781a37a3e 100644 --- a/compiler/expressions.go +++ b/compiler/expressions.go @@ -178,18 +178,18 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression { } if !isKeyValue { for i, element := range e.Elts { - elements[i] = fc.translateImplicitConversionWithCloning(element, t.Field(i).Type()).String() + elements[i] = fc.translateImplicitConversionWithCloning(element, fc.fieldType(t, i)).String() } } if isKeyValue { for i := range elements { - elements[i] = fc.translateExpr(fc.zeroValue(t.Field(i).Type())).String() + elements[i] = fc.translateExpr(fc.zeroValue(fc.fieldType(t, i))).String() } for _, element := range e.Elts { kve := element.(*ast.KeyValueExpr) for j := range elements { if kve.Key.(*ast.Ident).Name == t.Field(j).Name() { - elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, t.Field(j).Type()).String() + elements[j] = fc.translateImplicitConversionWithCloning(kve.Value, fc.fieldType(t, j)).String() break } } @@ -801,7 +801,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression { switch t := exprType.Underlying().(type) { case *types.Basic: if t.Kind() != types.UnsafePointer { - panic("unexpected basic type") + panic(fmt.Errorf(`unexpected basic type: %v in %v`, t, e.Name)) } return fc.formatExpr("0") case *types.Slice, *types.Pointer: @@ -917,7 +917,7 @@ func (fc *funcContext) makeReceiver(e *ast.SelectorExpr) *expression { recvType = ptr.Elem() } s := recvType.Underlying().(*types.Struct) - recvType = s.Field(index).Type() + recvType = fc.fieldType(s, index) } fakeSel := &ast.SelectorExpr{X: x, Sel: ast.NewIdent("o")} @@ -1314,12 +1314,13 @@ func (fc *funcContext) loadStruct(array, target string, s *types.Struct) string var collectFields func(s *types.Struct, path string) collectFields = func(s *types.Struct, path string) { for i := 0; i < s.NumFields(); i++ { - field := s.Field(i) - if fs, isStruct := field.Type().Underlying().(*types.Struct); isStruct { - collectFields(fs, path+"."+fieldName(s, i)) + fieldName := path + "." + fieldName(s, i) + fieldType := fc.fieldType(s, i) + if fs, isStruct := fieldType.Underlying().(*types.Struct); isStruct { + collectFields(fs, fieldName) continue } - fields = append(fields, types.NewVar(0, nil, path+"."+fieldName(s, i), field.Type())) + fields = append(fields, types.NewVar(0, nil, fieldName, fieldType)) } } collectFields(s, target) diff --git a/compiler/functions.go b/compiler/functions.go index 592992efc..361c92f0f 100644 --- a/compiler/functions.go +++ b/compiler/functions.go @@ -49,9 +49,9 @@ func (fc *funcContext) nestedFunctionContext(info *analysis.FuncInfo, inst typep } if sig.TypeParams().Len() > 0 { - c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.TypeParams()), inst.TArgs) + c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.TypeParams(), inst.TArgs, nil) } else if sig.RecvTypeParams().Len() > 0 { - c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, typeparams.ToSlice(sig.RecvTypeParams()), inst.TArgs) + c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.RecvTypeParams(), inst.TArgs, nil) } if c.objectNames == nil { c.objectNames = map[types.Object]string{} diff --git a/compiler/internal/analysis/info.go b/compiler/internal/analysis/info.go index d05f9a6d1..e400c870c 100644 --- a/compiler/internal/analysis/info.go +++ b/compiler/internal/analysis/info.go @@ -126,8 +126,8 @@ func (info *Info) newFuncInfoInstances(fd *ast.FuncDecl) []*FuncInfo { for _, inst := range instances { var resolver *typeparams.Resolver if sig, ok := obj.Type().(*types.Signature); ok { - tp := typeparams.ToSlice(typeparams.SignatureTypeParams(sig)) - resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs) + tp := typeparams.SignatureTypeParams(sig) + resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs, nil) } fi := info.newFuncInfo(fd, inst.Object, inst.TArgs, resolver) funcInfos = append(funcInfos, fi) diff --git a/compiler/internal/analysis/info_test.go b/compiler/internal/analysis/info_test.go index 73428207e..0df26b0b9 100644 --- a/compiler/internal/analysis/info_test.go +++ b/compiler/internal/analysis/info_test.go @@ -382,6 +382,7 @@ func TestBlocking_Defers_WithMultipleReturns(t *testing.T) { // of which flow control statements (e.g. if-statements) are terminating // or not. Any defers added in a terminating control flow would not // propagate to returns that are not in that block. + // See golang.org/x/tools/go/ssa for flow control analysis. // // For now we simply build up the list of defers as we go making // the return on line 31 also blocking. diff --git a/compiler/internal/symbol/symbol.go b/compiler/internal/symbol/symbol.go index 851ca1ef6..d460ea86d 100644 --- a/compiler/internal/symbol/symbol.go +++ b/compiler/internal/symbol/symbol.go @@ -21,6 +21,11 @@ type Name struct { // New constructs SymName for a given named symbol. func New(o types.Object) Name { + pkgPath := `_` + if pkg := o.Pkg(); pkg != nil { + pkgPath = pkg.Path() + } + if fun, ok := o.(*types.Func); ok { sig := fun.Type().(*types.Signature) if recv := sig.Recv(); recv != nil { @@ -28,18 +33,18 @@ func New(o types.Object) Name { typ := recv.Type() if ptr, ok := typ.(*types.Pointer); ok { return Name{ - PkgPath: o.Pkg().Path(), + PkgPath: pkgPath, Name: "(*" + ptr.Elem().(*types.Named).Obj().Name() + ")." + o.Name(), } } return Name{ - PkgPath: o.Pkg().Path(), + PkgPath: pkgPath, Name: typ.(*types.Named).Obj().Name() + "." + o.Name(), } } } return Name{ - PkgPath: o.Pkg().Path(), + PkgPath: pkgPath, Name: o.Name(), } } diff --git a/compiler/internal/typeparams/collect.go b/compiler/internal/typeparams/collect.go index 723172d4f..940690e83 100644 --- a/compiler/internal/typeparams/collect.go +++ b/compiler/internal/typeparams/collect.go @@ -4,6 +4,7 @@ import ( "fmt" "go/ast" "go/types" + "strings" "github.com/gopherjs/gopherjs/compiler/typesutil" "github.com/gopherjs/gopherjs/internal/govendor/subst" @@ -12,27 +13,67 @@ import ( // Resolver translates types defined in terms of type parameters into concrete // types, given a mapping from type params to type arguments. type Resolver struct { + tParams *types.TypeParamList + tArgs []types.Type + parent *Resolver + + // subster is the substitution helper that will perform the actual + // substitutions. This maybe nil when there are no substitutions but + // will still usable when nil. subster *subst.Subster selMemo map[typesutil.Selection]typesutil.Selection } // NewResolver creates a new Resolver with tParams entries mapping to tArgs // entries with the same index. -func NewResolver(tc *types.Context, tParams []*types.TypeParam, tArgs []types.Type) *Resolver { +func NewResolver(tc *types.Context, tParams *types.TypeParamList, tArgs []types.Type, parent *Resolver) *Resolver { r := &Resolver{ + tParams: tParams, + tArgs: tArgs, + parent: parent, subster: subst.New(tc, tParams, tArgs), selMemo: map[typesutil.Selection]typesutil.Selection{}, } return r } +// TypeParams is the list of type parameters that this resolver +// (not any parent) will substitute. +func (r *Resolver) TypeParams() *types.TypeParamList { + if r == nil { + return nil + } + return r.tParams +} + +// TypeArgs is the list of type arguments that this resolver +// (not any parent) will resolve to. +func (r *Resolver) TypeArgs() []types.Type { + if r == nil { + return nil + } + return r.tArgs +} + +// Parent is the resolver for the function or method that this resolver +// is nested in. This may be nil if the context for this resolver is not +// nested in another generic function or method. +func (r *Resolver) Parent() *Resolver { + if r == nil { + return nil + } + return r.parent +} + // Substitute replaces references to type params in the provided type definition // with the corresponding concrete types. func (r *Resolver) Substitute(typ types.Type) types.Type { - if r == nil || r.subster == nil || typ == nil { + if r == nil || typ == nil { return typ // No substitutions to be made. } - return r.subster.Type(typ) + typ = r.subster.Type(typ) + typ = r.parent.Substitute(typ) + return typ } // SubstituteAll same as Substitute, but accepts a TypeList are returns @@ -49,7 +90,7 @@ func (r *Resolver) SubstituteAll(list *types.TypeList) []types.Type { // defined in terms of type parameters with a method selection on a concrete // instantiation of the type. func (r *Resolver) SubstituteSelection(sel typesutil.Selection) typesutil.Selection { - if r == nil || r.subster == nil || sel == nil { + if r == nil || sel == nil { return sel // No substitutions to be made. } if concrete, ok := r.selMemo[sel]; ok { @@ -82,13 +123,22 @@ func (r *Resolver) SubstituteSelection(sel typesutil.Selection) typesutil.Select } } -// ToSlice converts TypeParamList into a slice with the same order of entries. -func ToSlice(tpl *types.TypeParamList) []*types.TypeParam { - result := make([]*types.TypeParam, tpl.Len()) - for i := range result { - result[i] = tpl.At(i) +// String gets a strings representation of the resolver for debugging. +func (r *Resolver) String() string { + if r == nil { + return `{}` } - return result + + parts := make([]string, 0, len(r.tArgs)) + for i, ta := range r.tArgs { + parts = append(parts, fmt.Sprintf("%s->%s", r.tParams.At(i), ta)) + } + + nestStr := `` + if r.parent != nil { + nestStr = r.parent.String() + `:` + } + return nestStr + `{` + strings.Join(parts, `, `) + `}` } // visitor implements ast.Visitor and collects instances of generic types and @@ -101,24 +151,35 @@ type visitor struct { instances *PackageInstanceSets resolver *Resolver info *types.Info + tNest []types.Type // The type arguments for a nested context. } var _ ast.Visitor = &visitor{} -func (c *visitor) Visit(n ast.Node) (w ast.Visitor) { - w = c // Always traverse the full depth of the AST tree. +func (c *visitor) Visit(n ast.Node) ast.Visitor { + if ident, ok := n.(*ast.Ident); ok { + c.visitIdent(ident) + } + return c +} - ident, ok := n.(*ast.Ident) - if !ok { - return +func (c *visitor) visitIdent(ident *ast.Ident) { + if inst, ok := c.info.Instances[ident]; ok { + // Found the use of a generic type or function. + c.visitInstance(ident, inst) } - instance, ok := c.info.Instances[ident] - if !ok { - return + if len(c.resolver.TypeArgs()) > 0 { + if obj, ok := c.info.Defs[ident]; ok && obj != nil { + // Found instance of a type defined inside a generic context. + c.visitNestedType(obj) + } } +} - obj := c.info.ObjectOf(ident) +func (c *visitor) visitInstance(ident *ast.Ident, inst types.Instance) { + obj := c.info.Uses[ident] + tArgs := inst.TypeArgs // For types embedded in structs, the object the identifier resolves to is a // *types.Var representing the implicitly declared struct field. However, the @@ -131,9 +192,53 @@ func (c *visitor) Visit(n ast.Node) (w ast.Visitor) { if t, ok := typ.(*types.Named); ok { obj = t.Obj() } + + // If the object is defined in the same scope as the instance, + // then we apply the current nested type arguments. + var tNest []types.Type + if obj.Parent().Contains(ident.Pos()) { + tNest = c.tNest + } + + c.addInstance(obj, tArgs, tNest) +} + +func (c *visitor) visitNestedType(obj types.Object) { + if _, ok := obj.(*types.TypeName); !ok { + // Found a variable or function, not a type, so skip it. + return + } + + typ := obj.Type() + if ptr, ok := typ.(*types.Pointer); ok { + typ = ptr.Elem() + } + + t, ok := typ.(*types.Named) + if !ok || t.TypeParams().Len() > 0 { + // Found a generic type or an unnamed type (e.g. type parameter). + // Don't add generic types yet because they + // will be added when we find an instance of them. + return + } + + c.addInstance(obj, nil, c.resolver.TypeArgs()) +} + +func (c *visitor) addInstance(obj types.Object, tArgList *types.TypeList, tNest []types.Type) { + tArgs := c.resolver.SubstituteAll(tArgList) + if isGeneric(tArgs...) { + // Skip any instances that still have type parameters in them after + // substitution. This occurs when a type is defined while nested + // in a generic context and is not fully instantiated yet. + // We need to wait until we find a full instantiation of the type. + return + } + c.instances.Add(Instance{ Object: obj, - TArgs: c.resolver.SubstituteAll(instance.TypeArgs), + TArgs: tArgs, + TNest: tNest, }) if t, ok := obj.Type().(*types.Named); ok { @@ -141,11 +246,11 @@ func (c *visitor) Visit(n ast.Node) (w ast.Visitor) { method := t.Method(i) c.instances.Add(Instance{ Object: method.Origin(), - TArgs: c.resolver.SubstituteAll(instance.TypeArgs), + TArgs: tArgs, + TNest: tNest, }) } } - return } // seedVisitor implements ast.Visitor that collects information necessary to @@ -241,22 +346,49 @@ func (c *Collector) Scan(pkg *types.Package, files ...*ast.File) { for iset := c.Instances.Pkg(pkg); !iset.exhausted(); { inst, _ := iset.next() + switch typ := inst.Object.Type().(type) { case *types.Signature: - v := visitor{ - instances: c.Instances, - resolver: NewResolver(c.TContext, ToSlice(SignatureTypeParams(typ)), inst.TArgs), - info: c.Info, - } - ast.Walk(&v, objMap[inst.Object]) + c.scanSignature(inst, typ, objMap) + case *types.Named: - obj := typ.Obj() - v := visitor{ - instances: c.Instances, - resolver: NewResolver(c.TContext, ToSlice(typ.TypeParams()), inst.TArgs), - info: c.Info, - } - ast.Walk(&v, objMap[obj]) + c.scanNamed(inst, typ, objMap) } } } + +func (c *Collector) scanSignature(inst Instance, typ *types.Signature, objMap map[types.Object]ast.Node) { + tParams := SignatureTypeParams(typ) + v := visitor{ + instances: c.Instances, + resolver: NewResolver(c.TContext, tParams, inst.TArgs, nil), + info: c.Info, + tNest: inst.TArgs, + } + ast.Walk(&v, objMap[inst.Object]) +} + +func (c *Collector) scanNamed(inst Instance, typ *types.Named, objMap map[types.Object]ast.Node) { + obj := typ.Obj() + node := objMap[obj] + if node == nil { + // Types without an entry in objMap are concrete types + // that are defined in a generic context. Skip them. + return + } + + var nestResolver *Resolver + if len(inst.TNest) > 0 { + fn := FindNestingFunc(inst.Object) + tp := SignatureTypeParams(fn.Type().(*types.Signature)) + nestResolver = NewResolver(c.TContext, tp, inst.TNest, nil) + } + + v := visitor{ + instances: c.Instances, + resolver: NewResolver(c.TContext, typ.TypeParams(), inst.TArgs, nestResolver), + info: c.Info, + tNest: inst.TNest, + } + ast.Walk(&v, node) +} diff --git a/compiler/internal/typeparams/collect_test.go b/compiler/internal/typeparams/collect_test.go index 9bd5faee4..6864e5ead 100644 --- a/compiler/internal/typeparams/collect_test.go +++ b/compiler/internal/typeparams/collect_test.go @@ -2,7 +2,9 @@ package typeparams import ( "go/ast" + "go/token" "go/types" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -35,7 +37,11 @@ func TestVisitor(t *testing.T) { t := typ[int, A]{} t.method(0) (*typ[int32, A]).method(nil, 0) + type x struct{ T []typ[int64, A] } + type y[X any] struct{ T []typ[A, X] } + _ = y[int8]{} + _ = y[A]{} return } @@ -49,7 +55,11 @@ func TestVisitor(t *testing.T) { t := typ[int, T]{} t.method(0) (*typ[int32, T]).method(nil, 0) + type x struct{ T []typ[int64, T] } + type y[X any] struct{ T []typ[T, X] } + _ = y[int8]{} + _ = y[T]{} return } @@ -67,7 +77,11 @@ func TestVisitor(t *testing.T) { t := typ[int, T]{} t.method(0) (*typ[int32, T]).method(nil, 0) + type x struct{ T []typ[int64, T] } + type y[X any] struct{ T []typ[T, X] } + _ = y[int8]{} + _ = y[T]{} return } @@ -189,22 +203,50 @@ func TestVisitor(t *testing.T) { descr: "non-generic function", resolver: nil, node: lookupDecl("entry1"), - want: instancesInFunc(lookupType("A")), + want: append( + instancesInFunc(lookupType("A")), + Instance{ + Object: lookupObj("entry1.y"), + TArgs: []types.Type{types.Typ[types.Int8]}, + }, + Instance{ + Object: lookupObj("entry1.y"), + TArgs: []types.Type{lookupType("A")}, + }, + ), }, { descr: "generic function", resolver: NewResolver( types.NewContext(), - ToSlice(lookupType("entry2").(*types.Signature).TypeParams()), + lookupType("entry2").(*types.Signature).TypeParams(), []types.Type{lookupType("B")}, + nil, ), node: lookupDecl("entry2"), - want: instancesInFunc(lookupType("B")), + want: append( + instancesInFunc(lookupType("B")), + Instance{ + Object: lookupObj("entry2.x"), + TNest: []types.Type{lookupType("B")}, + }, + Instance{ + Object: lookupObj("entry1.y"), + TNest: []types.Type{lookupType("B")}, + TArgs: []types.Type{types.Typ[types.Int8]}, + }, + Instance{ + Object: lookupObj("entry2.y"), + TNest: []types.Type{lookupType("B")}, + TArgs: []types.Type{lookupType("B")}, + }, + ), }, { descr: "generic method", resolver: NewResolver( types.NewContext(), - ToSlice(lookupType("entry3.method").(*types.Signature).RecvTypeParams()), + lookupType("entry3.method").(*types.Signature).RecvTypeParams(), []types.Type{lookupType("C")}, + nil, ), node: lookupDecl("entry3.method"), want: append( @@ -217,13 +259,28 @@ func TestVisitor(t *testing.T) { Object: lookupObj("entry3.method"), TArgs: []types.Type{lookupType("C")}, }, + Instance{ + Object: lookupObj("entry3.method.x"), + TNest: []types.Type{lookupType("C")}, + }, + Instance{ + Object: lookupObj("entry3.method.y"), + TNest: []types.Type{lookupType("C")}, + TArgs: []types.Type{types.Typ[types.Int8]}, + }, + Instance{ + Object: lookupObj("entry3.method.y"), + TNest: []types.Type{lookupType("C")}, + TArgs: []types.Type{lookupType("C")}, + }, ), }, { descr: "generic type declaration", resolver: NewResolver( types.NewContext(), - ToSlice(lookupType("entry3").(*types.Named).TypeParams()), + lookupType("entry3").(*types.Named).TypeParams(), []types.Type{lookupType("D")}, + nil, ), node: lookupDecl("entry3"), want: instancesInType(lookupType("D")), @@ -256,6 +313,11 @@ func TestVisitor(t *testing.T) { resolver: test.resolver, info: info, } + if test.resolver != nil { + // Since we know all the tests are for functions and methods, + // set the nested type to the type parameter from the resolver. + v.tNest = test.resolver.tArgs + } ast.Walk(&v, test.node) got := v.instances.Pkg(pkg).Values() if diff := cmp.Diff(test.want, got, instanceOpts()); diff != "" { @@ -319,7 +381,7 @@ func TestSeedVisitor(t *testing.T) { } got := sv.instances.Pkg(pkg).Values() if diff := cmp.Diff(want, got, instanceOpts()); diff != "" { - t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff) + t.Errorf("Instances from seedVisitor contain diff (-want,+got):\n%s", diff) } } @@ -353,28 +415,333 @@ func TestCollector(t *testing.T) { } c.Scan(pkg, file) - inst := func(name string, tArg types.Type) Instance { + inst := func(name, tNest, tArg string) Instance { return Instance{ Object: srctesting.LookupObj(pkg, name), - TArgs: []types.Type{tArg}, + TNest: evalTypeArgs(t, f.FileSet, pkg, tNest), + TArgs: evalTypeArgs(t, f.FileSet, pkg, tArg), } } want := []Instance{ - inst("typ", types.Typ[types.Int]), - inst("typ.method", types.Typ[types.Int]), - inst("fun", types.Typ[types.Int8]), - inst("fun.nested", types.Typ[types.Int8]), - inst("typ", types.Typ[types.Int16]), - inst("typ.method", types.Typ[types.Int16]), - inst("typ", types.Typ[types.Int32]), - inst("typ.method", types.Typ[types.Int32]), - inst("fun", types.Typ[types.Int64]), - inst("fun.nested", types.Typ[types.Int64]), + inst(`typ`, ``, `int`), + inst(`typ.method`, ``, `int`), + inst(`fun`, ``, `int8`), + inst(`fun.nested`, `int8`, `int8`), + inst(`typ`, ``, `int16`), + inst(`typ.method`, ``, `int16`), + inst(`typ`, ``, `int32`), + inst(`typ.method`, ``, `int32`), + inst(`fun`, ``, `int64`), + inst(`fun.nested`, `int64`, `int64`), } got := c.Instances.Pkg(pkg).Values() if diff := cmp.Diff(want, got, instanceOpts()); diff != "" { - t.Errorf("Instances from initialSeeder contain diff (-want,+got):\n%s", diff) + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func TestCollector_MoreNesting(t *testing.T) { + src := `package test + + func fun[T any]() { + type nestedCon struct{ X T } + _ = nestedCon{} + + type nestedGen[U any] struct{ Y T; Z U } + _ = nestedGen[T]{} + _ = nestedGen[int8]{} + + type nestedCover[T any] struct{ W T } + _ = nestedCover[T]{} + _ = nestedCover[int16]{} + } + + func a() { + fun[int32]() + fun[int64]() + } + ` + + f := srctesting.New(t) + file := f.Parse(`test.go`, src) + info, pkg := f.Check(`pkg/test`, file) + + c := Collector{ + TContext: types.NewContext(), + Info: info, + Instances: &PackageInstanceSets{}, + } + c.Scan(pkg, file) + + inst := func(name, tNest, tArg string) Instance { + return Instance{ + Object: srctesting.LookupObj(pkg, name), + TNest: evalTypeArgs(t, f.FileSet, pkg, tNest), + TArgs: evalTypeArgs(t, f.FileSet, pkg, tArg), + } + } + want := []Instance{ + inst(`fun`, ``, `int32`), + inst(`fun`, ``, `int64`), + + inst(`fun.nestedCon`, `int32`, ``), + inst(`fun.nestedCon`, `int64`, ``), + + inst(`fun.nestedGen`, `int32`, `int32`), + inst(`fun.nestedGen`, `int32`, `int8`), + inst(`fun.nestedGen`, `int64`, `int64`), + inst(`fun.nestedGen`, `int64`, `int8`), + + inst(`fun.nestedCover`, `int32`, `int32`), + inst(`fun.nestedCover`, `int32`, `int16`), + inst(`fun.nestedCover`, `int64`, `int64`), + inst(`fun.nestedCover`, `int64`, `int16`), + } + got := c.Instances.Pkg(pkg).Values() + if diff := cmp.Diff(want, got, instanceOpts()); diff != `` { + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func TestCollector_NestingWithVars(t *testing.T) { + // This is loosely based off of go1.19.13/test/typeparam/issue47740b.go + // I was getting an error where `Q.print[int;]` was showing up when + // `Q.print` is not in a nesting context with `int` and this helped debug + // it. The problem was that `q` was being treated like a type not a var. + src := `package test + + type Q struct{ v any } + func (q Q) print() { + println(q.v) + } + + func newQ(v any) Q { + return Q{v} + } + + type S[T any] struct{ x T } + func (s S[T]) echo() { + q := newQ(s.x) + q.print() + } + + func a() { + s := S[int]{x: 0} + s.echo() + } + ` + + f := srctesting.New(t) + file := f.Parse(`test.go`, src) + info, pkg := f.Check(`pkg/test`, file) + + c := Collector{ + TContext: types.NewContext(), + Info: info, + Instances: &PackageInstanceSets{}, + } + c.Scan(pkg, file) + + inst := func(name, tNest, tArg string) Instance { + return Instance{ + Object: srctesting.LookupObj(pkg, name), + TNest: evalTypeArgs(t, f.FileSet, pkg, tNest), + TArgs: evalTypeArgs(t, f.FileSet, pkg, tArg), + } + } + want := []Instance{ + inst(`S`, ``, `int`), + inst(`S.echo`, ``, `int`), + } + got := c.Instances.Pkg(pkg).Values() + if diff := cmp.Diff(want, got, instanceOpts()); diff != `` { + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func TestCollector_RecursiveTypeParams(t *testing.T) { + // This is based off of part of go1.19.13/test/typeparam/nested.go + src := `package test + func F[A any]() {} + func main() { + type U[_ any] int + type X[A any] U[X[A]] + F[X[int]]() + } + ` + + f := srctesting.New(t) + file := f.Parse(`test.go`, src) + info, pkg := f.Check(`test`, file) + + c := Collector{ + TContext: types.NewContext(), + Info: info, + Instances: &PackageInstanceSets{}, + } + c.Scan(pkg, file) + + tInt := types.Typ[types.Int] + xAny := srctesting.LookupObj(pkg, `main.X`) + xInt, err := types.Instantiate(types.NewContext(), xAny.Type(), []types.Type{tInt}, true) + if err != nil { + t.Fatalf("Failed to instantiate X[int]: %v", err) + } + + want := []Instance{ + { + Object: srctesting.LookupObj(pkg, `F`), + TArgs: []types.Type{xInt}, + }, { + Object: srctesting.LookupObj(pkg, `main.U`), + TArgs: []types.Type{xInt}, + }, { + Object: xAny, + TArgs: []types.Type{tInt}, + }, + } + got := c.Instances.Pkg(pkg).Values() + if diff := cmp.Diff(want, got, instanceOpts()); diff != `` { + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func TestCollector_NestedRecursiveTypeParams(t *testing.T) { + t.Skip(`Skipping test due to known issue with nested recursive type parameters.`) + // TODO(grantnelson-wf): This test is failing because the type parameters + // inside of U are not being resolved to concrete types. This is because + // when instantiating X in the collector, we are not resolving the + // nested type of U that is X's type argument. This leave the A in U + // as a type parameter instead of resolving it to string. + + // This is based off of part of go1.19.13/test/typeparam/nested.go + src := `package test + func F[A any]() any { + type U[_ any] struct{ x A } + type X[B any] U[X[B]] + return X[int]{} + } + func main() { + print(F[string]()) + } + ` + + f := srctesting.New(t) + file := f.Parse(`test.go`, src) + info, pkg := f.Check(`test`, file) + + c := Collector{ + TContext: types.NewContext(), + Info: info, + Instances: &PackageInstanceSets{}, + } + c.Scan(pkg, file) + + xAny := srctesting.LookupObj(pkg, `F.X`) + xInt, err := types.Instantiate(types.NewContext(), xAny.Type(), []types.Type{types.Typ[types.Int]}, true) + if err != nil { + t.Fatalf("Failed to instantiate X[int]: %v", err) + } + // TODO(grantnelson-wf): Need to instantiate xInt to replace `A` with `int` in the struct. + if isGeneric(xInt) { + t.Errorf("Expected uInt to be non-generic, got %v", xInt.Underlying()) + } + + want := []Instance{ + { + Object: srctesting.LookupObj(pkg, `F`), + TArgs: []types.Type{types.Typ[types.String]}, + }, { + Object: srctesting.LookupObj(pkg, `F.U`), + TNest: []types.Type{types.Typ[types.String]}, + TArgs: []types.Type{xInt}, + }, { + Object: xAny, + TNest: []types.Type{types.Typ[types.String]}, + TArgs: []types.Type{types.Typ[types.Int]}, + }, + } + got := c.Instances.Pkg(pkg).Values() + if diff := cmp.Diff(want, got, instanceOpts()); diff != `` { + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func TestCollector_NestedTypeParams(t *testing.T) { + t.Skip(`Skipping test due to known issue with nested recursive type parameters.`) + // TODO(grantnelson-wf): This test is failing because the type parameters + // inside of U are not being resolved to concrete types. This is because + // when instantiating X in the collector, we are not resolving the + // nested type of U that is X's type argument. This leave the A in U + // as a type parameter instead of resolving it to string. + + // This is based off of part of go1.19.13/test/typeparam/nested.go + src := `package test + func F[A any]() any { + type T[B any] struct{} + type U[_ any] struct{ X A } + return T[U[A]]{} + } + func main() { + print(F[int]()) + } + ` + + f := srctesting.New(t) + file := f.Parse(`test.go`, src) + info, pkg := f.Check(`test`, file) + + c := Collector{ + TContext: types.NewContext(), + Info: info, + Instances: &PackageInstanceSets{}, + } + c.Scan(pkg, file) + + uAny := srctesting.LookupObj(pkg, `F.U`) + uInt, err := types.Instantiate(types.NewContext(), uAny.Type(), []types.Type{types.Typ[types.Int]}, true) + if err != nil { + t.Fatalf("Failed to instantiate U[int]: %v", err) + } + //TODO(grantnelson-wf): Need to instantiate uInt to replace `A` with `int` in the struct. + if isGeneric(uInt) { + t.Errorf("Expected uInt to be non-generic, got %v", uInt.Underlying()) + } + + want := []Instance{ + { + Object: srctesting.LookupObj(pkg, `F`), + TArgs: []types.Type{types.Typ[types.Int]}, + }, { + Object: srctesting.LookupObj(pkg, `F.U`), + TNest: []types.Type{types.Typ[types.Int]}, + TArgs: []types.Type{types.Typ[types.Int]}, + }, { + Object: srctesting.LookupObj(pkg, `F.T`), + TNest: []types.Type{types.Typ[types.Int]}, + TArgs: []types.Type{uInt}, + }, + } + got := c.Instances.Pkg(pkg).Values() + if diff := cmp.Diff(want, got, instanceOpts()); diff != `` { + t.Errorf("Instances from Collector contain diff (-want,+got):\n%s", diff) + } +} + +func evalTypeArgs(t *testing.T, fSet *token.FileSet, pkg *types.Package, expr string) []types.Type { + if len(expr) == 0 { + return nil + } + args := strings.Split(expr, ",") + targs := make([]types.Type, 0, len(args)) + for _, astr := range args { + tv, err := types.Eval(fSet, pkg, 0, astr) + if err != nil { + t.Fatalf("Eval(%s) failed: %v", astr, err) + } + targs = append(targs, tv.Type) } + return targs } func TestCollector_CrossPackage(t *testing.T) { @@ -492,7 +859,7 @@ func TestResolver_SubstituteSelection(t *testing.T) { info, pkg := f.Check("pkg/test", file) method := srctesting.LookupObj(pkg, "g.Method").(*types.Func).Type().(*types.Signature) - resolver := NewResolver(nil, ToSlice(method.RecvTypeParams()), []types.Type{srctesting.LookupObj(pkg, "x").Type()}) + resolver := NewResolver(nil, method.RecvTypeParams(), []types.Type{srctesting.LookupObj(pkg, "x").Type()}, nil) if l := len(info.Selections); l != 1 { t.Fatalf("Got: %d selections. Want: 1", l) diff --git a/compiler/internal/typeparams/instance.go b/compiler/internal/typeparams/instance.go index 3e4c04d2c..64c67b4b5 100644 --- a/compiler/internal/typeparams/instance.go +++ b/compiler/internal/typeparams/instance.go @@ -3,6 +3,7 @@ package typeparams import ( "fmt" "go/types" + "strings" "github.com/gopherjs/gopherjs/compiler/internal/symbol" "github.com/gopherjs/gopherjs/compiler/typesutil" @@ -15,39 +16,82 @@ import ( type Instance struct { Object types.Object // Object to be instantiated. TArgs typesutil.TypeList // Type params to instantiate with. + + // TNest is the type params of the function this object was nested with-in. + // e.g. In `func A[X any]() { type B[Y any] struct {} }` the `X` + // from `A` is the context of `B[Y]` thus creating `B[X;Y]`. + TNest typesutil.TypeList } // String returns a string representation of the Instance. // // Two semantically different instances may have the same string representation // if the instantiated object or its type arguments shadow other types. -func (i *Instance) String() string { - sym := symbol.New(i.Object).String() - if len(i.TArgs) == 0 { - return sym +func (i Instance) String() string { + return i.symbolicName() + i.TypeParamsString(`<`, `>`) +} + +// TypeString returns a Go type string representing the instance (suitable for %T verb). +func (i Instance) TypeString() string { + return i.qualifiedName() + i.TypeParamsString(`[`, `]`) +} + +// symbolicName returns a string representation of the instance's name +// including the package name and pointer indicators but +// excluding the type parameters. +func (i Instance) symbolicName() string { + if i.Object == nil { + return `` } + return symbol.New(i.Object).String() +} - return fmt.Sprintf("%s<%s>", sym, i.TArgs) +// qualifiedName returns a string representation of the instance's name +// including the package name but +// excluding the type parameters and pointer indicators. +func (i Instance) qualifiedName() string { + if i.Object == nil { + return `` + } + if i.Object.Pkg() == nil { + return i.Object.Name() + } + return fmt.Sprintf("%s.%s", i.Object.Pkg().Name(), i.Object.Name()) } -// TypeString returns a Go type string representing the instance (suitable for %T verb). -func (i *Instance) TypeString() string { - tArgs := "" - if len(i.TArgs) > 0 { - tArgs = "[" + i.TArgs.String() + "]" +// TypeParamsString returns part of a Go type string that represents the type +// parameters of the instance including the nesting type parameters, e.g. [X;Y,Z]. +func (i Instance) TypeParamsString(open, close string) string { + hasNest := len(i.TNest) > 0 + hasArgs := len(i.TArgs) > 0 + buf := strings.Builder{} + if hasNest || hasArgs { + buf.WriteString(open) + if hasNest { + buf.WriteString(i.TNest.String()) + buf.WriteRune(';') + if hasArgs { + buf.WriteRune(' ') + } + } + if hasArgs { + buf.WriteString(i.TArgs.String()) + } + buf.WriteString(close) } - return fmt.Sprintf("%s.%s%s", i.Object.Pkg().Name(), i.Object.Name(), tArgs) + return buf.String() } -// IsTrivial returns true if this is an instance of a non-generic object. -func (i *Instance) IsTrivial() bool { - return len(i.TArgs) == 0 +// IsTrivial returns true if this is an instance of a non-generic object +// and it is not nested in a generic function. +func (i Instance) IsTrivial() bool { + return len(i.TArgs) == 0 && len(i.TNest) == 0 } // Recv returns an instance of the receiver type of a method. // // Returns zero value if not a method. -func (i *Instance) Recv() Instance { +func (i Instance) Recv() Instance { sig, ok := i.Object.Type().(*types.Signature) if !ok { return Instance{} @@ -159,6 +203,17 @@ func (iset *InstanceSet) ForObj(obj types.Object) []Instance { return result } +// ObjHasInstances returns true if there are any instances (either trivial +// or non-trivial) that belong to the given object type, otherwise false. +func (iset *InstanceSet) ObjHasInstances(obj types.Object) bool { + for _, inst := range iset.values { + if inst.Object == obj { + return true + } + } + return false +} + // PackageInstanceSets stores an InstanceSet for each package in a program, keyed // by import path. type PackageInstanceSets map[string]*InstanceSet diff --git a/compiler/internal/typeparams/map.go b/compiler/internal/typeparams/map.go index dbe07a54e..7edbdc016 100644 --- a/compiler/internal/typeparams/map.go +++ b/compiler/internal/typeparams/map.go @@ -38,9 +38,9 @@ type InstanceMap[V any] struct { // If the given key isn't found, an empty bucket and -1 are returned. func (im *InstanceMap[V]) findIndex(key Instance) (mapBucket[V], int) { if im != nil && im.data != nil { - bucket := im.data[key.Object][typeHash(im.hasher, key.TArgs...)] + bucket := im.data[key.Object][typeHash(im.hasher, key.TNest, key.TArgs)] for i, candidate := range bucket { - if candidate != nil && candidate.key.TArgs.Equal(key.TArgs) { + if candidateArgsMatch(key, candidate) { return bucket, i } } @@ -82,7 +82,7 @@ func (im *InstanceMap[V]) Set(key Instance, value V) V { if _, ok := im.data[key.Object]; !ok { im.data[key.Object] = mapBuckets[V]{} } - bucketID := typeHash(im.hasher, key.TArgs...) + bucketID := typeHash(im.hasher, key.TNest, key.TArgs) // If there is already an identical key in the map, override the entry value. hole := -1 @@ -90,7 +90,7 @@ func (im *InstanceMap[V]) Set(key Instance, value V) V { for i, candidate := range bucket { if candidate == nil { hole = i - } else if candidate.key.TArgs.Equal(key.TArgs) { + } else if candidateArgsMatch(key, candidate) { old := candidate.value candidate.value = value return old @@ -180,13 +180,24 @@ func (im *InstanceMap[V]) String() string { return `{` + strings.Join(entries, `, `) + `}` } +// candidateArgsMatch checks if the candidate entry has the same type +// arguments as the given key. +func candidateArgsMatch[V any](key Instance, candidate *mapEntry[V]) bool { + return candidate != nil && + candidate.key.TNest.Equal(key.TNest) && + candidate.key.TArgs.Equal(key.TArgs) +} + // typeHash returns a combined hash of several types. // // Provided hasher is used to compute hashes of individual types, which are // xor'ed together. Xor preserves bit distribution property, so the combined // hash should be as good for bucketing, as the original. -func typeHash(hasher typeutil.Hasher, types ...types.Type) uint32 { +func typeHash(hasher typeutil.Hasher, nestTypes, types []types.Type) uint32 { var hash uint32 + for _, typ := range nestTypes { + hash ^= hasher.Hash(typ) + } for _, typ := range types { hash ^= hasher.Hash(typ) } diff --git a/compiler/internal/typeparams/map_test.go b/compiler/internal/typeparams/map_test.go index baa31f64a..d67a1884d 100644 --- a/compiler/internal/typeparams/map_test.go +++ b/compiler/internal/typeparams/map_test.go @@ -7,8 +7,10 @@ import ( ) func TestInstanceMap(t *testing.T) { + pkg := types.NewPackage(`testPkg`, `testPkg`) + i1 := Instance{ - Object: types.NewTypeName(token.NoPos, nil, "i1", nil), + Object: types.NewTypeName(token.NoPos, pkg, "i1", nil), TArgs: []types.Type{ types.Typ[types.Int], types.Typ[types.Int8], @@ -23,7 +25,7 @@ func TestInstanceMap(t *testing.T) { } i2 := Instance{ - Object: types.NewTypeName(token.NoPos, nil, "i2", nil), // Different pointer. + Object: types.NewTypeName(token.NoPos, pkg, "i2", nil), // Different pointer. TArgs: []types.Type{ types.Typ[types.Int], types.Typ[types.Int8], @@ -70,7 +72,7 @@ func TestInstanceMap(t *testing.T) { if got := m.Len(); got != 1 { t.Errorf("Got: map length %d. Want: 1.", got) } - if got, want := m.String(), `{{type i1 int, int8}:abc}`; got != want { + if got, want := m.String(), `{testPkg.i1:abc}`; got != want { t.Errorf("Got: map string %q. Want: map string %q.", got, want) } if got, want := m.Keys(), []Instance{i1}; !keysMatch(got, want) { @@ -95,7 +97,7 @@ func TestInstanceMap(t *testing.T) { if got := m.Get(i1clone); got != "def" { t.Errorf(`Got: getting set key returned %q. Want: "def"`, got) } - if got, want := m.String(), `{{type i1 int, int8}:def}`; got != want { + if got, want := m.String(), `{testPkg.i1:def}`; got != want { t.Errorf("Got: map string %q. Want: map string %q.", got, want) } if got, want := m.Keys(), []Instance{i1}; !keysMatch(got, want) { @@ -165,7 +167,7 @@ func TestInstanceMap(t *testing.T) { if got := m.Len(); got != 5 { t.Errorf("Got: map length %d. Want: 5.", got) } - if got, want := m.String(), `{{type i1 int, int8}:def, {type i1 int, int}:456, {type i1 string, string}:789, {type i1 }:ghi, {type i2 int, int8}:123}`; got != want { + if got, want := m.String(), `{testPkg.i1:ghi, testPkg.i1:def, testPkg.i1:456, testPkg.i1:789, testPkg.i2:123}`; got != want { t.Errorf("Got: map string %q. Want: map string %q.", got, want) } if got, want := m.Keys(), []Instance{i1, i2, i3, i4, i5}; !keysMatch(got, want) { diff --git a/compiler/internal/typeparams/utils.go b/compiler/internal/typeparams/utils.go index 6930fbf23..ea528314e 100644 --- a/compiler/internal/typeparams/utils.go +++ b/compiler/internal/typeparams/utils.go @@ -3,6 +3,7 @@ package typeparams import ( "errors" "fmt" + "go/token" "go/types" ) @@ -19,6 +20,31 @@ func SignatureTypeParams(sig *types.Signature) *types.TypeParamList { } } +// FindNestingFunc returns the function or method that the given object +// is nested in, or nil if the object was defined at the package level. +func FindNestingFunc(obj types.Object) *types.Func { + objPos := obj.Pos() + if objPos == token.NoPos { + return nil + } + + scope := obj.Parent() + for scope != nil { + // Iterate over all declarations in the scope. + for _, name := range scope.Names() { + decl := scope.Lookup(name) + if fn, ok := decl.(*types.Func); ok { + // Check if the object's position is within the function's scope. + if objPos >= fn.Pos() && objPos <= fn.Scope().End() { + return fn + } + } + } + scope = scope.Parent() + } + return nil +} + var ( errInstantiatesGenerics = errors.New("instantiates generic type or function") errDefinesGenerics = errors.New("defines generic type or function") @@ -58,3 +84,58 @@ func RequiresGenericsSupport(info *types.Info) error { return nil } + +// isGeneric will search all the given types and their subtypes for a +// *types.TypeParam. This will not check if a type could be generic, +// but if each instantiation is not completely concrete yet. +// +// This is useful to check for generics types like `X[B[T]]`, where +// `X` appears concrete because it is instantiated with the type argument `B[T]`, +// however the `T` inside `B[T]` is a type parameter making `X[B[T]]` a generic +// type since it required instantiation to a concrete type, e.g. `X[B[int]]`. +func isGeneric(typ ...types.Type) bool { + var containsTypeParam func(t types.Type) bool + + foreach := func(count int, getter func(index int) types.Type) bool { + for i := 0; i < count; i++ { + if containsTypeParam(getter(i)) { + return true + } + } + return false + } + + seen := make(map[types.Type]struct{}) + containsTypeParam = func(t types.Type) bool { + if _, ok := seen[t]; ok { + return false + } + seen[t] = struct{}{} + + switch t := t.(type) { + case *types.TypeParam: + return true + case *types.Named: + return t.TypeParams().Len() != t.TypeArgs().Len() || + foreach(t.TypeArgs().Len(), func(i int) types.Type { return t.TypeArgs().At(i) }) || + containsTypeParam(t.Underlying()) + case *types.Struct: + return foreach(t.NumFields(), func(i int) types.Type { return t.Field(i).Type() }) + case *types.Interface: + return foreach(t.NumMethods(), func(i int) types.Type { return t.Method(i).Type() }) + case *types.Signature: + return foreach(t.Params().Len(), func(i int) types.Type { return t.Params().At(i).Type() }) || + foreach(t.Results().Len(), func(i int) types.Type { return t.Results().At(i).Type() }) + case *types.Map: + return containsTypeParam(t.Key()) || containsTypeParam(t.Elem()) + case interface{ Elem() types.Type }: + // Handles *types.Pointer, *types.Slice, *types.Array, *types.Chan + return containsTypeParam(t.Elem()) + default: + // Other types (e.g., basic types) do not contain type parameters. + return false + } + } + + return foreach(len(typ), func(i int) types.Type { return typ[i] }) +} diff --git a/compiler/package.go b/compiler/package.go index 8f336130d..bb94962da 100644 --- a/compiler/package.go +++ b/compiler/package.go @@ -323,11 +323,17 @@ func (fc *funcContext) initArgs(ty types.Type) string { if !field.Exported() { pkgPath = field.Pkg().Path() } - fields[i] = fmt.Sprintf(`{prop: "%s", name: %s, embedded: %t, exported: %t, typ: %s, tag: %s}`, fieldName(t, i), encodeString(field.Name()), field.Anonymous(), field.Exported(), fc.typeName(field.Type()), encodeString(t.Tag(i))) + ft := fc.fieldType(t, i) + fields[i] = fmt.Sprintf(`{prop: "%s", name: %s, embedded: %t, exported: %t, typ: %s, tag: %s}`, + fieldName(t, i), encodeString(field.Name()), field.Anonymous(), field.Exported(), fc.typeName(ft), encodeString(t.Tag(i))) } return fmt.Sprintf(`"%s", [%s]`, pkgPath, strings.Join(fields, ", ")) case *types.TypeParam: - err := bailout(fmt.Errorf(`%v has unexpected generic type parameter %T`, ty, ty)) + tr := fc.typeResolver.Substitute(ty) + if tr != ty { + return fc.initArgs(tr) + } + err := bailout(fmt.Errorf(`"%v" has unexpected generic type parameter %T`, ty, ty)) panic(err) default: err := bailout(fmt.Errorf("%v has unexpected type %T", ty, ty)) diff --git a/compiler/utils.go b/compiler/utils.go index 83b826ce2..7d286f447 100644 --- a/compiler/utils.go +++ b/compiler/utils.go @@ -198,7 +198,7 @@ func (fc *funcContext) translateSelection(sel typesutil.Selection, pos token.Pos jsFieldName := s.Field(index).Name() for { fields = append(fields, fieldName(s, 0)) - ft := s.Field(0).Type() + ft := fc.fieldType(s, 0) if typesutil.IsJsObject(ft) { return fields, jsTag } @@ -215,7 +215,7 @@ func (fc *funcContext) translateSelection(sel typesutil.Selection, pos token.Pos } } fields = append(fields, fieldName(s, index)) - t = s.Field(index).Type() + t = fc.fieldType(s, index) } return fields, "" } @@ -441,13 +441,16 @@ func (fc *funcContext) objectName(o types.Object) string { // knownInstances returns a list of known instantiations of the object. // -// For objects without type params always returns a single trivial instance. +// For objects without type params and not nested in a generic function or +// method, this always returns a single trivial instance. +// If the object is generic, or in a generic function or method, but there are +// no instances, then the object is unused and an empty list is returned. func (fc *funcContext) knownInstances(o types.Object) []typeparams.Instance { - if !typeparams.HasTypeParams(o.Type()) { + instances := fc.pkgCtx.instanceSet.Pkg(o.Pkg()).ForObj(o) + if len(instances) == 0 && !typeparams.HasTypeParams(o.Type()) { return []typeparams.Instance{{Object: o}} } - - return fc.pkgCtx.instanceSet.Pkg(o.Pkg()).ForObj(o) + return instances } // instName returns a JS expression that refers to the provided instance of a @@ -459,7 +462,8 @@ func (fc *funcContext) instName(inst typeparams.Instance) string { return objName } fc.pkgCtx.DeclareDCEDep(inst.Object, inst.TArgs...) - return fmt.Sprintf("%s[%d /* %v */]", objName, fc.pkgCtx.instanceSet.ID(inst), inst.TArgs) + label := inst.TypeParamsString(` /* `, ` */`) + return fmt.Sprintf("%s[%d%s]", objName, fc.pkgCtx.instanceSet.ID(inst), label) } // methodName returns a JS identifier (specifically, object property name) @@ -504,14 +508,31 @@ func (fc *funcContext) typeName(ty types.Type) string { return "$error" } inst := typeparams.Instance{Object: t.Obj()} + + // Get type arguments for the type if there are any. for i := 0; i < t.TypeArgs().Len(); i++ { inst.TArgs = append(inst.TArgs, t.TypeArgs().At(i)) } + + // Get the nesting type arguments if there are any. + if fn := typeparams.FindNestingFunc(t.Obj()); fn != nil { + if fn.Scope().Contains(t.Obj().Pos()) { + tp := typeparams.SignatureTypeParams(fn.Type().(*types.Signature)) + tNest := make([]types.Type, tp.Len()) + for i := 0; i < tp.Len(); i++ { + tNest[i] = fc.typeResolver.Substitute(tp.At(i)) + } + inst.TNest = typesutil.TypeList(tNest) + } + } + return fc.instName(inst) case *types.Interface: if t.Empty() { return "$emptyInterface" } + case *types.TypeParam: + panic(fmt.Errorf("unexpected type parameter: %v", t)) } // For anonymous composite types, generate a synthetic package-level type @@ -575,6 +596,12 @@ func (fc *funcContext) typeOf(expr ast.Expr) types.Type { return fc.typeResolver.Substitute(typ) } +// fieldType returns the type of the i-th field of the given struct +// after substituting type parameters with concrete types for nested context. +func (fc *funcContext) fieldType(t *types.Struct, i int) types.Type { + return fc.typeResolver.Substitute(t.Field(i).Type()) +} + func (fc *funcContext) selectionOf(e *ast.SelectorExpr) (typesutil.Selection, bool) { if sel, ok := fc.pkgCtx.Selections[e]; ok { return fc.typeResolver.SubstituteSelection(sel), true diff --git a/internal/govendor/subst/export.go b/internal/govendor/subst/export.go index 38e394bda..00a77ca49 100644 --- a/internal/govendor/subst/export.go +++ b/internal/govendor/subst/export.go @@ -4,6 +4,7 @@ package subst import ( + "fmt" "go/types" ) @@ -17,33 +18,33 @@ type Subster struct { } // New creates a new Subster with a given list of type parameters and matching args. -func New(tc *types.Context, tParams []*types.TypeParam, tArgs []types.Type) *Subster { - assert(len(tParams) == len(tArgs), "New() argument count must match") +func New(tc *types.Context, tParams *types.TypeParamList, tArgs []types.Type) *Subster { + if tParams.Len() != len(tArgs) { + panic(fmt.Errorf("number of type parameters and arguments must match: %d => %d", tParams.Len(), len(tArgs))) + } - if len(tParams) == 0 { + if tParams.Len() == 0 && len(tArgs) == 0 { return nil } - subst := &subster{ - replacements: make(map[*types.TypeParam]types.Type, len(tParams)), - cache: make(map[types.Type]types.Type), - ctxt: tc, - scope: nil, - debug: false, - } - for i := 0; i < len(tParams); i++ { - subst.replacements[tParams[i]] = tArgs[i] - } - return &Subster{ - impl: subst, - } + subst := makeSubster(tc, nil, tParams, tArgs, false) + return &Subster{impl: subst} } -// Type returns a version of typ with all references to type parameters replaced -// with the corresponding type arguments. +// Type returns a version of typ with all references to type parameters +// replaced with the corresponding type arguments. func (s *Subster) Type(typ types.Type) types.Type { if s == nil { return typ } return s.impl.typ(typ) } + +// Types returns a version of ts with all references to type parameters +// replaced with the corresponding type arguments. +func (s *Subster) Types(ts []types.Type) []types.Type { + if s == nil { + return ts + } + return s.impl.types(ts) +} diff --git a/internal/govendor/subst/subst.go b/internal/govendor/subst/subst.go index 9020e94f9..825e3c7f1 100644 --- a/internal/govendor/subst/subst.go +++ b/internal/govendor/subst/subst.go @@ -2,6 +2,8 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Copy of https://cs.opensource.google/go/x/tools/+/refs/tags/v0.17.0:go/ssa/subst.go +// Any changes to this copy are labelled with GOPHERJS. package subst import ( @@ -81,9 +83,12 @@ func (subst *subster) typ(t types.Type) (res types.Type) { // fall through if result r will be identical to t, types.Identical(r, t). switch t := t.(type) { case *types.TypeParam: - r := subst.replacements[t] - assert(r != nil, "type param without replacement encountered") - return r + // GOPHERJS: Replaced an assert that was causing a panic for nested types with code from + // https://cs.opensource.google/go/x/tools/+/refs/tags/v0.33.0:go/ssa/subst.go;l=92 + if r := subst.replacements[t]; r != nil { + return r + } + return t case *types.Basic: return t diff --git a/internal/govendor/subst/subst_test.go b/internal/govendor/subst/subst_test.go index 53fadbcf0..832f0ebd4 100644 --- a/internal/govendor/subst/subst_test.go +++ b/internal/govendor/subst/subst_test.go @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// Copy of https://cs.opensource.google/go/x/tools/+/refs/tags/v0.17.0:go/ssa/subst_test.go package subst import ( diff --git a/internal/govendor/subst/util.go b/internal/govendor/subst/util.go index 22072e39f..5b55c0310 100644 --- a/internal/govendor/subst/util.go +++ b/internal/govendor/subst/util.go @@ -6,18 +6,16 @@ package subst import "go/types" -// This file defines a number of miscellaneous utility functions. - -//// Sanity checking utilities - // assert panics with the mesage msg if p is false. // Avoid combining with expensive string formatting. +// From https://cs.opensource.google/go/x/tools/+/refs/tags/v0.17.0:go/ssa/util.go;l=27 func assert(p bool, msg string) { if !p { panic(msg) } } +// From https://cs.opensource.google/go/x/tools/+/refs/tags/v0.33.0:go/ssa/wrappers.go;l=262 func changeRecv(s *types.Signature, recv *types.Var) *types.Signature { return types.NewSignatureType(recv, nil, nil, s.Params(), s.Results(), s.Variadic()) } diff --git a/internal/srctesting/srctesting.go b/internal/srctesting/srctesting.go index bf74bce51..e4242991c 100644 --- a/internal/srctesting/srctesting.go +++ b/internal/srctesting/srctesting.go @@ -158,6 +158,9 @@ func LookupObj(pkg *types.Package, name string) types.Object { for len(path) > 0 { obj = scope.Lookup(path[0]) + if obj == nil { + panic(fmt.Sprintf("failed to find %q in %q", path[0], name)) + } path = path[1:] if fun, ok := obj.(*types.Func); ok { @@ -170,6 +173,9 @@ func LookupObj(pkg *types.Package, name string) types.Object { if len(path) > 0 { obj, _, _ = types.LookupFieldOrMethod(obj.Type(), true, obj.Pkg(), path[0]) path = path[1:] + if fun, ok := obj.(*types.Func); ok { + scope = fun.Scope() + } } } return obj