Skip to content

Rudimentary support for passing type parameters into generic functions #1161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 64 additions & 20 deletions compiler/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,14 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
if isUnsigned(basic) {
shift = ">>>"
}
return fc.formatExpr(`(%1s = %2e / %3e, (%1s === %1s && %1s !== 1/0 && %1s !== -1/0) ? %1s %4s 0 : $throwRuntimeError("integer divide by zero"))`, fc.newVariable("_q"), e.X, e.Y, shift)
return fc.formatExpr(`(%1s = %2e / %3e, (%1s === %1s && %1s !== 1/0 && %1s !== -1/0) ? %1s %4s 0 : $throwRuntimeError("integer divide by zero"))`, fc.newLocalVariable("_q"), e.X, e.Y, shift)
}
if basic.Kind() == types.Float32 {
return fc.fixNumber(fc.formatExpr("%e / %e", e.X, e.Y), basic)
}
return fc.formatExpr("%e / %e", e.X, e.Y)
case token.REM:
return fc.formatExpr(`(%1s = %2e %% %3e, %1s === %1s ? %1s : $throwRuntimeError("integer divide by zero"))`, fc.newVariable("_r"), e.X, e.Y)
return fc.formatExpr(`(%1s = %2e %% %3e, %1s === %1s ? %1s : $throwRuntimeError("integer divide by zero"))`, fc.newLocalVariable("_r"), e.X, e.Y)
case token.SHL, token.SHR:
op := e.Op.String()
if e.Op == token.SHR && isUnsigned(basic) {
Expand All @@ -385,7 +385,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
if e.Op == token.SHR && !isUnsigned(basic) {
return fc.fixNumber(fc.formatParenExpr("%e >> $min(%f, 31)", e.X, e.Y), basic)
}
y := fc.newVariable("y")
y := fc.newLocalVariable("y")
return fc.fixNumber(fc.formatExpr("(%s = %f, %s < 32 ? (%e %s %s) : 0)", y, e.Y, y, e.X, op, y), basic)
case token.AND, token.OR:
if isUnsigned(basic) {
Expand All @@ -408,7 +408,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
if fc.Blocking[e.Y] {
skipCase := fc.caseCounter
fc.caseCounter++
resultVar := fc.newVariable("_v")
resultVar := fc.newLocalVariable("_v")
fc.Printf("if (!(%s)) { %s = false; $s = %d; continue s; }", fc.translateExpr(e.X), resultVar, skipCase)
fc.Printf("%s = %s; case %d:", resultVar, fc.translateExpr(e.Y), skipCase)
return fc.formatExpr("%s", resultVar)
Expand All @@ -418,7 +418,7 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
if fc.Blocking[e.Y] {
skipCase := fc.caseCounter
fc.caseCounter++
resultVar := fc.newVariable("_v")
resultVar := fc.newLocalVariable("_v")
fc.Printf("if (%s) { %s = true; $s = %d; continue s; }", fc.translateExpr(e.X), resultVar, skipCase)
fc.Printf("%s = %s; case %d:", resultVar, fc.translateExpr(e.Y), skipCase)
return fc.formatExpr("%s", resultVar)
Expand Down Expand Up @@ -477,25 +477,33 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
if _, isTuple := exprType.(*types.Tuple); isTuple {
return fc.formatExpr(
`(%1s = $mapIndex(%2e,%3s), %1s !== undefined ? [%1s.v, true] : [%4e, false])`,
fc.newVariable("_entry"),
fc.newLocalVariable("_entry"),
e.X,
key,
fc.zeroValue(t.Elem()),
)
}
return fc.formatExpr(
`(%1s = $mapIndex(%2e,%3s), %1s !== undefined ? %1s.v : %4e)`,
fc.newVariable("_entry"),
fc.newLocalVariable("_entry"),
e.X,
key,
fc.zeroValue(t.Elem()),
)
case *types.Basic:
return fc.formatExpr("%e.charCodeAt(%f)", e.X, e.Index)
case *types.Signature:
return fc.translateGenericInstance(e)
default:
panic(fmt.Sprintf("Unhandled IndexExpr: %T\n", t))
panic(fmt.Errorf("unhandled IndexExpr: %T", t))
}
case *ast.IndexListExpr:
switch t := fc.pkgCtx.TypeOf(e.X).Underlying().(type) {
case *types.Signature:
return fc.translateGenericInstance(e)
default:
panic(fmt.Errorf("unhandled IndexListExpr: %T", t))
}

case *ast.SliceExpr:
if b, isBasic := fc.pkgCtx.TypeOf(e.X).Underlying().(*types.Basic); isBasic && isString(b) {
switch {
Expand Down Expand Up @@ -642,13 +650,13 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
case "Call":
if id, ok := fc.identifierConstant(e.Args[0]); ok {
if e.Ellipsis.IsValid() {
objVar := fc.newVariable("obj")
objVar := fc.newLocalVariable("obj")
return fc.formatExpr("(%s = %s, %s.%s.apply(%s, %s))", objVar, recv, objVar, id, objVar, externalizeExpr(e.Args[1]))
}
return fc.formatExpr("%s(%s)", globalRef(id), externalizeArgs(e.Args[1:]))
}
if e.Ellipsis.IsValid() {
objVar := fc.newVariable("obj")
objVar := fc.newLocalVariable("obj")
return fc.formatExpr("(%s = %s, %s[$externalize(%e, $String)].apply(%s, %s))", objVar, recv, objVar, e.Args[0], objVar, externalizeExpr(e.Args[1]))
}
return fc.formatExpr("%s[$externalize(%e, $String)](%s)", recv, e.Args[0], externalizeArgs(e.Args[1:]))
Expand Down Expand Up @@ -749,6 +757,10 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
case *types.Var, *types.Const:
return fc.formatExpr("%s", fc.objectName(o))
case *types.Func:
if _, ok := fc.pkgCtx.Info.Instances[e]; ok {
// Generic function call with auto-inferred types.
return fc.translateGenericInstance(e)
}
return fc.formatExpr("%s", fc.objectName(o))
case *types.TypeName:
return fc.formatExpr("%s", fc.typeName(o.Type()))
Expand Down Expand Up @@ -788,14 +800,46 @@ func (fc *funcContext) translateExpr(expr ast.Expr) *expression {
}
}

// translateGenericInstance translates a generic function instantiation.
//
// The returned JS expression evaluates into a callable function with type params
// substituted.
func (fc *funcContext) translateGenericInstance(e ast.Expr) *expression {
var identifier *ast.Ident
switch e := e.(type) {
case *ast.Ident:
identifier = e
case *ast.IndexExpr:
identifier = e.X.(*ast.Ident)
case *ast.IndexListExpr:
identifier = e.X.(*ast.Ident)
default:
err := bailout(fmt.Errorf("unexpected generic instantiation expression type %T at %s", e, fc.pkgCtx.fileSet.Position(e.Pos())))
panic(err)
}

instance, ok := fc.pkgCtx.Info.Instances[identifier]
if !ok {
err := fmt.Errorf("no matching generic instantiation for %q at %s", identifier, fc.pkgCtx.fileSet.Position(identifier.Pos()))
bailout(err)
}
typeParams := []string{}
for i := 0; i < instance.TypeArgs.Len(); i++ {
t := instance.TypeArgs.At(i)
typeParams = append(typeParams, fc.typeName(t))
}
o := fc.pkgCtx.Uses[identifier]
return fc.formatExpr("%s(%s)", fc.objectName(o), strings.Join(typeParams, ", "))
}

func (fc *funcContext) translateCall(e *ast.CallExpr, sig *types.Signature, fun *expression) *expression {
args := fc.translateArgs(sig, e.Args, e.Ellipsis.IsValid())
if fc.Blocking[e] {
resumeCase := fc.caseCounter
fc.caseCounter++
returnVar := "$r"
if sig.Results().Len() != 0 {
returnVar = fc.newVariable("_r")
returnVar = fc.newLocalVariable("_r")
}
fc.Printf("%[1]s = %[2]s(%[3]s); /* */ $s = %[4]d; case %[4]d: if($c) { $c = false; %[1]s = %[1]s.$blk(); } if (%[1]s && %[1]s.$blk !== undefined) { break s; }", returnVar, fun, strings.Join(args, ", "), resumeCase)
if sig.Results().Len() != 0 {
Expand Down Expand Up @@ -845,7 +889,7 @@ func (fc *funcContext) delegatedCall(expr *ast.CallExpr) (callable *expression,
ellipsis := expr.Ellipsis

for i := range expr.Args {
v := fc.newVariable("_arg")
v := fc.newLocalVariable("_arg")
vars[i] = v
// Subtle: the proxy lambda argument needs to be assigned with the type
// that the original function expects, and not with the argument
Expand Down Expand Up @@ -1124,8 +1168,8 @@ func (fc *funcContext) translateConversion(expr ast.Expr, desiredType types.Type
}
if ptr, isPtr := fc.pkgCtx.TypeOf(expr).(*types.Pointer); fc.pkgCtx.Pkg.Path() == "syscall" && isPtr {
if s, isStruct := ptr.Elem().Underlying().(*types.Struct); isStruct {
array := fc.newVariable("_array")
target := fc.newVariable("_struct")
array := fc.newLocalVariable("_array")
target := fc.newLocalVariable("_struct")
fc.Printf("%s = new Uint8Array(%d);", array, sizes32.Sizeof(s))
fc.Delayed(func() {
fc.Printf("%s = %s, %s;", target, fc.translateExpr(expr), fc.loadStruct(array, target, s))
Expand Down Expand Up @@ -1173,8 +1217,8 @@ func (fc *funcContext) translateConversion(expr ast.Expr, desiredType types.Type
// struct pointer when handling syscalls.
// TODO(nevkontakte): Add a runtime assertion that the unsafe.Pointer is
// indeed pointing at a byte array.
array := fc.newVariable("_array")
target := fc.newVariable("_struct")
array := fc.newLocalVariable("_array")
target := fc.newLocalVariable("_struct")
return fc.formatExpr("(%s = %e, %s = %e, %s, %s)", array, expr, target, fc.zeroValue(t.Elem()), fc.loadStruct(array, target, ptrElType), target)
}
// Convert between structs of different types but identical layouts,
Expand All @@ -1196,7 +1240,7 @@ func (fc *funcContext) translateConversion(expr ast.Expr, desiredType types.Type
// type iPtr *int; var c int = 42; println((iPtr)(&c));
// TODO(nevkontakte): Are there any other cases that fall into this case?
exprTypeElem := exprType.Underlying().(*types.Pointer).Elem()
ptrVar := fc.newVariable("_ptr")
ptrVar := fc.newLocalVariable("_ptr")
getterConv := fc.translateConversion(fc.setType(&ast.StarExpr{X: fc.newIdent(ptrVar, exprType)}, exprTypeElem), t.Elem())
setterConv := fc.translateConversion(fc.newIdent("$v", t.Elem()), exprTypeElem)
return fc.formatExpr("(%1s = %2e, new %3s(function() { return %4s; }, function($v) { %1s.$set(%5s); }, %1s.$target))", ptrVar, expr, fc.typeName(desiredType), getterConv, setterConv)
Expand Down Expand Up @@ -1268,7 +1312,7 @@ func (fc *funcContext) translateConversionToSlice(expr ast.Expr, desiredType typ
}

func (fc *funcContext) loadStruct(array, target string, s *types.Struct) string {
view := fc.newVariable("_view")
view := fc.newLocalVariable("_view")
code := fmt.Sprintf("%s = new DataView(%s.buffer, %s.byteOffset)", view, array, array)
var fields []*types.Var
var collectFields func(s *types.Struct, path string)
Expand Down Expand Up @@ -1398,7 +1442,7 @@ func (fc *funcContext) formatExprInternal(format string, a []interface{}, parens
out.WriteByte('(')
parens = false
}
v := fc.newVariable("x")
v := fc.newLocalVariable("x")
out.WriteString(v + " = " + fc.translateExpr(e.(ast.Expr)).String() + ", ")
vars[i] = v
}
Expand Down
28 changes: 23 additions & 5 deletions compiler/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
Implicits: make(map[ast.Node]types.Object),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
Scopes: make(map[ast.Node]*types.Scope),
Instances: make(map[*ast.Ident]types.Instance),
}

var errList ErrorList
Expand Down Expand Up @@ -294,7 +295,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
// but now we do it here to maintain previous behavior.
continue
}
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariableWithLevel(importedPkg.Name(), true)
funcCtx.pkgCtx.pkgVars[importedPkg.Path()] = funcCtx.newVariable(importedPkg.Name(), varPackage)
importedPaths = append(importedPaths, importedPkg.Path())
}
sort.Strings(importedPaths)
Expand Down Expand Up @@ -521,7 +522,7 @@ func Compile(importPath string, files []*ast.File, fileSet *token.FileSet, impor
d.DeclCode = funcCtx.CatchOutput(0, func() {
typeName := funcCtx.objectName(o)
lhs := typeName
if isPkgLevel(o) {
if typeVarLevel(o) == varPackage {
lhs += " = $pkg." + encodeIdent(o.Name())
}
size := int64(0)
Expand Down Expand Up @@ -773,12 +774,12 @@ func translateFunction(typ *ast.FuncType, recv *ast.Ident, body *ast.BlockStmt,
var params []string
for _, param := range typ.Params.List {
if len(param.Names) == 0 {
params = append(params, c.newVariable("param"))
params = append(params, c.newLocalVariable("param"))
continue
}
for _, ident := range param.Names {
if isBlank(ident) {
params = append(params, c.newVariable("param"))
params = append(params, c.newLocalVariable("param"))
continue
}
params = append(params, c.objectName(c.pkgCtx.Defs[ident]))
Expand Down Expand Up @@ -898,5 +899,22 @@ func translateFunction(typ *ast.FuncType, recv *ast.Ident, body *ast.BlockStmt,

c.pkgCtx.escapingVars = prevEV

return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
if !c.sigTypes.IsGeneric() {
return params, fmt.Sprintf("function%s(%s) {\n%s%s}", functionName, strings.Join(params, ", "), bodyOutput, c.Indentation(0))
}

// Generic functions are generated as factories to allow passing type parameters
// from the call site.
// TODO(nevkontakte): Cache function instances for a given combination of type
// parameters.
// TODO(nevkontakte): Generate type parameter arguments and derive all dependent
// types inside the function.
typeParams := []string{}
for i := 0; i < c.sigTypes.Sig.TypeParams().Len(); i++ {
typeParam := c.sigTypes.Sig.TypeParams().At(i)
typeParams = append(typeParams, c.typeName(typeParam))
}

return params, fmt.Sprintf("function%s(%s){ return function(%s) {\n%s%s}; }",
functionName, strings.Join(typeParams, ", "), strings.Join(params, ", "), bodyOutput, c.Indentation(0))
}
Loading