aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim King <taking@google.com>2023-02-13 20:12:58 -0800
committerTim King <taking@google.com>2023-02-17 19:56:39 +0000
commit193023cca0b707b4d668e5d332649cba285e99f7 (patch)
tree3f6ed57585407455df4aca8487675b0ec6821426
parent3102dad5faf9b21d36f1e58748b62515ef1ee194 (diff)
downloadgolang-x-tools-193023cca0b707b4d668e5d332649cba285e99f7.tar.gz
go/ssa: substitute type parameters in local types
ssa.subster now supports substituting type parameters that are within local Named types. Generalizes the definition of isParameterized to support local Named types that contain type parameters that are not instantiated. Fixes golang/go#58491 Updates golang/go#58573 Change-Id: Id7a4e45c41126056105e483b59eb057c05786d66 Reviewed-on: https://go-review.googlesource.com/c/tools/+/467995 TryBot-Result: Gopher Robot <gobot@golang.org> gopls-CI: kokoro <noreply+kokoro@google.com> Reviewed-by: Robert Findley <rfindley@google.com> Reviewed-by: Alan Donovan <adonovan@google.com> Run-TryBot: Tim King <taking@google.com>
-rw-r--r--go/ssa/builder_go120_test.go2
-rw-r--r--go/ssa/builder_test.go111
-rw-r--r--go/ssa/instantiate.go3
-rw-r--r--go/ssa/parameterized.go3
-rw-r--r--go/ssa/ssa.go2
-rw-r--r--go/ssa/stdlib_test.go12
-rw-r--r--go/ssa/subst.go89
-rw-r--r--go/ssa/subst_test.go2
8 files changed, 179 insertions, 45 deletions
diff --git a/go/ssa/builder_go120_test.go b/go/ssa/builder_go120_test.go
index a691f938c..acdd182c5 100644
--- a/go/ssa/builder_go120_test.go
+++ b/go/ssa/builder_go120_test.go
@@ -85,7 +85,7 @@ func TestBuildPackageGo120(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fset := token.NewFileSet()
- f, err := parser.ParseFile(fset, "p.go", tc.src, parser.ParseComments)
+ f, err := parser.ParseFile(fset, "p.go", tc.src, 0)
if err != nil {
t.Error(err)
}
diff --git a/go/ssa/builder_test.go b/go/ssa/builder_test.go
index fd2d21c5d..b3bb09c5e 100644
--- a/go/ssa/builder_test.go
+++ b/go/ssa/builder_test.go
@@ -894,3 +894,114 @@ func TestGenericFunctionSelector(t *testing.T) {
}
}
}
+
+func TestIssue58491(t *testing.T) {
+ // Test that a local type reaches type param in instantiation.
+ testenv.NeedsGo1Point(t, 18)
+ src := `
+ package p
+
+ func foo[T any](blocking func() (T, error)) error {
+ type result struct {
+ res T
+ error // ensure the method set of result is non-empty
+ }
+
+ res := make(chan result, 1)
+ go func() {
+ var r result
+ r.res, r.error = blocking()
+ res <- r
+ }()
+ r := <-res
+ err := r // require the rtype for result when instantiated
+ return err
+ }
+ var Inst = foo[int]
+ `
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, "p.go", src, 0)
+ if err != nil {
+ t.Error(err)
+ }
+ files := []*ast.File{f}
+
+ pkg := types.NewPackage("p", "")
+ conf := &types.Config{}
+ p, _, err := ssautil.BuildPackage(conf, fset, pkg, files, ssa.SanityCheckFunctions|ssa.InstantiateGenerics)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Find the local type result instantiated with int.
+ var found bool
+ for _, rt := range p.Prog.RuntimeTypes() {
+ if n, ok := rt.(*types.Named); ok {
+ if u, ok := n.Underlying().(*types.Struct); ok {
+ found = true
+ if got, want := n.String(), "p.result"; got != want {
+ t.Errorf("Expected the name %s got: %s", want, got)
+ }
+ if got, want := u.String(), "struct{res int; error}"; got != want {
+ t.Errorf("Expected the underlying type of %s to be %s. got %s", n, want, got)
+ }
+ }
+ }
+ }
+ if !found {
+ t.Error("Failed to find any Named to struct types")
+ }
+}
+
+func TestIssue58491Rec(t *testing.T) {
+ // Roughly the same as TestIssue58491 but with a recursive type.
+ testenv.NeedsGo1Point(t, 18)
+ src := `
+ package p
+
+ func foo[T any]() error {
+ type result struct {
+ res T
+ next *result
+ error // ensure the method set of result is non-empty
+ }
+
+ r := &result{}
+ err := r // require the rtype for result when instantiated
+ return err
+ }
+ var Inst = foo[int]
+ `
+ fset := token.NewFileSet()
+ f, err := parser.ParseFile(fset, "p.go", src, 0)
+ if err != nil {
+ t.Error(err)
+ }
+ files := []*ast.File{f}
+
+ pkg := types.NewPackage("p", "")
+ conf := &types.Config{}
+ p, _, err := ssautil.BuildPackage(conf, fset, pkg, files, ssa.SanityCheckFunctions|ssa.InstantiateGenerics)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Find the local type result instantiated with int.
+ var found bool
+ for _, rt := range p.Prog.RuntimeTypes() {
+ if n, ok := rt.(*types.Named); ok {
+ if u, ok := n.Underlying().(*types.Struct); ok {
+ found = true
+ if got, want := n.String(), "p.result"; got != want {
+ t.Errorf("Expected the name %s got: %s", want, got)
+ }
+ if got, want := u.String(), "struct{res int; next *p.result; error}"; got != want {
+ t.Errorf("Expected the underlying type of %s to be %s. got %s", n, want, got)
+ }
+ }
+ }
+ }
+ if !found {
+ t.Error("Failed to find any Named to struct types")
+ }
+}
diff --git a/go/ssa/instantiate.go b/go/ssa/instantiate.go
index f6b2533f2..38249dea2 100644
--- a/go/ssa/instantiate.go
+++ b/go/ssa/instantiate.go
@@ -148,7 +148,8 @@ func (insts *instanceSet) lookupOrCreate(targs []types.Type, parameterized *tpWa
if prog.mode&InstantiateGenerics != 0 && concrete {
synthetic = fmt.Sprintf("instance of %s", fn.Name())
- subst = makeSubster(prog.ctxt, fn.typeparams, targs, false)
+ scope := typeparams.OriginMethod(obj).Scope()
+ subst = makeSubster(prog.ctxt, scope, fn.typeparams, targs, false)
} else {
synthetic = fmt.Sprintf("instantiation wrapper of %s", fn.Name())
}
diff --git a/go/ssa/parameterized.go b/go/ssa/parameterized.go
index b11413c81..3fc4348fc 100644
--- a/go/ssa/parameterized.go
+++ b/go/ssa/parameterized.go
@@ -17,7 +17,7 @@ type tpWalker struct {
seen map[types.Type]bool
}
-// isParameterized returns true when typ contains any type parameters.
+// isParameterized returns true when typ reaches any type parameter.
func (w *tpWalker) isParameterized(typ types.Type) (res bool) {
// NOTE: Adapted from go/types/infer.go. Try to keep in sync.
@@ -101,6 +101,7 @@ func (w *tpWalker) isParameterized(typ types.Type) (res bool) {
return true
}
}
+ return w.isParameterized(t.Underlying()) // recurse for types local to parameterized functions
case *typeparams.TypeParam:
return true
diff --git a/go/ssa/ssa.go b/go/ssa/ssa.go
index 5904b817b..c3471c156 100644
--- a/go/ssa/ssa.go
+++ b/go/ssa/ssa.go
@@ -36,7 +36,7 @@ type Program struct {
bounds map[boundsKey]*Function // bounds for curried x.Method closures
thunks map[selectionKey]*Function // thunks for T.Method expressions
instances map[*Function]*instanceSet // instances of generic functions
- parameterized tpWalker // determines whether a type is parameterized.
+ parameterized tpWalker // determines whether a type reaches a type parameter.
}
// A Package is a single analyzed Go package containing Members for
diff --git a/go/ssa/stdlib_test.go b/go/ssa/stdlib_test.go
index f85f220b4..8b9f4238d 100644
--- a/go/ssa/stdlib_test.go
+++ b/go/ssa/stdlib_test.go
@@ -50,18 +50,6 @@ func TestStdlib(t *testing.T) {
t.Fatal(err)
}
- // TODO(golang/go#58491): fix breakage on new generic code in net.
- //
- // For now, exclude the 'net' package.
- netIdx := -1
- for i, pkg := range pkgs {
- if pkg.PkgPath == "net" {
- netIdx = i
- break
- }
- }
- pkgs = append(pkgs[:netIdx], pkgs[netIdx+1:]...)
-
t1 := time.Now()
alloc1 := bytesAllocated()
diff --git a/go/ssa/subst.go b/go/ssa/subst.go
index d7f8ae4a7..7efab3578 100644
--- a/go/ssa/subst.go
+++ b/go/ssa/subst.go
@@ -5,7 +5,6 @@
package ssa
import (
- "fmt"
"go/types"
"golang.org/x/tools/internal/typeparams"
@@ -19,41 +18,42 @@ import (
//
// Not concurrency-safe.
type subster struct {
- // TODO(zpavlinovic): replacements can contain type params
- // when generating instances inside of a generic function body.
replacements map[*typeparams.TypeParam]types.Type // values should contain no type params
cache map[types.Type]types.Type // cache of subst results
- ctxt *typeparams.Context
- debug bool // perform extra debugging checks
+ ctxt *typeparams.Context // cache for instantiation
+ scope *types.Scope // *types.Named declared within this scope can be substituted (optional)
+ debug bool // perform extra debugging checks
// TODO(taking): consider adding Pos
+ // TODO(zpavlinovic): replacements can contain type params
+ // when generating instances inside of a generic function body.
}
// Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
// targs should not contain any types in tparams.
-func makeSubster(ctxt *typeparams.Context, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster {
+// scope is the (optional) lexical block of the generic function for which we are substituting.
+func makeSubster(ctxt *typeparams.Context, scope *types.Scope, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster {
assert(tparams.Len() == len(targs), "makeSubster argument count must match")
subst := &subster{
replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()),
cache: make(map[types.Type]types.Type),
ctxt: ctxt,
+ scope: scope,
debug: debug,
}
for i := 0; i < tparams.Len(); i++ {
subst.replacements[tparams.At(i)] = targs[i]
}
if subst.debug {
- if err := subst.wellFormed(); err != nil {
- panic(err)
- }
+ subst.wellFormed()
}
return subst
}
-// wellFormed returns an error if subst was not properly initialized.
-func (subst *subster) wellFormed() error {
- if subst == nil || len(subst.replacements) == 0 {
- return nil
+// wellFormed asserts that subst was properly initialized.
+func (subst *subster) wellFormed() {
+ if subst == nil {
+ return
}
// Check that all of the type params do not appear in the arguments.
s := make(map[types.Type]bool, len(subst.replacements))
@@ -62,10 +62,9 @@ func (subst *subster) wellFormed() error {
}
for _, r := range subst.replacements {
if reaches(r, s) {
- return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements)
+ panic(subst)
}
}
- return nil
}
// typ returns the type of t with the type parameter tparams[i] substituted
@@ -306,29 +305,56 @@ func (subst *subster) interface_(iface *types.Interface) *types.Interface {
}
func (subst *subster) named(t *types.Named) types.Type {
- // A name type may be:
- // (1) ordinary (no type parameters, no type arguments),
- // (2) generic (type parameters but no type arguments), or
- // (3) instantiated (type parameters and type arguments).
+ // A named type may be:
+ // (1) ordinary named type (non-local scope, no type parameters, no type arguments),
+ // (2) locally scoped type,
+ // (3) generic (type parameters but no type arguments), or
+ // (4) instantiated (type parameters and type arguments).
tparams := typeparams.ForNamed(t)
if tparams.Len() == 0 {
- // case (1) ordinary
+ if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
+ // Outside the current function scope?
+ return t // case (1) ordinary
+ }
- // Note: If Go allows for local type declarations in generic
- // functions we may need to descend into underlying as well.
- return t
+ // case (2) locally scoped type.
+ // Create a new named type to represent this instantiation.
+ // We assume that local types of distinct instantiations of a
+ // generic function are distinct, even if they don't refer to
+ // type parameters, but the spec is unclear; see golang/go#58573.
+ //
+ // Subtle: We short circuit substitution and use a newly created type in
+ // subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive
+ // types during traversal. This both breaks infinite cycles and allows for
+ // constructing types with the replacement applied in subst.typ(under).
+ //
+ // Example:
+ // func foo[T any]() {
+ // type linkedlist struct {
+ // next *linkedlist
+ // val T
+ // }
+ // }
+ //
+ // When the field `next *linkedlist` is visited during subst.typ(under),
+ // we want the substituted type for the field `next` to be `*n`.
+ n := types.NewNamed(t.Obj(), nil, nil)
+ subst.cache[t] = n
+ subst.cache[n] = n
+ n.SetUnderlying(subst.typ(t.Underlying()))
+ return n
}
targs := typeparams.NamedTypeArgs(t)
// insts are arguments to instantiate using.
insts := make([]types.Type, tparams.Len())
- // case (2) generic ==> targs.Len() == 0
+ // case (3) generic ==> targs.Len() == 0
// Instantiating a generic with no type arguments should be unreachable.
// Please report a bug if you encounter this.
assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
- // case (3) instantiated.
+ // case (4) instantiated.
// Substitute into the type arguments and instantiate the replacements/
// Example:
// type N[A any] func() A
@@ -378,19 +404,26 @@ func (subst *subster) signature(t *types.Signature) types.Type {
}
// reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
-// Updates c to cache results.
+// It updates c to cache results.
+//
+// reaches is currently only part of the wellFormed debug logic, and
+// in practice c is initially only type parameters. It is not currently
+// relied on in production.
func reaches(t types.Type, c map[types.Type]bool) (res bool) {
if c, ok := c[t]; ok {
return c
}
- c[t] = false // prevent cycles
+
+ // c is populated with temporary false entries as types are visited.
+ // This avoids repeat visits and break cycles.
+ c[t] = false
defer func() {
c[t] = res
}()
switch t := t.(type) {
case *typeparams.TypeParam, *types.Basic:
- // no-op => c == false
+ return false
case *types.Array:
return reaches(t.Elem(), c)
case *types.Slice:
diff --git a/go/ssa/subst_test.go b/go/ssa/subst_test.go
index 5fa882700..14cda54e6 100644
--- a/go/ssa/subst_test.go
+++ b/go/ssa/subst_test.go
@@ -100,7 +100,7 @@ var _ L[int] = Fn0[L[int]](nil)
T := tv.Type.(*types.Named)
- subst := makeSubster(typeparams.NewContext(), typeparams.ForNamed(T), targs, true)
+ subst := makeSubster(typeparams.NewContext(), nil, typeparams.ForNamed(T), targs, true)
sub := subst.typ(T.Underlying())
if got := sub.String(); got != test.want {
t.Errorf("subst{%v->%v}.typ(%s) = %v, want %v", test.expr, test.args, T.Underlying(), got, test.want)