Skip to content

Commit db164d7

Browse files
committed
Ensure we are using mapped SIL type for switch_enum case and not the
original lowered one. Fixes #73018
1 parent c3488c6 commit db164d7

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,7 +2998,8 @@ namespace {
29982998
}
29992999

30003000
static void runPeepholesAndReg2Mem(SILPassManager *pm, SILModule *silMod,
3001-
IRGenModule *irgenModule);
3001+
IRGenModule *irgenModule,
3002+
LargeSILTypeMapper &mapperCache);
30023003

30033004
/// The entry point to this function transformation.
30043005
void LoadableByAddress::run() {
@@ -3010,7 +3011,8 @@ void LoadableByAddress::run() {
30103011
runOnFunction(&F);
30113012

30123013
if (modFuncs.empty() && modApplies.empty()) {
3013-
runPeepholesAndReg2Mem(getPassManager(), getModule(), getIRGenModule());
3014+
runPeepholesAndReg2Mem(getPassManager(), getModule(), getIRGenModule(),
3015+
MapperCache);
30143016
return;
30153017
}
30163018

@@ -3248,7 +3250,8 @@ void LoadableByAddress::run() {
32483250
storeToBlockStorageInstrs.clear();
32493251

32503252
getPassManager()->invalidateAllAnalysis();
3251-
runPeepholesAndReg2Mem(getPassManager(), getModule(), getIRGenModule());
3253+
runPeepholesAndReg2Mem(getPassManager(), getModule(), getIRGenModule(),
3254+
MapperCache);
32523255
}
32533256

32543257
namespace {
@@ -3396,11 +3399,13 @@ class AddressAssignment {
33963399
GenericEnvironment *genEnv;
33973400
IRGenModule *irgenModule;
33983401
SILFunction &currFn;
3402+
LargeSILTypeMapper &mapperCache;
33993403

34003404
public:
34013405
AddressAssignment(IRGenModule *irgenModule, GenericEnvironment *genEnv,
3402-
SILFunction &currFn)
3403-
: genEnv(genEnv), irgenModule(irgenModule), currFn(currFn) {}
3406+
SILFunction &currFn, LargeSILTypeMapper &mapperCache)
3407+
: genEnv(genEnv), irgenModule(irgenModule), currFn(currFn),
3408+
mapperCache(mapperCache) {}
34043409

34053410
void assign(SILInstruction *inst);
34063411

@@ -3456,6 +3461,11 @@ class AddressAssignment {
34563461
return false;
34573462
}
34583463

3464+
SILType getNewSILType(CanSILFunctionType fnTy, SILType type) {
3465+
return mapperCache.getNewSILType(getSubstGenericEnvironment(fnTy), type,
3466+
*irgenModule);
3467+
}
3468+
34593469
bool isLargeLoadableType(SILType ty) {
34603470
if (ty.isAddress() || ty.isClassOrClassMetatype()) {
34613471
return false;
@@ -4077,9 +4087,15 @@ class RewriteUser : SILInstructionVisitor<RewriteUser> {
40774087
"caseBB has a payload argument");
40784088

40794089
SILBuilder caseBuilder = assignment.getBuilder(caseBB->begin());
4090+
SILType eltType =
4091+
opdAddr->getType().getEnumElementType(caseDecl,
4092+
caseBuilder.getModule(),
4093+
caseBuilder.getTypeExpansionContext());
4094+
SILType mappedEltType =
4095+
assignment.getNewSILType(sw->getFunction()->getLoweredFunctionType(),
4096+
eltType);
40804097
auto *caseAddr =
4081-
caseBuilder.createUncheckedTakeEnumDataAddr(loc, opdAddr, caseDecl);
4082-
4098+
caseBuilder.createUncheckedTakeEnumDataAddr(loc, opdAddr, caseDecl, mappedEltType);
40834099
if (assignment.isLargeLoadableType(caseArg->getType())) {
40844100
assignment.mapValueToAddress(caseArg, caseAddr);
40854101
assignment.markBlockArgumentForDeletion(caseBB);
@@ -4274,7 +4290,8 @@ void AddressAssignment::assign(SILInstruction *inst) {
42744290
}
42754291

42764292
static void runPeepholesAndReg2Mem(SILPassManager *pm, SILModule *silMod,
4277-
IRGenModule *irgenModule) {
4293+
IRGenModule *irgenModule,
4294+
LargeSILTypeMapper &mapperCache) {
42784295

42794296
if (!irgenModule->getOptions().EnableLargeLoadableTypesReg2Mem)
42804297
return;
@@ -4301,7 +4318,7 @@ static void runPeepholesAndReg2Mem(SILPassManager *pm, SILModule *silMod,
43014318
// Delete replaced instructions.
43024319
opts.deleteInstructions();
43034320

4304-
AddressAssignment assignment(irgenModule, genEnv, currF);
4321+
AddressAssignment assignment(irgenModule, genEnv, currF, mapperCache);
43054322

43064323
// Assign addresses to basic block arguments.
43074324

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// REQUIRES: differentiable_programming
2+
// RUN: %target-swift-frontend -emit-ir -verify %s
3+
4+
// This used to trigger an assertion due to LoadableByAddress not doing proper mapping of
5+
// switch_enum arguments during rewriting
6+
7+
import _Differentiation
8+
struct O: Differentiable {var a: B.G; var b: Array<SIMD2<Float>>; var c: B.M; var d: B.M}
9+
func w(y: B.M) -> B.M {return y}
10+
@differentiable(reverse) func q(sd: Float, i: B.G, j: Array<SIMD2<Float>>, s: B.M, t: B.M) -> O {
11+
let u1 = C(i.g4); let u2 = C(i.g2); let u3 = C(i.g3); let u4 = C(j)
12+
var u5 = C(s.m4); var u6 = C(t.m4); var u7 = w(y: s); let u8 = w(y: t)
13+
u6 = C(Array<Float>([0.0]))
14+
u5 = u6
15+
if sd > 0 {}
16+
u7.m4 = u5.r
17+
var u9 = i
18+
u9.g1 = C(u9.g1).r
19+
u9.g4 = u1.r
20+
u9.g2 = u2.r
21+
u9.g3 = u3.r
22+
return O(a: u9, b: u4.r, c: u7, d: u8)
23+
}
24+
protocol N {}; protocol H: Differentiable {}
25+
struct B: D, N, Differentiable
26+
{
27+
struct M: D, N, Differentiable {var m1 = Array<Float>(); var m2 = Array<SIMD2<Float>>(); var m3 = Array<SIMD2<Float>>(); var m4 = Array<Float>(); var m5 = Array<Float>()}
28+
struct G: D, N, Differentiable {var g1 = Array<Float>(); var g2 = Array<Float>(); var g3 = Array<Float>(); var g4 = Array<Float>()}
29+
}
30+
protocol D: Differentiable & H & N where Self.TangentVector: N {}
31+
struct C<E>: Differentiable, AdditiveArithmetic where E: Differentiable, E: AdditiveArithmetic {
32+
typealias TangentVector = C<E.TangentVector>
33+
var v: [E]; var a: E
34+
@differentiable(reverse) init(_ c: [E], t: E = .zero) {self.v = c; self.a = .zero}
35+
@differentiable(reverse) var r: [E] {return v}
36+
@inlinable @derivative(of: init(_:t:))
37+
static func vjpInit(_ values: [E], t: E = .zero) -> (value: C, pullback: (Self.TangentVector) -> ([E].TangentVector, E.TangentVector)) {(value: Self(values, t: t), pullback: {tangentVector in return ([E].TangentVector(tangentVector.v), tangentVector.a)})}
38+
@inlinable @derivative(of: r) func vjpArray() -> (value: [E], pullback: ([E].TangentVector) -> Self.TangentVector) {(value: self.v, pullback: {_ in return Self.TangentVector([E.TangentVector](), t: E.TangentVector.zero)})}
39+
mutating func move(by offset: Self.TangentVector) {}
40+
static func + (lhs: Self, rhs: Self) -> Self {return lhs}
41+
static func - (lhs: Self, rhs: Self) -> Self {return lhs}
42+
static var zero: Self { Self([], t: .zero) }
43+
}
44+

0 commit comments

Comments
 (0)