Skip to content

Handling nested type arguments #1374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Handling deep nested types
  • Loading branch information
grantnelson-wf committed May 28, 2025
commit f5ca11fb8baadbdeb39c7c5cb7c1d88bb82824be
13 changes: 2 additions & 11 deletions compiler/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,7 @@ func (fc *funcContext) newNamedTypeVarDecl(obj *types.TypeName) *Decl {
func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, error) {
originType := inst.Object.Type().(*types.Named)

var nestResolver *typeparams.Resolver
if len(inst.TNest) > 0 {
fn := typeparams.FindNestingFunc(inst.Object)
tp := typeparams.SignatureTypeParams(fn.Type().(*types.Signature))
nestResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, tp, inst.TNest, nil)
}
fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, originType.TypeParams(), inst.TArgs, nestResolver)
fc.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, inst)
defer func() { fc.typeResolver = nil }()

instanceType := originType
Expand All @@ -469,10 +463,7 @@ func (fc *funcContext) newNamedTypeInstDecl(inst typeparams.Instance) (*Decl, er
}
instanceType = instantiated.(*types.Named)
}
if len(inst.TNest) > 0 {
instantiated := nestResolver.Substitute(instanceType)
instanceType = instantiated.(*types.Named)
}
instanceType = fc.typeResolver.Substitute(instanceType).(*types.Named)
}

