Skip to content

Commit 6a6fd99

Browse files
committed
go/ssa: substitute type parameterized aliases
Adds support to substitute type parameterized aliases in generic functions. Change-Id: I4fb2e5f5fd9b626781efdc4db808c52cb22ba241 Reviewed-on: https://go-review.googlesource.com/c/tools/+/602195 Reviewed-by: Alan Donovan <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent f6a2390 commit 6a6fd99

File tree

6 files changed

+273
-26
lines changed

6 files changed

+273
-26
lines changed

go/ssa/builder_generic_test.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,13 @@ func TestGenericBodies(t *testing.T) {
550550
}
551551

552552
// Collect calls to the builtin print function.
553-
probes := callsTo(p, "print")
553+
fns := make(map[*ssa.Function]bool)
554+
for _, mem := range p.Members {
555+
if fn, ok := mem.(*ssa.Function); ok {
556+
fns[fn] = true
557+
}
558+
}
559+
probes := callsTo(fns, "print")
554560
expectations := matchNotes(prog.Fset, notes, probes)
555561

556562
for call := range probes {
@@ -576,17 +582,15 @@ func TestGenericBodies(t *testing.T) {
576582

577583
// callsTo finds all calls to an SSA value named fname,
578584
// and returns a map from each call site to its enclosing function.
579-
func callsTo(p *ssa.Package, fname string) map[*ssa.CallCommon]*ssa.Function {
585+
func callsTo(fns map[*ssa.Function]bool, fname string) map[*ssa.CallCommon]*ssa.Function {
580586
callsites := make(map[*ssa.CallCommon]*ssa.Function)
581-
for _, mem := range p.Members {
582-
if fn, ok := mem.(*ssa.Function); ok {
583-
for _, bb := range fn.Blocks {
584-
for _, i := range bb.Instrs {
585-
if i, ok := i.(ssa.CallInstruction); ok {
586-
call := i.Common()
587-
if call.Value.Name() == fname {
588-
callsites[call] = fn
589-
}
587+
for fn := range fns {
588+
for _, bb := range fn.Blocks {
589+
for _, i := range bb.Instrs {
590+
if i, ok := i.(ssa.CallInstruction); ok {
591+
call := i.Common()
592+
if call.Value.Name() == fname {
593+
callsites[call] = fn
590594
}
591595
}
592596
}

go/ssa/builder_go122_test.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,13 @@ func TestRangeOverInt(t *testing.T) {
168168
}
169169

170170
// Collect calls to the built-in print function.
171-
probes := callsTo(p, "print")
171+
fns := make(map[*ssa.Function]bool)
172+
for _, mem := range p.Members {
173+
if fn, ok := mem.(*ssa.Function); ok {
174+
fns[fn] = true
175+
}
176+
}
177+
probes := callsTo(fns, "print")
172178
expectations := matchNotes(fset, notes, probes)
173179

174180
for call := range probes {

go/ssa/builder_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"go/token"
1515
"go/types"
1616
"os"
17+
"os/exec"
1718
"path/filepath"
1819
"reflect"
1920
"sort"
@@ -1260,3 +1261,143 @@ func TestIssue67079(t *testing.T) {
12601261

12611262
g.Wait() // ignore error
12621263
}
1264+
1265+
func TestGenericAliases(t *testing.T) {
1266+
testenv.NeedsGo1Point(t, 23)
1267+
1268+
if os.Getenv("GENERICALIASTEST_CHILD") == "1" {
1269+
testGenericAliases(t)
1270+
return
1271+
}
1272+
1273+
testenv.NeedsExec(t)
1274+
testenv.NeedsTool(t, "go")
1275+
1276+
cmd := exec.Command(os.Args[0], "-test.run=TestGenericAliases")
1277+
cmd.Env = append(os.Environ(),
1278+
"GENERICALIASTEST_CHILD=1",
1279+
"GODEBUG=gotypesalias=1",
1280+
"GOEXPERIMENT=aliastypeparams",
1281+
)
1282+
out, err := cmd.CombinedOutput()
1283+
if len(out) > 0 {
1284+
t.Logf("out=<<%s>>", out)
1285+
}
1286+
var exitcode int
1287+
if err, ok := err.(*exec.ExitError); ok {
1288+
exitcode = err.ExitCode()
1289+
}
1290+
const want = 0
1291+
if exitcode != want {
1292+
t.Errorf("exited %d, want %d", exitcode, want)
1293+
}
1294+
}
1295+
1296+
func testGenericAliases(t *testing.T) {
1297+
t.Setenv("GOEXPERIMENT", "aliastypeparams=1")
1298+
1299+
const source = `
1300+
package P
1301+
1302+
type A = uint8
1303+
type B[T any] = [4]T
1304+
1305+
var F = f[string]
1306+
1307+
func f[S any]() {
1308+
// Two copies of f are made: p.f[S] and p.f[string]
1309+
1310+
var v A // application of A that is declared outside of f without no type arguments
1311+
print("p.f", "String", "p.A", v)
1312+
print("p.f", "==", v, uint8(0))
1313+
print("p.f[string]", "String", "p.A", v)
1314+
print("p.f[string]", "==", v, uint8(0))
1315+
1316+
1317+
var u B[S] // application of B that is declared outside declared outside of f with type arguments
1318+
print("p.f", "String", "p.B[S]", u)
1319+
print("p.f", "==", u, [4]S{})
1320+
print("p.f[string]", "String", "p.B[string]", u)
1321+
print("p.f[string]", "==", u, [4]string{})
1322+
1323+
type C[T any] = struct{ s S; ap *B[T]} // declaration within f with type params
1324+
var w C[int] // application of C with type arguments
1325+
print("p.f", "String", "p.C[int]", w)
1326+
print("p.f", "==", w, struct{ s S; ap *[4]int}{})
1327+
print("p.f[string]", "String", "p.C[int]", w)
1328+
print("p.f[string]", "==", w, struct{ s string; ap *[4]int}{})
1329+
}
1330+
`
1331+
1332+
conf := loader.Config{Fset: token.NewFileSet()}
1333+
f, err := parser.ParseFile(conf.Fset, "p.go", source, 0)
1334+
if err != nil {
1335+
t.Fatal(err)
1336+
}
1337+
conf.CreateFromFiles("p", f)
1338+
iprog, err := conf.Load()
1339+
if err != nil {
1340+
t.Fatal(err)
1341+
}
1342+
1343+
// Create and build SSA program.
1344+
prog := ssautil.CreateProgram(iprog, ssa.InstantiateGenerics)
1345+
prog.Build()
1346+
1347+
probes := callsTo(ssautil.AllFunctions(prog), "print")
1348+
if got, want := len(probes), 3*4*2; got != want {
1349+
t.Errorf("Found %v probes, expected %v", got, want)
1350+
}
1351+
1352+
const debug = false // enable to debug skips
1353+
skipped := 0
1354+
for probe, fn := range probes {
1355+
// Each probe is of the form:
1356+
// print("within", "test", head, tail)
1357+
// The probe only matches within a function whose fn.String() is within.
1358+
// This allows for different instantiations of fn to match different probes.
1359+
// On a match, it applies the test named "test" to head::tail.
1360+
if len(probe.Args) < 3 {
1361+
t.Fatalf("probe %v did not have enough arguments", probe)
1362+
}
1363+
within, test, head, tail := constString(probe.Args[0]), probe.Args[1], probe.Args[2], probe.Args[3:]
1364+
if within != fn.String() {
1365+
skipped++
1366+
if debug {
1367+
t.Logf("Skipping %q within %q", within, fn.String())
1368+
}
1369+
continue // does not match function
1370+
}
1371+
1372+
switch test := constString(test); test {
1373+
case "==": // All of the values are types.Identical.
1374+
for _, v := range tail {
1375+
if !types.Identical(head.Type(), v.Type()) {
1376+
t.Errorf("Expected %v and %v to have identical types", head, v)
1377+
}
1378+
}
1379+
case "String": // head is a string constant that all values in tail must match Type().String()
1380+
want := constString(head)
1381+
for _, v := range tail {
1382+
if got := v.Type().String(); got != want {
1383+
t.Errorf("%s: %v had the Type().String()=%q. expected %q", within, v, got, want)
1384+
}
1385+
}
1386+
default:
1387+
t.Errorf("%q is not a test subcommand", test)
1388+
}
1389+
}
1390+
if want := 3 * 4; skipped != want {
1391+
t.Errorf("Skipped %d probes, expected to skip %d", skipped, want)
1392+
}
1393+
}
1394+
1395+
// constString returns the value of a string constant
1396+
// or "<not a constant string>" if the value is not a string constant.
1397+
func constString(v ssa.Value) string {
1398+
if c, ok := v.(*ssa.Const); ok {
1399+
str := c.Value.String()
1400+
return strings.Trim(str, `"`)
1401+
}
1402+
return "<not a constant string>"
1403+
}

go/ssa/subst.go

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,15 +318,80 @@ func (subst *subster) interface_(iface *types.Interface) *types.Interface {
318318
}
319319

320320
func (subst *subster) alias(t *aliases.Alias) types.Type {
321-
// TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
322-
u := aliases.Unalias(t)
323-
if s := subst.typ(u); s != u {
324-
// If there is any change, do not create a new alias.
325-
return s
321+
// See subster.named. This follows the same strategy.
322+
tparams := aliases.TypeParams(t)
323+
targs := aliases.TypeArgs(t)
324+
tname := t.Obj()
325+
torigin := aliases.Origin(t)
326+
327+
if !declaredWithin(tname, subst.origin) {
328+
// t is declared outside of the function origin. So t is a package level type alias.
329+
if targs.Len() == 0 {
330+
// No type arguments so no instantiation needed.
331+
return t
332+
}
333+
334+
// Instantiate with the substituted type arguments.
335+
newTArgs := subst.typelist(targs)
336+
return subst.instantiate(torigin, newTArgs)
326337
}
327-
// If there is no change, t did not reach any type parameter.
328-
// Keep the Alias.
329-
return t
338+
339+
if targs.Len() == 0 {
340+
// t is declared within the function origin and has no type arguments.
341+
//
342+
// Example: This corresponds to A or B in F, but not A[int]:
343+
//
344+
// func F[T any]() {
345+
// type A[S any] = struct{t T, s S}
346+
// type B = T
347+
// var x A[int]
348+
// ...
349+
// }
350+
//
351+
// This is somewhat different than *Named as *Alias cannot be created recursively.
352+
353+
// Copy and substitute type params.
354+
var newTParams []*types.TypeParam
355+
for i := 0; i < tparams.Len(); i++ {
356+
cur := tparams.At(i)
357+
cobj := cur.Obj()
358+
cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
359+
ntp := types.NewTypeParam(cname, nil)
360+
subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
361+
newTParams = append(newTParams, ntp)
362+
}
363+
364+
// Substitute rhs.
365+
rhs := subst.typ(aliases.Rhs(t))
366+
367+
// Create the fresh alias.
368+
obj := aliases.NewAlias(true, tname.Pos(), tname.Pkg(), tname.Name(), rhs)
369+
fresh := obj.Type()
370+
if fresh, ok := fresh.(*aliases.Alias); ok {
371+
// TODO: assume ok when aliases are always materialized (go1.27).
372+
aliases.SetTypeParams(fresh, newTParams)
373+
}
374+
375+
// Substitute into all of the constraints after they are created.
376+
for i, ntp := range newTParams {
377+
bound := tparams.At(i).Constraint()
378+
ntp.SetConstraint(subst.typ(bound))
379+
}
380+
return fresh
381+
}
382+
383+
// t is declared within the function origin and has type arguments.
384+
//
385+
// Example: This corresponds to A[int] in F. Cases A and B are handled above.
386+
// func F[T any]() {
387+
// type A[S any] = struct{t T, s S}
388+
// type B = T
389+
// var x A[int]
390+
// ...
391+
// }
392+
subOrigin := subst.typ(torigin)
393+
subTArgs := subst.typelist(targs)
394+
return subst.instantiate(subOrigin, subTArgs)
330395
}
331396

332397
func (subst *subster) named(t *types.Named) types.Type {
@@ -456,7 +521,7 @@ func (subst *subster) named(t *types.Named) types.Type {
456521

457522
func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
458523
i, err := types.Instantiate(subst.ctxt, orig, targs, false)
459-
assert(err == nil, "failed to Instantiate Named type")
524+
assert(err == nil, "failed to Instantiate named (Named or Alias) type")
460525
if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
461526
return c.(types.Type)
462527
}

internal/aliases/aliases_go121.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@ import (
1515
// It will never be created by go/types.
1616
type Alias struct{}
1717

18-
func (*Alias) String() string { panic("unreachable") }
19-
func (*Alias) Underlying() types.Type { panic("unreachable") }
20-
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
21-
func Rhs(alias *Alias) types.Type { panic("unreachable") }
22-
func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
18+
func (*Alias) String() string { panic("unreachable") }
19+
func (*Alias) Underlying() types.Type { panic("unreachable") }
20+
func (*Alias) Obj() *types.TypeName { panic("unreachable") }
21+
func Rhs(alias *Alias) types.Type { panic("unreachable") }
22+
func TypeParams(alias *Alias) *types.TypeParamList { panic("unreachable") }
23+
func SetTypeParams(alias *Alias, tparams []*types.TypeParam) { panic("unreachable") }
24+
func TypeArgs(alias *Alias) *types.TypeList { panic("unreachable") }
25+
func Origin(alias *Alias) *Alias { panic("unreachable") }
2326

2427
// Unalias returns the type t for go <=1.21.
2528
func Unalias(t types.Type) types.Type { return t }

internal/aliases/aliases_go122.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,34 @@ func TypeParams(alias *Alias) *types.TypeParamList {
3636
return nil
3737
}
3838

39+
// SetTypeParams sets the type parameters of the alias type.
40+
func SetTypeParams(alias *Alias, tparams []*types.TypeParam) {
41+
if alias, ok := any(alias).(interface {
42+
SetTypeParams(tparams []*types.TypeParam)
43+
}); ok {
44+
alias.SetTypeParams(tparams) // go1.23+
45+
} else if len(tparams) > 0 {
46+
panic("cannot set type parameters of an Alias type in go1.22")
47+
}
48+
}
49+
50+
// TypeArgs returns the type arguments used to instantiate the Alias type.
51+
func TypeArgs(alias *Alias) *types.TypeList {
52+
if alias, ok := any(alias).(interface{ TypeArgs() *types.TypeList }); ok {
53+
return alias.TypeArgs() // go1.23+
54+
}
55+
return nil // empty (go1.22)
56+
}
57+
58+
// Origin returns the generic Alias type of which alias is an instance.
59+
// If alias is not an instance of a generic alias, Origin returns alias.
60+
func Origin(alias *Alias) *Alias {
61+
if alias, ok := any(alias).(interface{ Origin() *types.Alias }); ok {
62+
return alias.Origin() // go1.23+
63+
}
64+
return alias // not an instance of a generic alias (go1.22)
65+
}
66+
3967
// Unalias is a wrapper of types.Unalias.
4068
func Unalias(t types.Type) types.Type { return types.Unalias(t) }
4169

0 commit comments

Comments
 (0)