diff --git a/build/build.go b/build/build.go index 070d05df1..ccacee90c 100644 --- a/build/build.go +++ b/build/build.go @@ -117,6 +117,20 @@ func ImportDir(dir string, mode build.ImportMode, installSuffix string, buildTag return pkg, nil } +// overrideInfo is used by parseAndAugment methods to manage +// directives and how the overlay and original are merged. +type overrideInfo struct { + // KeepOriginal indicates that the original code should be kept + // but the identifier will be prefixed by `_gopherjs_original_foo`. + keepOriginal bool + + // purgeMethods indicates that this info is for a type and + // if a method has this type as a receiver should also be removed. + // If the method is defined in the overlays and therefore has its + // own overrides, this will be ignored. + purgeMethods bool +} + // parseAndAugment parses and returns all .go files of given pkg. // Standard Go library packages are augmented with files in compiler/natives folder. // If isTest is true and pkg.ImportPath has no _test suffix, package is built for running internal tests. @@ -125,31 +139,54 @@ func ImportDir(dir string, mode build.ImportMode, installSuffix string, buildTag // The native packages are augmented by the contents of natives.FS in the following way. // The file names do not matter except the usual `_test` suffix. The files for // native overrides get added to the package (even if they have the same name -// as an existing file from the standard library). For function identifiers that exist -// in the original AND the overrides AND that include the following directive in their comment: -// //gopherjs:keep-original, the original identifier in the AST gets prefixed by -// `_gopherjs_original_`. For other identifiers that exist in the original AND the overrides, -// the original identifier gets replaced by `_`. New identifiers that don't exist in original -// package get added. +// as an existing file from the standard library). +// +// - For function identifiers that exist in the original and the overrides and have the +// directive `gopherjs:keep-original`, the original identifier in the AST gets +// prefixed by `_gopherjs_original_`. +// - For identifiers that exist in the original and the overrides +// and have the directive `gopherjs:purge`, both the original and override are removed. +// This is for completely removing something which is currently invalid for GopherJS. +// For any purged types any methods with that type as the receiver are also removed. +// - Otherwise for identifiers that exist in the original and the overrides, +// the original identifier is removed. +// - New identifiers that don't exist in original package get added. func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *token.FileSet) ([]*ast.File, []JSFile, error) { - var files []*ast.File + jsFiles, overlayFiles := parseOverlayFiles(xctx, pkg, isTest, fileSet) - type overrideInfo struct { - keepOriginal bool - pruneOriginal bool + originalFiles, err := parserOriginalFiles(pkg, fileSet) + if err != nil { + return nil, nil, err } - replacedDeclNames := make(map[string]overrideInfo) + overrides := make(map[string]overrideInfo) + for _, file := range overlayFiles { + augmentOverlayFile(fileSet, file, overrides) + pruneImports(fileSet, file) + } + delete(overrides, "init") + + for _, file := range originalFiles { + augmentOriginalImports(pkg.ImportPath, file) + augmentOriginalFile(fileSet, file, overrides) + pruneImports(fileSet, file) + } + + return append(overlayFiles, originalFiles...), jsFiles, nil +} + +// parseOverlayFiles loads and parses overlay files +// to augment the original files with. +func parseOverlayFiles(xctx XContext, pkg *PackageData, isTest bool, fileSet *token.FileSet) ([]JSFile, []*ast.File) { isXTest := strings.HasSuffix(pkg.ImportPath, "_test") importPath := pkg.ImportPath if isXTest { importPath = importPath[:len(importPath)-5] } - jsFiles := []JSFile{} - + var jsFiles []JSFile + var files []*ast.File nativesContext := overlayCtx(xctx.Env()) - if nativesPkg, err := nativesContext.Import(importPath, "", 0); err == nil { jsFiles = nativesPkg.JSFiles names := nativesPkg.GoFiles @@ -159,6 +196,7 @@ func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *toke if isXTest { names = nativesPkg.XTestGoFiles } + for _, name := range names { fullPath := path.Join(nativesPkg.Dir, name) r, err := nativesContext.bctx.OpenFile(fullPath) @@ -173,34 +211,17 @@ func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *toke panic(err) } r.Close() - for _, decl := range file.Decls { - switch d := decl.(type) { - case *ast.FuncDecl: - k := astutil.FuncKey(d) - replacedDeclNames[k] = overrideInfo{ - keepOriginal: astutil.KeepOriginal(d), - pruneOriginal: astutil.PruneOriginal(d), - } - case *ast.GenDecl: - switch d.Tok { - case token.TYPE: - for _, spec := range d.Specs { - replacedDeclNames[spec.(*ast.TypeSpec).Name.Name] = overrideInfo{} - } - case token.VAR, token.CONST: - for _, spec := range d.Specs { - for _, name := range spec.(*ast.ValueSpec).Names { - replacedDeclNames[name.Name] = overrideInfo{} - } - } - } - } - } + files = append(files, file) } } - delete(replacedDeclNames, "init") + return jsFiles, files +} + +// parserOriginalFiles loads and parses the original files to augment. +func parserOriginalFiles(pkg *PackageData, fileSet *token.FileSet) ([]*ast.File, error) { + var files []*ast.File var errList compiler.ErrorList for _, name := range pkg.GoFiles { if !filepath.IsAbs(name) { // name might be absolute if specified directly. E.g., `gopherjs build /abs/file.go`. @@ -208,7 +229,7 @@ func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *toke } r, err := buildutil.OpenFile(pkg.bctx, name) if err != nil { - return nil, nil, err + return nil, err } file, err := parser.ParseFile(fileSet, name, r, parser.ParseComments) r.Close() @@ -226,68 +247,233 @@ func parseAndAugment(xctx XContext, pkg *PackageData, isTest bool, fileSet *toke continue } - switch pkg.ImportPath { - case "crypto/rand", "encoding/gob", "encoding/json", "expvar", "go/token", "log", "math/big", "math/rand", "regexp", "time": - for _, spec := range file.Imports { - path, _ := strconv.Unquote(spec.Path.Value) - if path == "sync" { - if spec.Name == nil { - spec.Name = ast.NewIdent("sync") + files = append(files, file) + } + + if errList != nil { + return nil, errList + } + return files, nil +} + +// augmentOverlayFile is the part of parseAndAugment that processes +// an overlay file AST to collect information such as compiler directives and +// perform any initial augmentation needed to the overlay. +func augmentOverlayFile(fileSet *token.FileSet, file *ast.File, overrides map[string]overrideInfo) { + for i, decl := range file.Decls { + purgeDecl := astutil.Purge(decl) + switch d := decl.(type) { + case *ast.FuncDecl: + k := astutil.FuncKey(d) + overrides[k] = overrideInfo{ + keepOriginal: astutil.KeepOriginal(d), + } + case *ast.GenDecl: + for j, spec := range d.Specs { + purgeSpec := purgeDecl || astutil.Purge(spec) + switch s := spec.(type) { + case *ast.TypeSpec: + overrides[s.Name.Name] = overrideInfo{ + purgeMethods: purgeSpec, } - spec.Path.Value = `"github.com/gopherjs/gopherjs/nosync"` + case *ast.ValueSpec: + for _, name := range s.Names { + overrides[name.Name] = overrideInfo{} + } + } + if purgeSpec { + d.Specs[j] = nil + } + } + } + if purgeDecl { + file.Decls[i] = nil + } + } + finalizeRemovals(file) +} + +// augmentOriginalImports is the part of parseAndAugment that processes +// an original file AST to modify the imports for that file. +func augmentOriginalImports(importPath string, file *ast.File) { + switch importPath { + case "crypto/rand", "encoding/gob", "encoding/json", "expvar", "go/token", "log", "math/big", "math/rand", "regexp", "time": + for _, spec := range file.Imports { + path, _ := strconv.Unquote(spec.Path.Value) + if path == "sync" { + if spec.Name == nil { + spec.Name = ast.NewIdent("sync") } + spec.Path.Value = `"github.com/gopherjs/gopherjs/nosync"` } } + } +} - for _, decl := range file.Decls { - switch d := decl.(type) { - case *ast.FuncDecl: - k := astutil.FuncKey(d) - if info, ok := replacedDeclNames[k]; ok { - if info.pruneOriginal { - // Prune function bodies, since it may contain code invalid for - // GopherJS and pin unwanted imports. - d.Body = nil +// augmentOriginalFile is the part of parseAndAugment that processes an +// original file AST to augment the source code using the overrides from +// the overlay files. +func augmentOriginalFile(fileSet *token.FileSet, file *ast.File, overrides map[string]overrideInfo) { + for i, decl := range file.Decls { + switch d := decl.(type) { + case *ast.FuncDecl: + if info, ok := overrides[astutil.FuncKey(d)]; ok { + if info.keepOriginal { + // Allow overridden function calls + // The standard library implementation of foo() becomes _gopherjs_original_foo() + d.Name.Name = `_gopherjs_original_` + d.Name.Name + } else { + file.Decls[i] = nil + } + } else if recvKey := astutil.FuncReceiverKey(d); len(recvKey) > 0 { + // check if the receiver has been purged, if so, remove the method too. + if info, ok := overrides[recvKey]; ok && info.purgeMethods { + file.Decls[i] = nil + } + } + case *ast.GenDecl: + for j, spec := range d.Specs { + switch s := spec.(type) { + case *ast.TypeSpec: + if _, ok := overrides[s.Name.Name]; ok { + d.Specs[j] = nil } - if info.keepOriginal { - // Allow overridden function calls - // The standard library implementation of foo() becomes _gopherjs_original_foo() - d.Name.Name = "_gopherjs_original_" + d.Name.Name + case *ast.ValueSpec: + if len(s.Names) == len(s.Values) { + // multi-value context + // e.g. var a, b = 2, 3 + for k, name := range s.Names { + if _, ok := overrides[name.Name]; ok { + s.Names[k] = nil + s.Values[k] = nil + } + } } else { - d.Name = ast.NewIdent("_") - } - } - case *ast.GenDecl: - switch d.Tok { - case token.TYPE: - for _, spec := range d.Specs { - s := spec.(*ast.TypeSpec) - if _, ok := replacedDeclNames[s.Name.Name]; ok { - s.Name = ast.NewIdent("_") - s.Type = &ast.StructType{Struct: s.Pos(), Fields: &ast.FieldList{}} - s.TypeParams = nil + // single-value context + // e.g. var a, b = func() (int, int) { return 2, 3 }() + nameRemoved := false + for _, name := range s.Names { + if _, ok := overrides[name.Name]; ok { + nameRemoved = true + name.Name = `_` + } } - } - case token.VAR, token.CONST: - for _, spec := range d.Specs { - s := spec.(*ast.ValueSpec) - for i, name := range s.Names { - if _, ok := replacedDeclNames[name.Name]; ok { - s.Names[i] = ast.NewIdent("_") + if nameRemoved { + removeSpec := true + for _, name := range s.Names { + if name.Name != `_` { + removeSpec = false + break + } + } + if removeSpec { + d.Specs[j] = nil } } } } } } + } + finalizeRemovals(file) +} - files = append(files, file) +// pruneImports will remove any unused imports from the file. +// +// This will not remove any dot (`.`) or blank (`_`) imports. +func pruneImports(fileSet *token.FileSet, file *ast.File) { + unused := make(map[string]int, len(file.Imports)) + for i, in := range file.Imports { + if name := astutil.ImportName(in); len(name) > 0 { + unused[name] = i + } } - if errList != nil { - return nil, nil, errList + // Remove any import which is used in a selector. + ast.Walk(astutil.NewCallbackVisitor(func(n ast.Node) bool { + if sel, ok := n.(*ast.SelectorExpr); ok { + if id, ok := sel.X.(*ast.Ident); ok && id.Obj == nil { + delete(unused, id.Name) + } + } + return len(unused) > 0 + }), file) + + if len(unused) == 0 { + return + } + + // Remove all import specifications + isUnusedSpec := map[*ast.ImportSpec]bool{} + for _, index := range unused { + isUnusedSpec[file.Imports[index]] = true + } + for _, decl := range file.Decls { + if d, ok := decl.(*ast.GenDecl); ok { + for i, spec := range d.Specs { + if other, ok := spec.(*ast.ImportSpec); ok && isUnusedSpec[other] { + d.Specs[i] = nil + } + } + } + } + + // Remove the import copies in the file + for _, index := range unused { + file.Imports[index] = nil + } + + finalizeRemovals(file) +} + +// finalizeRemovals fully removes any declaration, specification, imports +// that have been set to nil. This will also remove the file's top-level +// comment group to remove any unassociated comments, including the comments +// from removed code. +func finalizeRemovals(file *ast.File) { + fileChanged := false + for i, decl := range file.Decls { + switch d := decl.(type) { + case nil: + fileChanged = true + case *ast.GenDecl: + declChanged := false + for j, spec := range d.Specs { + switch s := spec.(type) { + case nil: + declChanged = true + case *ast.ValueSpec: + specChanged := false + for _, name := range s.Names { + if name == nil { + specChanged = true + break + } + } + if specChanged { + s.Names = astutil.Squeeze(s.Names) + s.Values = astutil.Squeeze(s.Values) + if len(s.Names) == 0 { + declChanged = true + d.Specs[j] = nil + } + } + } + } + if declChanged { + d.Specs = astutil.Squeeze(d.Specs) + if len(d.Specs) == 0 { + fileChanged = true + file.Decls[i] = nil + } + } + } + } + if fileChanged { + file.Decls = astutil.Squeeze(file.Decls) } - return files, jsFiles, nil + file.Imports = astutil.Squeeze(file.Imports) + file.Comments = nil } // Options controls build process behavior. @@ -678,7 +864,7 @@ func (s *Session) BuildPackage(pkg *PackageData) (*compiler.Archive, error) { archive := s.buildCache.LoadArchive(pkg.ImportPath) if archive != nil && !pkg.SrcModTime.After(archive.BuildTime) { if err := archive.RegisterTypes(s.Types); err != nil { - panic(fmt.Errorf("Failed to load type information from %v: %w", archive, err)) + panic(fmt.Errorf("failed to load type information from %v: %w", archive, err)) } s.UpToDateArchives[pkg.ImportPath] = archive // Existing archive is up to date, no need to build it from scratch. diff --git a/build/build_test.go b/build/build_test.go index 2fa17e2c5..8b00336bd 100644 --- a/build/build_test.go +++ b/build/build_test.go @@ -1,12 +1,15 @@ package build import ( + "bytes" "fmt" gobuild "go/build" + "go/printer" "go/token" "strconv" "testing" + "github.com/gopherjs/gopherjs/internal/srctesting" "github.com/shurcooL/go/importgraphutil" ) @@ -127,3 +130,388 @@ func (m stringSet) String() string { } return fmt.Sprintf("%q", s) } + +func TestOverlayAugmentation(t *testing.T) { + tests := []struct { + desc string + src string + want string + expInfo map[string]overrideInfo + }{ + { + desc: `prune an unused import`, + src: `import foo "some/other/bar"`, + want: ``, + expInfo: map[string]overrideInfo{}, + }, { + desc: `remove function`, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func Foo(a, b int) int { + return a + b + }`, + expInfo: map[string]overrideInfo{ + `Foo`: {}, + }, + }, { + desc: `keep function`, + src: `//gopherjs:keep-original + func Foo(a, b int) int { + return a + b + }`, + want: `//gopherjs:keep-original + func Foo(a, b int) int { + return a + b + }`, + expInfo: map[string]overrideInfo{ + `Foo`: {keepOriginal: true}, + }, + }, { + desc: `purge function`, + src: `//gopherjs:purge + func Foo(a, b int) int { + return a + b + }`, + want: ``, + expInfo: map[string]overrideInfo{ + `Foo`: {}, + }, + }, { + desc: `purge struct removes an import`, + src: `import "bytes" + import "math" + + //gopherjs:purge + type Foo struct { + bar *bytes.Buffer + } + + const Tau = math.Pi * 2.0`, + want: `import "math" + + const Tau = math.Pi * 2.0`, + expInfo: map[string]overrideInfo{ + `Foo`: {purgeMethods: true}, + `Tau`: {}, + }, + }, { + desc: `purge whole type decl`, + src: `//gopherjs:purge + type ( + Foo struct {} + bar interface{} + bob int + )`, + want: ``, + expInfo: map[string]overrideInfo{ + `Foo`: {purgeMethods: true}, + `bar`: {purgeMethods: true}, + `bob`: {purgeMethods: true}, + }, + }, { + desc: `purge part of type decl`, + src: `type ( + Foo struct {} + + //gopherjs:purge + bar interface{} + + //gopherjs:purge + bob int + )`, + want: `type ( + Foo struct {} + )`, + expInfo: map[string]overrideInfo{ + `Foo`: {}, + `bar`: {purgeMethods: true}, + `bob`: {purgeMethods: true}, + }, + }, { + desc: `purge all of a type decl`, + src: `type ( + //gopherjs:purge + Foo struct {} + )`, + want: ``, + expInfo: map[string]overrideInfo{ + `Foo`: {purgeMethods: true}, + }, + }, { + desc: `remove and purge values`, + src: `import "time" + + const ( + foo = 42 + //gopherjs:purge + bar = "gopherjs" + ) + + //gopherjs:purge + var now = time.Now`, + want: `const ( + foo = 42 + )`, + expInfo: map[string]overrideInfo{ + `foo`: {}, + `bar`: {}, + `now`: {}, + }, + }, { + desc: `imports not confused by local variables`, + src: `import ( + "cmp" + "time" + ) + + //gopherjs:purge + func Sort[S ~[]E, E cmp.Ordered](x S) {} + + func SecondsSince(start time.Time) int { + cmp := time.Now().Sub(start) + return int(cmp.Second()) + }`, + want: `import ( + "time" + ) + + func SecondsSince(start time.Time) int { + cmp := time.Now().Sub(start) + return int(cmp.Second()) + }`, + expInfo: map[string]overrideInfo{ + `Sort`: {}, + `SecondsSince`: {}, + }, + }, { + desc: `purge generics`, + src: `import "cmp" + + //gopherjs:purge + type Pointer[T any] struct {} + + //gopherjs:purge + func Sort[S ~[]E, E cmp.Ordered](x S) {} + + // stub for "func Equal[S ~[]E, E any](s1, s2 S) bool" + func Equal() {}`, + want: `// stub for "func Equal[S ~[]E, E any](s1, s2 S) bool" + func Equal() {}`, + expInfo: map[string]overrideInfo{ + `Pointer`: {purgeMethods: true}, + `Sort`: {}, + `Equal`: {}, + }, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + pkgName := "package testpackage\n\n" + fsetSrc := token.NewFileSet() + fileSrc := srctesting.Parse(t, fsetSrc, pkgName+test.src) + + overrides := map[string]overrideInfo{} + augmentOverlayFile(fsetSrc, fileSrc, overrides) + pruneImports(fsetSrc, fileSrc) + + buf := &bytes.Buffer{} + _ = printer.Fprint(buf, fsetSrc, fileSrc) + got := buf.String() + + fsetWant := token.NewFileSet() + fileWant := srctesting.Parse(t, fsetWant, pkgName+test.want) + + buf.Reset() + _ = printer.Fprint(buf, fsetWant, fileWant) + want := buf.String() + + if got != want { + t.Errorf("augmentOverlayFile and pruneImports got unexpected code:\n"+ + "returned:\n\t%q\nwant:\n\t%q", got, want) + } + + for key, expInfo := range test.expInfo { + if gotInfo, ok := overrides[key]; !ok { + t.Errorf(`%q was expected but not gotten`, key) + } else if expInfo != gotInfo { + t.Errorf(`%q had wrong info, got %+v`, key, gotInfo) + } + } + for key, gotInfo := range overrides { + if _, ok := test.expInfo[key]; !ok { + t.Errorf(`%q with %+v was not expected`, key, gotInfo) + } + } + }) + } +} + +func TestOriginalAugmentation(t *testing.T) { + tests := []struct { + desc string + info map[string]overrideInfo + src string + want string + }{ + { + desc: `prune an unused import`, + info: map[string]overrideInfo{}, + src: `import foo "some/other/bar"`, + want: ``, + }, { + desc: `do not affect function`, + info: map[string]overrideInfo{}, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func Foo(a, b int) int { + return a + b + }`, + }, { + desc: `change unnamed sync import`, + info: map[string]overrideInfo{}, + src: `import "sync" + + var _ = &sync.Mutex{}`, + want: `import sync "github.com/gopherjs/gopherjs/nosync" + + var _ = &sync.Mutex{}`, + }, { + desc: `change named sync import`, + info: map[string]overrideInfo{}, + src: `import foo "sync" + + var _ = &foo.Mutex{}`, + want: `import foo "github.com/gopherjs/gopherjs/nosync" + + var _ = &foo.Mutex{}`, + }, { + desc: `remove function`, + info: map[string]overrideInfo{ + `Foo`: {}, + }, + src: `func Foo(a, b int) int { + return a + b + }`, + want: ``, + }, { + desc: `keep original function`, + info: map[string]overrideInfo{ + `Foo`: {keepOriginal: true}, + }, + src: `func Foo(a, b int) int { + return a + b + }`, + want: `func _gopherjs_original_Foo(a, b int) int { + return a + b + }`, + }, { + desc: `remove types and values`, + info: map[string]overrideInfo{ + `Foo`: {}, + `now`: {}, + `bar1`: {}, + }, + src: `import "time" + + type Foo interface{} + var now = time.Now + const bar1, bar2 = 21, 42`, + want: `const bar2 = 42`, + }, { + desc: `remove in multi-value context`, + info: map[string]overrideInfo{ + `bar`: {}, + }, + src: `const foo, bar = func() (int, int) { + return 24, 12 + }()`, + want: `const foo, _ = func() (int, int) { + return 24, 12 + }()`, + }, { + desc: `full remove in multi-value context`, + info: map[string]overrideInfo{ + `bar`: {}, + }, + src: `const _, bar = func() (int, int) { + return 24, 12 + }()`, + want: ``, + }, { + desc: `purge struct and methods`, + info: map[string]overrideInfo{ + `Foo`: {purgeMethods: true}, + }, + src: `type Foo struct{ + bar int + } + + func (f Foo) GetBar() int { + return f.bar + } + + func (f *Foo) SetBar(bar int) { + f.bar = bar + } + + func NewFoo(bar int) *Foo { + return &Foo{bar: bar} + }`, + // NewFoo is not removed automatically since + // only functions with Foo as a receiver is removed. + want: `func NewFoo(bar int) *Foo { + return &Foo{bar: bar} + }`, + }, { + desc: `purge generics`, + info: map[string]overrideInfo{ + `Pointer`: {purgeMethods: true}, + `Sort`: {}, + `Equal`: {}, + }, + src: `import "cmp" + + type Pointer[T any] struct {} + func (x *Pointer[T]) Load() *T {} + func (x *Pointer[T]) Store(val *T) {} + + func Sort[S ~[]E, E cmp.Ordered](x S) {} + + // overlay had stub "func Equal() {}" + func Equal[S ~[]E, E any](s1, s2 S) bool {}`, + want: ``, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + pkgName := "package testpackage\n\n" + importPath := `math/rand` + fsetSrc := token.NewFileSet() + fileSrc := srctesting.Parse(t, fsetSrc, pkgName+test.src) + + augmentOriginalImports(importPath, fileSrc) + augmentOriginalFile(fsetSrc, fileSrc, test.info) + pruneImports(fsetSrc, fileSrc) + + buf := &bytes.Buffer{} + _ = printer.Fprint(buf, fsetSrc, fileSrc) + got := buf.String() + + fsetWant := token.NewFileSet() + fileWant := srctesting.Parse(t, fsetWant, pkgName+test.want) + + buf.Reset() + _ = printer.Fprint(buf, fsetWant, fileWant) + want := buf.String() + + if got != want { + t.Errorf("augmentOriginalImports, augmentOriginalFile, and pruneImports got unexpected code:\n"+ + "returned:\n\t%q\nwant:\n\t%q", got, want) + } + }) + } +} diff --git a/compiler/astutil/astutil.go b/compiler/astutil/astutil.go index 30febe1cb..98e35d872 100644 --- a/compiler/astutil/astutil.go +++ b/compiler/astutil/astutil.go @@ -5,7 +5,10 @@ import ( "go/ast" "go/token" "go/types" - "strings" + "path" + "reflect" + "regexp" + "strconv" ) func RemoveParens(e ast.Expr) ast.Expr { @@ -50,6 +53,7 @@ func IsTypeExpr(expr ast.Expr, info *types.Info) bool { } } +// ImportsUnsafe determines if this file imports the "unsafe" package. func ImportsUnsafe(file *ast.File) bool { for _, imp := range file.Imports { if imp.Path.Value == `"unsafe"` { @@ -59,38 +63,121 @@ func ImportsUnsafe(file *ast.File) bool { return false } +// ImportName tries to determine the package name for an import. +// +// If the package name isn't specified then this will make a best +// make a best guess using the import path. +// If the import name is dot (`.`), blank (`_`), or there +// was an issue determining the package name then empty is returned. +func ImportName(spec *ast.ImportSpec) string { + if spec == nil { + return `` + } + + var name string + if spec.Name != nil { + name = spec.Name.Name + } else { + importPath, _ := strconv.Unquote(spec.Path.Value) + name = path.Base(importPath) + } + + switch name { + case `_`, `.`, `/`: + return `` + default: + return name + } +} + // FuncKey returns a string, which uniquely identifies a top-level function or // method in a package. func FuncKey(d *ast.FuncDecl) string { - if d.Recv == nil || len(d.Recv.List) == 0 { - return d.Name.Name + if recvKey := FuncReceiverKey(d); len(recvKey) > 0 { + return recvKey + "." + d.Name.Name + } + return d.Name.Name +} + +// FuncReceiverKey returns a string that uniquely identifies the receiver +// struct of the function or an empty string if there is no receiver. +// This name will match the name of the struct in the struct's type spec. +func FuncReceiverKey(d *ast.FuncDecl) string { + if d == nil || d.Recv == nil || len(d.Recv.List) == 0 { + return `` } recv := d.Recv.List[0].Type - if star, ok := recv.(*ast.StarExpr); ok { - recv = star.X + for { + switch r := recv.(type) { + case *ast.IndexListExpr: + recv = r.X + continue + case *ast.IndexExpr: + recv = r.X + continue + case *ast.StarExpr: + recv = r.X + continue + case *ast.Ident: + return r.Name + default: + panic(fmt.Errorf(`unexpected type %T in receiver of function: %v`, recv, d)) + } } - return recv.(*ast.Ident).Name + "." + d.Name.Name } -// PruneOriginal returns true if gopherjs:prune-original directive is present -// before a function decl. -// -// `//gopherjs:prune-original` is a GopherJS-specific directive, which can be -// applied to functions in native overlays and will instruct the augmentation -// logic to delete the body of a standard library function that was replaced. -// This directive can be used to remove code that would be invalid in GopherJS, -// such as code expecting ints to be 64-bit. It should be used with caution -// since it may create unused imports in the original source file. -func PruneOriginal(d *ast.FuncDecl) bool { - if d.Doc == nil { - return false - } - for _, c := range d.Doc.List { - if strings.HasPrefix(c.Text, "//gopherjs:prune-original") { - return true +// anyDocLine calls the given predicate on all associated documentation +// lines and line-comment lines from the given node. +// If the predicate returns true for any line then true is returned. +func anyDocLine(node any, predicate func(line string) bool) bool { + switch a := node.(type) { + case *ast.Comment: + return a != nil && predicate(a.Text) + case *ast.CommentGroup: + if a != nil { + for _, c := range a.List { + if anyDocLine(c, predicate) { + return true + } + } } + return false + case *ast.Field: + return a != nil && (anyDocLine(a.Doc, predicate) || anyDocLine(a.Comment, predicate)) + case *ast.File: + return a != nil && anyDocLine(a.Doc, predicate) + case *ast.FuncDecl: + return a != nil && anyDocLine(a.Doc, predicate) + case *ast.GenDecl: + return a != nil && anyDocLine(a.Doc, predicate) + case *ast.ImportSpec: + return a != nil && (anyDocLine(a.Doc, predicate) || anyDocLine(a.Comment, predicate)) + case *ast.TypeSpec: + return a != nil && (anyDocLine(a.Doc, predicate) || anyDocLine(a.Comment, predicate)) + case *ast.ValueSpec: + return a != nil && (anyDocLine(a.Doc, predicate) || anyDocLine(a.Comment, predicate)) + default: + panic(fmt.Errorf(`unexpected node type to get doc from: %T`, node)) } - return false +} + +// directiveMatcher is a regex which matches a GopherJS directive +// and finds the directive action. +var directiveMatcher = regexp.MustCompile(`^\/(?:\/|\*)gopherjs:([\w-]+)`) + +// hasDirective returns true if the associated documentation +// or line comments for the given node have the given directive action. +// +// All GopherJS-specific directives must start with `//gopherjs:` or +// `/*gopherjs:` and followed by an action without any whitespace. The action +// must be one or more letter, decimal, underscore, or hyphen. +// +// see https://pkg.go.dev/cmd/compile#hdr-Compiler_Directives +func hasDirective(node any, directiveAction string) bool { + return anyDocLine(node, func(line string) bool { + m := directiveMatcher.FindStringSubmatch(line) + return len(m) == 2 && m[1] == directiveAction + }) } // KeepOriginal returns true if gopherjs:keep-original directive is present @@ -101,16 +188,23 @@ func PruneOriginal(d *ast.FuncDecl) bool { // logic to expose the original function such that it can be called. For a // function in the original called `foo`, it will be accessible by the name // `_gopherjs_original_foo`. -func KeepOriginal(d *ast.FuncDecl) bool { - if d.Doc == nil { - return false - } - for _, c := range d.Doc.List { - if strings.HasPrefix(c.Text, "//gopherjs:keep-original") { - return true - } - } - return false +func KeepOriginal(d any) bool { + return hasDirective(d, `keep-original`) +} + +// Purge returns true if gopherjs:purge directive is present +// on a struct, interface, type, variable, constant, or function. +// +// `//gopherjs:purge` is a GopherJS-specific directive, which can be +// applied in native overlays and will instruct the augmentation logic to +// delete part of the standard library without a replacement. This directive +// can be used to remove code that would be invalid in GopherJS, such as code +// using unsupported features (e.g. generic interfaces before generics were +// fully supported). It should be used with caution since it may remove needed +// dependencies. If a type is purged, all methods using that type as +// a receiver will also be purged. +func Purge(d any) bool { + return hasDirective(d, `purge`) } // FindLoopStmt tries to find the loop statement among the AST nodes in the @@ -167,3 +261,37 @@ func EndsWithReturn(stmts []ast.Stmt) bool { return false } } + +// Squeeze removes all nil nodes from the slice. +// +// The given slice will be modified. This is designed for squeezing +// declaration, specification, imports, and identifier lists. +func Squeeze[E ast.Node, S ~[]E](s S) S { + var zero E + count, dest := len(s), 0 + for src := 0; src < count; src++ { + if !reflect.DeepEqual(s[src], zero) { + // Swap the values, this will put the nil values to the end + // of the slice so that the tail isn't holding onto pointers. + s[dest], s[src] = s[src], s[dest] + dest++ + } + } + return s[:dest] +} + +type CallbackVisitor struct { + predicate func(node ast.Node) bool +} + +func NewCallbackVisitor(predicate func(node ast.Node) bool) *CallbackVisitor { + return &CallbackVisitor{predicate: predicate} +} + +func (v *CallbackVisitor) Visit(node ast.Node) ast.Visitor { + if v.predicate != nil && v.predicate(node) { + return v + } + v.predicate = nil + return nil +} diff --git a/compiler/astutil/astutil_test.go b/compiler/astutil/astutil_test.go index a996ae73f..20a9c2d1c 100644 --- a/compiler/astutil/astutil_test.go +++ b/compiler/astutil/astutil_test.go @@ -1,7 +1,10 @@ package astutil import ( + "fmt" + "go/ast" "go/token" + "strconv" "testing" "github.com/gopherjs/gopherjs/internal/srctesting" @@ -52,6 +55,47 @@ func TestImportsUnsafe(t *testing.T) { } } +func TestImportName(t *testing.T) { + tests := []struct { + desc string + src string + want string + }{ + { + desc: `named import`, + src: `import foo "some/other/bar"`, + want: `foo`, + }, { + desc: `unnamed import`, + src: `import "some/other/bar"`, + want: `bar`, + }, { + desc: `dot import`, + src: `import . "some/other/bar"`, + want: ``, + }, { + desc: `blank import`, + src: `import _ "some/other/bar"`, + want: ``, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + src := "package testpackage\n\n" + test.src + fset := token.NewFileSet() + file := srctesting.Parse(t, fset, src) + if len(file.Imports) != 1 { + t.Fatal(`expected one and only one import`) + } + importSpec := file.Imports[0] + got := ImportName(importSpec) + if got != test.want { + t.Fatalf(`ImportName() returned %q, want %q`, got, test.want) + } + }) + } +} + func TestFuncKey(t *testing.T) { tests := []struct { desc string @@ -59,30 +103,47 @@ func TestFuncKey(t *testing.T) { want string }{ { - desc: "top-level function", - src: `package testpackage; func foo() {}`, - want: "foo", + desc: `top-level function`, + src: `func foo() {}`, + want: `foo`, }, { - desc: "top-level exported function", - src: `package testpackage; func Foo() {}`, - want: "Foo", + desc: `top-level exported function`, + src: `func Foo() {}`, + want: `Foo`, }, { - desc: "method", - src: `package testpackage; func (_ myType) bar() {}`, - want: "myType.bar", + desc: `method on reference`, + src: `func (_ myType) bar() {}`, + want: `myType.bar`, + }, { + desc: `method on pointer`, + src: ` func (_ *myType) bar() {}`, + want: `myType.bar`, + }, { + desc: `method on generic reference`, + src: ` func (_ myType[T]) bar() {}`, + want: `myType.bar`, + }, { + desc: `method on generic pointer`, + src: ` func (_ *myType[T]) bar() {}`, + want: `myType.bar`, + }, { + desc: `method on struct with multiple generics`, + src: ` func (_ *myType[T1, T2, T3, T4]) bar() {}`, + want: `myType.bar`, }, } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - fdecl := srctesting.ParseFuncDecl(t, test.src) + src := `package testpackage; ` + test.src + fdecl := srctesting.ParseFuncDecl(t, src) if got := FuncKey(fdecl); got != test.want { - t.Errorf("Got %q, want %q", got, test.want) + t.Errorf(`Got %q, want %q`, got, test.want) } }) } } -func TestPruneOriginal(t *testing.T) { +func TestKeepOriginal(t *testing.T) { tests := []struct { desc string src string @@ -102,20 +163,20 @@ func TestPruneOriginal(t *testing.T) { }, { desc: "only directive", src: `package testpackage; - //gopherjs:prune-original + //gopherjs:keep-original func foo() {}`, want: true, }, { desc: "directive with explanation", src: `package testpackage; - //gopherjs:prune-original because reasons + //gopherjs:keep-original because reasons func foo() {}`, want: true, }, { desc: "directive in godoc", src: `package testpackage; // foo does something - //gopherjs:prune-original + //gopherjs:keep-original func foo() {}`, want: true, }, @@ -123,8 +184,369 @@ func TestPruneOriginal(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { fdecl := srctesting.ParseFuncDecl(t, test.src) - if got := PruneOriginal(fdecl); got != test.want { - t.Errorf("PruneOriginal() returned %t, want %t", got, test.want) + if got := KeepOriginal(fdecl); got != test.want { + t.Errorf("KeepOriginal() returned %t, want %t", got, test.want) + } + }) + } +} + +func TestHasDirectiveOnDecl(t *testing.T) { + tests := []struct { + desc string + src string + want bool + }{ + { + desc: `no comment on function`, + src: `package testpackage; + func foo() {}`, + want: false, + }, { + desc: `no directive on function with comment`, + src: `package testpackage; + // foo has no directive + func foo() {}`, + want: false, + }, { + desc: `wrong directive on function`, + src: `package testpackage; + //gopherjs:wrong-directive + func foo() {}`, + want: false, + }, { + desc: `correct directive on function`, + src: `package testpackage; + //gopherjs:do-stuff + // foo has a directive to do stuff + func foo() {}`, + want: true, + }, { + desc: `correct directive in multiline comment on function`, + src: `package testpackage; + /*gopherjs:do-stuff + foo has a directive to do stuff + */ + func foo() {}`, + want: true, + }, { + desc: `invalid directive in multiline comment on function`, + src: `package testpackage; + /* + gopherjs:do-stuff + */ + func foo() {}`, + want: false, + }, { + desc: `prefix directive on function`, + src: `package testpackage; + //gopherjs:do-stuffs + func foo() {}`, + want: false, + }, { + desc: `multiple directives on function`, + src: `package testpackage; + //gopherjs:wrong-directive + //gopherjs:do-stuff + //gopherjs:another-directive + func foo() {}`, + want: true, + }, { + desc: `directive with explanation on function`, + src: `package testpackage; + //gopherjs:do-stuff 'cause we can + func foo() {}`, + want: true, + }, { + desc: `no directive on type declaration`, + src: `package testpackage; + // Foo has a comment + type Foo int`, + want: false, + }, { + desc: `directive on type declaration`, + src: `package testpackage; + //gopherjs:do-stuff + type Foo int`, + want: true, + }, { + desc: `no directive on const declaration`, + src: `package testpackage; + const foo = 42`, + want: false, + }, { + desc: `directive on const documentation`, + src: `package testpackage; + //gopherjs:do-stuff + const foo = 42`, + want: true, + }, { + desc: `no directive on var declaration`, + src: `package testpackage; + var foo = 42`, + want: false, + }, { + desc: `directive on var documentation`, + src: `package testpackage; + //gopherjs:do-stuff + var foo = 42`, + want: true, + }, { + desc: `no directive on var declaration`, + src: `package testpackage; + import _ "embed"`, + want: false, + }, { + desc: `directive on var documentation`, + src: `package testpackage; + //gopherjs:do-stuff + import _ "embed"`, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + const action = `do-stuff` + decl := srctesting.ParseDecl(t, test.src) + if got := hasDirective(decl, action); got != test.want { + t.Errorf(`hasDirective(%T, %q) returned %t, want %t`, decl, action, got, test.want) + } + }) + } +} + +func TestHasDirectiveOnSpec(t *testing.T) { + tests := []struct { + desc string + src string + want bool + }{ + { + desc: `no directive on type specification`, + src: `package testpackage; + type Foo int`, + want: false, + }, { + desc: `directive in doc on type specification`, + src: `package testpackage; + type ( + //gopherjs:do-stuff + Foo int + )`, + want: true, + }, { + desc: `directive in line on type specification`, + src: `package testpackage; + type Foo int //gopherjs:do-stuff`, + want: true, + }, { + desc: `no directive on const specification`, + src: `package testpackage; + const foo = 42`, + want: false, + }, { + desc: `directive in doc on const specification`, + src: `package testpackage; + const ( + //gopherjs:do-stuff + foo = 42 + )`, + want: true, + }, { + desc: `directive in line on const specification`, + src: `package testpackage; + const foo = 42 //gopherjs:do-stuff`, + want: true, + }, { + desc: `no directive on var specification`, + src: `package testpackage; + var foo = 42`, + want: false, + }, { + desc: `directive in doc on var specification`, + src: `package testpackage; + var ( + //gopherjs:do-stuff + foo = 42 + )`, + want: true, + }, { + desc: `directive in line on var specification`, + src: `package testpackage; + var foo = 42 //gopherjs:do-stuff`, + want: true, + }, { + desc: `no directive on import specification`, + src: `package testpackage; + import _ "embed"`, + want: false, + }, { + desc: `directive in doc on import specification`, + src: `package testpackage; + import ( + //gopherjs:do-stuff + _ "embed" + )`, + want: true, + }, { + desc: `directive in line on import specification`, + src: `package testpackage; + import _ "embed" //gopherjs:do-stuff`, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + const action = `do-stuff` + spec := srctesting.ParseSpec(t, test.src) + if got := hasDirective(spec, action); got != test.want { + t.Errorf(`hasDirective(%T, %q) returned %t, want %t`, spec, action, got, test.want) + } + }) + } +} + +func TestHasDirectiveOnFile(t *testing.T) { + tests := []struct { + desc string + src string + want bool + }{ + { + desc: `no directive on file`, + src: `package testpackage; + //gopherjs:do-stuff + type Foo int`, + want: false, + }, { + desc: `directive on file`, + src: `//gopherjs:do-stuff + package testpackage; + type Foo int`, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + const action = `do-stuff` + fset := token.NewFileSet() + file := srctesting.Parse(t, fset, test.src) + if got := hasDirective(file, action); got != test.want { + t.Errorf(`hasDirective(%T, %q) returned %t, want %t`, file, action, got, test.want) + } + }) + } +} + +func TestHasDirectiveOnField(t *testing.T) { + tests := []struct { + desc string + src string + want bool + }{ + { + desc: `no directive on struct field`, + src: `package testpackage; + type Foo struct { + bar int + }`, + want: false, + }, { + desc: `directive in doc on struct field`, + src: `package testpackage; + type Foo struct { + //gopherjs:do-stuff + bar int + }`, + want: true, + }, { + desc: `directive in line on struct field`, + src: `package testpackage; + type Foo struct { + bar int //gopherjs:do-stuff + }`, + want: true, + }, { + desc: `no directive on interface method`, + src: `package testpackage; + type Foo interface { + Bar(a int) int + }`, + want: false, + }, { + desc: `directive in doc on interface method`, + src: `package testpackage; + type Foo interface { + //gopherjs:do-stuff + Bar(a int) int + }`, + want: true, + }, { + desc: `directive in line on interface method`, + src: `package testpackage; + type Foo interface { + Bar(a int) int //gopherjs:do-stuff + }`, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + const action = `do-stuff` + spec := srctesting.ParseSpec(t, test.src) + tspec := spec.(*ast.TypeSpec) + var field *ast.Field + switch typeNode := tspec.Type.(type) { + case *ast.StructType: + field = typeNode.Fields.List[0] + case *ast.InterfaceType: + field = typeNode.Methods.List[0] + default: + t.Errorf(`unexpected node type, %T, when finding field`, typeNode) + return + } + if got := hasDirective(field, action); got != test.want { + t.Errorf(`hasDirective(%T, %q) returned %t, want %t`, field, action, got, test.want) + } + }) + } +} + +func TestHasDirectiveBadCase(t *testing.T) { + tests := []struct { + desc string + node any + want string + }{ + { + desc: `untyped nil node`, + node: nil, + want: `unexpected node type to get doc from: `, + }, { + desc: `unexpected node type`, + node: &ast.ArrayType{}, + want: `unexpected node type to get doc from: *ast.ArrayType`, + }, { + desc: `nil expected node type`, + node: (*ast.FuncDecl)(nil), + want: ``, // no panic + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + const action = `do-stuff` + var got string + func() { + defer func() { got = fmt.Sprint(recover()) }() + hasDirective(test.node, action) + }() + if got != test.want { + t.Errorf(`hasDirective(%T, %q) returned %s, want %s`, test.node, action, got, test.want) } }) } @@ -181,3 +603,61 @@ func TestEndsWithReturn(t *testing.T) { }) } } + +func TestSqueezeIdents(t *testing.T) { + tests := []struct { + desc string + count int + assign []int + }{ + { + desc: `no squeezing`, + count: 5, + assign: []int{0, 1, 2, 3, 4}, + }, { + desc: `missing front`, + count: 5, + assign: []int{3, 4}, + }, { + desc: `missing back`, + count: 5, + assign: []int{0, 1, 2}, + }, { + desc: `missing several`, + count: 10, + assign: []int{1, 2, 3, 6, 8}, + }, { + desc: `empty`, + count: 0, + assign: []int{}, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + input := make([]*ast.Ident, test.count) + for _, i := range test.assign { + input[i] = ast.NewIdent(strconv.Itoa(i)) + } + + result := Squeeze(input) + if len(result) != len(test.assign) { + t.Errorf("Squeeze() returned a slice %d long, want %d", len(result), len(test.assign)) + } + for i, id := range input { + if i < len(result) { + if id == nil { + t.Errorf(`Squeeze() returned a nil in result at %d`, i) + } else { + value, err := strconv.Atoi(id.Name) + if err != nil || value != test.assign[i] { + t.Errorf(`Squeeze() returned %s at %d instead of %d`, id.Name, i, test.assign[i]) + } + } + } else if id != nil { + t.Errorf(`Squeeze() didn't clear out tail of slice, want %d nil`, i) + } + } + }) + } +} diff --git a/compiler/natives/src/net/http/http.go b/compiler/natives/src/net/http/http.go index 7843235b2..8fd607c4d 100644 --- a/compiler/natives/src/net/http/http.go +++ b/compiler/natives/src/net/http/http.go @@ -7,7 +7,7 @@ import ( "bufio" "bytes" "errors" - "io/ioutil" + "io" "net/textproto" "strconv" @@ -68,7 +68,7 @@ func (t *XHRTransport) RoundTrip(req *Request) (*Response, error) { StatusCode: xhr.Get("status").Int(), Header: Header(header), ContentLength: contentLength, - Body: ioutil.NopCloser(bytes.NewReader(body)), + Body: io.NopCloser(bytes.NewReader(body)), Request: req, } }) @@ -91,7 +91,7 @@ func (t *XHRTransport) RoundTrip(req *Request) (*Response, error) { if req.Body == nil { xhr.Call("send") } else { - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err != nil { req.Body.Close() // RoundTrip must always close the body, including on errors. return nil, err diff --git a/internal/srctesting/srctesting.go b/internal/srctesting/srctesting.go index 1d9cecd20..4e374845e 100644 --- a/internal/srctesting/srctesting.go +++ b/internal/srctesting/srctesting.go @@ -53,17 +53,45 @@ func Check(t *testing.T, fset *token.FileSet, files ...*ast.File) (*types.Info, // // Fails the test if there isn't exactly one function declared in the source. func ParseFuncDecl(t *testing.T, src string) *ast.FuncDecl { + t.Helper() + decl := ParseDecl(t, src) + fdecl, ok := decl.(*ast.FuncDecl) + if !ok { + t.Fatalf("Got %T decl, expected *ast.FuncDecl", decl) + } + return fdecl +} + +// ParseDecl parses source with a single declaration and +// returns that declaration AST. +// +// Fails the test if there isn't exactly one declaration in the source. +func ParseDecl(t *testing.T, src string) ast.Decl { t.Helper() fset := token.NewFileSet() file := Parse(t, fset, src) if l := len(file.Decls); l != 1 { - t.Fatalf("Got %d decls in the sources, expected exactly 1", l) + t.Fatalf(`Got %d decls in the sources, expected exactly 1`, l) } - fdecl, ok := file.Decls[0].(*ast.FuncDecl) + return file.Decls[0] +} + +// ParseSpec parses source with a single declaration containing +// a single specification and returns that specification AST. +// +// Fails the test if there isn't exactly one declaration and +// one specification in the source. +func ParseSpec(t *testing.T, src string) ast.Spec { + t.Helper() + decl := ParseDecl(t, src) + gdecl, ok := decl.(*ast.GenDecl) if !ok { - t.Fatalf("Got %T decl, expected *ast.FuncDecl", file.Decls[0]) + t.Fatalf("Got %T decl, expected *ast.GenDecl", decl) } - return fdecl + if l := len(gdecl.Specs); l != 1 { + t.Fatalf(`Got %d spec in the sources, expected exactly 1`, l) + } + return gdecl.Specs[0] } // Format AST node into a string.