underlying := instanceType.Underlying()
Expand Down
2 changes: 1 addition & 1 deletion compiler/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
case *types.Signature:
return fc.formatExpr("%s", fc.instName(fc.instanceOf(e.X.(*ast.Ident))))
default:
panic(fmt.Errorf(`unhandled IndexExpr: %T`, t))
panic(fmt.Errorf(`unhandled IndexExpr: %T in %T`, t, fc.typeOf(e.X)))
}
case *ast.IndexListExpr:
switch t := fc.typeOf(e.X).Underlying().(type) {
Expand Down
10 changes: 3 additions & 7 deletions compiler/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,9 @@ func (fc *funcContext) nestedFunctionContext(info *analysis.FuncInfo, inst typep
c.allVars[k] = v
}

if sig.TypeParams().Len() > 0 {
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.TypeParams(), inst.TArgs, nil)
} else if sig.RecvTypeParams().Len() > 0 {
c.typeResolver = typeparams.NewResolver(c.pkgCtx.typesCtx, sig.RecvTypeParams(), inst.TArgs, nil)
}
if c.objectNames == nil {
c.objectNames = map[types.Object]string{}
// Use the parent function's resolver unless the function has it's own type arguments.
if !inst.IsTrivial() {
c.typeResolver = typeparams.NewResolver(fc.pkgCtx.typesCtx, inst)
}

// Synthesize an identifier by which the function may reference itself. Since
Expand Down
6 changes: 1 addition & 5 deletions compiler/internal/analysis/info.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ func (info *Info) newFuncInfoInstances(fd *ast.FuncDecl) []*FuncInfo {

funcInfos := make([]*FuncInfo, 0, len(instances))
for _, inst := range instances {
var resolver *typeparams.Resolver
if sig, ok := obj.Type().(*types.Signature); ok {
tp := typeparams.SignatureTypeParams(sig)
resolver = typeparams.NewResolver(info.typeCtx, tp, inst.TArgs, nil)
}
resolver := typeparams.NewResolver(info.typeCtx, inst)
fi := info.newFuncInfo(fd, inst.Object, inst.TArgs, resolver)
funcInfos = append(funcInfos, fi)
}
Expand Down
179 changes: 25 additions & 154 deletions compiler/internal/typeparams/collect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,143 +4,8 @@ import (
"fmt"
"go/ast"
"go/types"
"strings"

"github.com/gopherjs/gopherjs/compiler/typesutil"
"github.com/gopherjs/gopherjs/internal/govendor/subst"
)

// Resolver translates types defined in terms of type parameters into concrete
// types, given a mapping from type params to type arguments.
type Resolver struct {
tParams *types.TypeParamList
tArgs []types.Type
parent *Resolver

// subster is the substitution helper that will perform the actual
// substitutions. This maybe nil when there are no substitutions but
// will still usable when nil.
subster *subst.Subster
selMemo map[typesutil.Selection]typesutil.Selection
}

// NewResolver creates a new Resolver with tParams entries mapping to tArgs
// entries with the same index.
func NewResolver(tc *types.Context, tParams *types.TypeParamList, tArgs []types.Type, parent *Resolver) *Resolver {
r := &Resolver{
tParams: tParams,
tArgs: tArgs,
parent: parent,
subster: subst.New(tc, tParams, tArgs),
selMemo: map[typesutil.Selection]typesutil.Selection{},
}
return r
}

// TypeParams is the list of type parameters that this resolver
// (not any parent) will substitute.
func (r *Resolver) TypeParams() *types.TypeParamList {
if r == nil {
return nil
}
return r.tParams
}

// TypeArgs is the list of type arguments that this resolver
// (not any parent) will resolve to.
func (r *Resolver) TypeArgs() []types.Type {
if r == nil {
return nil
}
return r.tArgs
}

// Parent is the resolver for the function or method that this resolver
// is nested in. This may be nil if the context for this resolver is not
// nested in another generic function or method.
func (r *Resolver) Parent() *Resolver {
if r == nil {
return nil
}
return r.parent
}

// Substitute replaces references to type params in the provided type definition
// with the corresponding concrete types.
func (r *Resolver) Substitute(typ types.Type) types.Type {
if r == nil || typ == nil {
return typ // No substitutions to be made.
}
typ = r.subster.Type(typ)
typ = r.parent.Substitute(typ)
return typ
}

// SubstituteAll same as Substitute, but accepts a TypeList are returns
// substitution results as a slice in the same order.
func (r *Resolver) SubstituteAll(list *types.TypeList) []types.Type {
result := make([]types.Type, list.Len())
for i := range result {
result[i] = r.Substitute(list.At(i))
}
return result
}

// SubstituteSelection replaces a method of field selection on a generic type
// defined in terms of type parameters with a method selection on a concrete
// instantiation of the type.
func (r *Resolver) SubstituteSelection(sel typesutil.Selection) typesutil.Selection {
if r == nil || sel == nil {
return sel // No substitutions to be made.
}
if concrete, ok := r.selMemo[sel]; ok {
return concrete
}

switch sel.Kind() {
case types.MethodExpr, types.MethodVal, types.FieldVal:
recv := r.Substitute(sel.Recv())
if types.Identical(recv, sel.Recv()) {
return sel // Non-generic receiver, no substitution necessary.
}

// Look up the method on the instantiated receiver.
pkg := sel.Obj().Pkg()
obj, index, _ := types.LookupFieldOrMethod(recv, true, pkg, sel.Obj().Name())
if obj == nil {
panic(fmt.Errorf("failed to lookup field %q in type %v", sel.Obj().Name(), recv))
}
typ := obj.Type()

if sel.Kind() == types.MethodExpr {
typ = typesutil.RecvAsFirstArg(typ.(*types.Signature))
}
concrete := typesutil.NewSelection(sel.Kind(), recv, index, obj, typ)
r.selMemo[sel] = concrete
return concrete
default:
panic(fmt.Errorf("unexpected selection kind %v: %v", sel.Kind(), sel))
}
}

// String gets a strings representation of the resolver for debugging.
func (r *Resolver) String() string {
if r == nil {
return `{}`
}

parts := make([]string, 0, len(r.tArgs))
for i, ta := range r.tArgs {
parts = append(parts, fmt.Sprintf("%s->%s", r.tParams.At(i), ta))
}

nestStr := ``
if r.parent != nil {
nestStr = r.parent.String() + `:`
}
return nestStr + `{` + strings.Join(parts, `, `) + `}`
}

// visitor implements ast.Visitor and collects instances of generic types and
// functions into an InstanceSet.
//
Expand All @@ -151,7 +16,9 @@ type visitor struct {
instances *PackageInstanceSets
resolver *Resolver
info *types.Info
tNest []types.Type // The type arguments for a nested context.

nestTParams *types.TypeParamList // The type parameters for a nested context.
nestTArgs []types.Type // The type arguments for a nested context.
}

var _ ast.Visitor = &visitor{}
Expand Down Expand Up @@ -195,12 +62,14 @@ func (c *visitor) visitInstance(ident *ast.Ident, inst types.Instance) {

// If the object is defined in the same scope as the instance,
// then we apply the current nested type arguments.
var tNest []types.Type
var nestTParams *types.TypeParamList
var nestTArgs []types.Type
if obj.Parent().Contains(ident.Pos()) {
tNest = c.tNest
nestTParams = c.nestTParams
nestTArgs = c.nestTArgs
}

c.addInstance(obj, tArgs, tNest)
c.addInstance(obj, tArgs, nestTParams, nestTArgs)
}

func (c *visitor) visitNestedType(obj types.Object) {
Expand All @@ -222,12 +91,12 @@ func (c *visitor) visitNestedType(obj types.Object) {
return
}

c.addInstance(obj, nil, c.resolver.TypeArgs())
c.addInstance(obj, nil, c.resolver.TypeParams(), c.resolver.TypeArgs())
}

func (c *visitor) addInstance(obj types.Object, tArgList *types.TypeList, tNest []types.Type) {
func (c *visitor) addInstance(obj types.Object, tArgList *types.TypeList, nestTParams *types.TypeParamList, nestTArgs []types.Type) {
tArgs := c.resolver.SubstituteAll(tArgList)
if isGeneric(tArgs...) {
if isGeneric(nestTParams, tArgs) {
// Skip any instances that still have type parameters in them after
// substitution. This occurs when a type is defined while nested
// in a generic context and is not fully instantiated yet.
Expand All @@ -238,7 +107,7 @@ func (c *visitor) addInstance(obj types.Object, tArgList *types.TypeList, tNest
c.instances.Add(Instance{
Object: obj,
TArgs: tArgs,
TNest: tNest,
TNest: nestTArgs,
})

if t, ok := obj.Type().(*types.Named); ok {
Expand All @@ -247,7 +116,7 @@ func (c *visitor) addInstance(obj types.Object, tArgList *types.TypeList, tNest
c.instances.Add(Instance{
Object: method.Origin(),
TArgs: tArgs,
TNest: tNest,
TNest: nestTArgs,
})
}
}
Expand Down Expand Up @@ -358,12 +227,13 @@ func (c *Collector) Scan(pkg *types.Package, files ...*ast.File) {
}

func (c *Collector) scanSignature(inst Instance, typ *types.Signature, objMap map[types.Object]ast.Node) {
tParams := SignatureTypeParams(typ)
v := visitor{
instances: c.Instances,
resolver: NewResolver(c.TContext, tParams, inst.TArgs, nil),
resolver: NewResolver(c.TContext, inst),
info: c.Info,
tNest: inst.TArgs,

nestTParams: SignatureTypeParams(typ),
nestTArgs: inst.TArgs,
}
ast.Walk(&v, objMap[inst.Object])
}
Expand All @@ -377,18 +247,19 @@ func (c *Collector) scanNamed(inst Instance, typ *types.Named, objMap map[types.
return
}

var nestResolver *Resolver
if len(inst.TNest) > 0 {
fn := FindNestingFunc(inst.Object)
tp := SignatureTypeParams(fn.Type().(*types.Signature))
nestResolver = NewResolver(c.TContext, tp, inst.TNest, nil)
var nestTParams *types.TypeParamList
nest := FindNestingFunc(obj)
if nest != nil {
nestTParams = SignatureTypeParams(nest.Type().(*types.Signature))
}

v := visitor{
instances: c.Instances,
resolver: NewResolver(c.TContext, typ.TypeParams(), inst.TArgs, nestResolver),
resolver: NewResolver(c.TContext, inst),
info: c.Info,
tNest: inst.TNest,

nestTParams: nestTParams,
nestTArgs: inst.TNest,
}
ast.Walk(&v, node)
}
Loading
Loading