Skip to content

Commit 5fbca17

Browse files
committed
Ensure we are using mapped SIL type for switch_enum case and not the
original lowered one. Fixes #73018
1 parent c3488c6 commit 5fbca17

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4078,8 +4078,8 @@ class RewriteUser : SILInstructionVisitor<RewriteUser> {
40784078

40794079
SILBuilder caseBuilder = assignment.getBuilder(caseBB->begin());
40804080
auto *caseAddr =
4081-
caseBuilder.createUncheckedTakeEnumDataAddr(loc, opdAddr, caseDecl);
4082-
4081+
caseBuilder.createUncheckedTakeEnumDataAddr(loc, opdAddr, caseDecl,
4082+
caseArg->getType().getAddressType());
40834083
if (assignment.isLargeLoadableType(caseArg->getType())) {
40844084
assignment.mapValueToAddress(caseArg, caseAddr);
40854085
assignment.markBlockArgumentForDeletion(caseBB);
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// REQUIRES: differentiable_programming
2+
// RUN: %target-swift-frontend -emit-ir -verify %s
3+
4+
// This used to trigger an assertion due to LoadableByAddress not doing proper mapping of
5+
// switch_enum arguments during rewriting
6+
7+
import _Differentiation
8+
struct O: Differentiable {var a: B.G; var b: Array<SIMD2<Float>>; var c: B.M; var d: B.M}
9+
func w(y: B.M) -> B.M {return y}
10+
@differentiable(reverse) func q(sd: Float, i: B.G, j: Array<SIMD2<Float>>, s: B.M, t: B.M) -> O {
11+
let u1 = C(i.g4); let u2 = C(i.g2); let u3 = C(i.g3); let u4 = C(j)
12+
var u5 = C(s.m4); var u6 = C(t.m4); var u7 = w(y: s); let u8 = w(y: t)
13+
u6 = C(Array<Float>([0.0]))
14+
u5 = u6
15+
if sd > 0 {}
16+
u7.m4 = u5.r
17+
var u9 = i
18+
u9.g1 = C(u9.g1).r
19+
u9.g4 = u1.r
20+
u9.g2 = u2.r
21+
u9.g3 = u3.r
22+
return O(a: u9, b: u4.r, c: u7, d: u8)
23+
}
24+
protocol N {}; protocol H: Differentiable {}
25+
struct B: D, N, Differentiable
26+
{
27+
struct M: D, N, Differentiable {var m1 = Array<Float>(); var m2 = Array<SIMD2<Float>>(); var m3 = Array<SIMD2<Float>>(); var m4 = Array<Float>(); var m5 = Array<Float>()}
28+
struct G: D, N, Differentiable {var g1 = Array<Float>(); var g2 = Array<Float>(); var g3 = Array<Float>(); var g4 = Array<Float>()}
29+
}
30+
protocol D: Differentiable & H & N where Self.TangentVector: N {}
31+
struct C<E>: Differentiable, AdditiveArithmetic where E: Differentiable, E: AdditiveArithmetic {
32+
typealias TangentVector = C<E.TangentVector>
33+
var v: [E]; var a: E
34+
@differentiable(reverse) init(_ c: [E], t: E = .zero) {self.v = c; self.a = .zero}
35+
@differentiable(reverse) var r: [E] {return v}
36+
@inlinable @derivative(of: init(_:t:))
37+
static func vjpInit(_ values: [E], t: E = .zero) -> (value: C, pullback: (Self.TangentVector) -> ([E].TangentVector, E.TangentVector)) {(value: Self(values, t: t), pullback: {tangentVector in return ([E].TangentVector(tangentVector.v), tangentVector.a)})}
38+
@inlinable @derivative(of: r) func vjpArray() -> (value: [E], pullback: ([E].TangentVector) -> Self.TangentVector) {(value: self.v, pullback: {_ in return Self.TangentVector([E.TangentVector](), t: E.TangentVector.zero)})}
39+
mutating func move(by offset: Self.TangentVector) {}
40+
static func + (lhs: Self, rhs: Self) -> Self {return lhs}
41+
static func - (lhs: Self, rhs: Self) -> Self {return lhs}
42+
static var zero: Self { Self([], t: .zero) }
43+
}
44+

0 commit comments

Comments
 (0)