Skip to content

Commit 399d1c9

Browse files
authored
Merge pull request #35236 from marcrasi/sr-13945
[AutoDiff] generated linear maps should be "convention(thin)"
2 parents 3361aff + 0142e52 commit 399d1c9

File tree

5 files changed

+41
-6
lines changed

5 files changed

+41
-6
lines changed

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1699,7 +1699,7 @@ void JVPCloner::Implementation::prepareForDifferentialGeneration() {
16991699
auto *diffGenericEnv =
17001700
diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr;
17011701
auto diffType = SILFunctionType::get(
1702-
diffGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
1702+
diffGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(),
17031703
origTy->getCalleeConvention(), dfParams, {}, dfResults, None,
17041704
origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(),
17051705
original->getASTContext());

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ SILFunction *VJPCloner::Implementation::createEmptyPullback() {
980980
auto *pbGenericEnv =
981981
pbGenericSig ? pbGenericSig->getGenericEnvironment() : nullptr;
982982
auto pbType = SILFunctionType::get(
983-
pbGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
983+
pbGenericSig, SILExtInfo::getThin(), origTy->getCoroutineKind(),
984984
origTy->getCalleeConvention(), pbParams, {}, adjResults, None,
985985
origTy->getPatternSubstitutions(), origTy->getInvocationSubstitutions(),
986986
original->getASTContext());

test/AutoDiff/SILOptimizer/derivative_sil.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,25 @@ func foo(_ x: Float) -> Float {
8686
// CHECK-SIL: dealloc_stack [[TMP_BUF_RES]] : $*Float
8787
// CHECK-SIL: return [[DX]] : $Float
8888
// CHECK-SIL: }
89+
90+
// Check the conventions of the generated functions for a method (SR-13945).
91+
struct ExampleStruct {
92+
@_silgen_name("fooMethod")
93+
@differentiable
94+
func fooMethod(_ x: Float) -> Float {
95+
let y = Float.add(x, x)
96+
return y
97+
}
98+
}
99+
100+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__fooMethod__jvp_src_0_wrt_0 : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
101+
// CHECK-SIL: } // end sil function 'AD__fooMethod__jvp_src_0_wrt_0'
102+
103+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__fooMethod__differential_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__fooMethod_bb0__DF__src_0_wrt_0) -> Float {
104+
// CHECK-SIL: } // end sil function 'AD__fooMethod__differential_src_0_wrt_0'
105+
106+
// CHECK-SIL-LABEL: sil hidden [ossa] @AD__fooMethod__vjp_src_0_wrt_0 : $@convention(method) (Float, ExampleStruct) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
107+
// CHECK-SIL: } // end sil function 'AD__fooMethod__vjp_src_0_wrt_0'
108+
109+
// CHECK-SIL-LABEL: sil private [ossa] @AD__fooMethod__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__fooMethod_bb0__PB__src_0_wrt_0) -> Float {
110+
// CHECK-SIL: } // end sil function 'AD__fooMethod__pullback_src_0_wrt_0'

test/AutoDiff/SILOptimizer/semantic_member_accessors_sil.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
4444
// CHECK-LABEL: // differentiability witness for Struct.x.getter
4545
// CHECK-NEXT: sil_differentiability_witness private [parameters 0] [results 0] @$s4null6StructV1xSfvg : $@convention(method) (Struct) -> Float {
4646

47-
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvs__pullback_src_0_wrt_0_1_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@inout Generic<τ_0_0>.TangentVector, @owned {{.*}}) -> @out τ_0_0.TangentVector {
47+
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvs__pullback_src_0_wrt_0_1_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@inout Generic<τ_0_0>.TangentVector, @owned {{.*}}) -> @out τ_0_0.TangentVector {
4848
// CHECK: bb0([[ADJ_X_RESULT:%.*]] : $*τ_0_0.TangentVector, [[ADJ_SELF:%.*]] : $*Generic<τ_0_0>.TangentVector, {{.*}} : {{.*}}):
4949
// CHECK: [[ADJ_X_TMP:%.*]] = alloc_stack $τ_0_0.TangentVector
5050
// CHECK: [[ZERO_FN:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter
@@ -60,7 +60,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
6060
// CHECK: return {{.*}} : $()
6161
// CHECK: }
6262

63-
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvg__pullback_src_0_wrt_0_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned {{.*}}) -> @out Generic<τ_0_0>.TangentVector {
63+
// CHECK-LABEL: sil private [ossa] @AD__$s4null7GenericV1xxvg__pullback_src_0_wrt_0_{{16_Differentiation|s}}14DifferentiableRzl : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned {{.*}}) -> @out Generic<τ_0_0>.TangentVector {
6464
// CHECK: bb0([[ADJ_SELF_RESULT:%.*]] : $*Generic<τ_0_0>.TangentVector, [[SEED:%.*]] : $*τ_0_0.TangentVector, {{.*}} : ${{.*}}):
6565
// CHECK: [[ADJ_SELF_TMP:%.*]] = alloc_stack $Generic<τ_0_0>.TangentVector
6666
// CHECK: [[SEED_COPY:%.*]] = alloc_stack $τ_0_0.TangentVector
@@ -76,7 +76,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
7676
// CHECK: return {{.*}} : $()
7777
// CHECK: }
7878

79-
// CHECK-LABEL: sil private [ossa] @AD__$s4null6StructV1xSfvs__pullback_src_0_wrt_0_1 : $@convention(method) (@inout Struct.TangentVector, @owned _AD__$s4null6StructV1xSfvs_bb0__PB__src_0_wrt_0_1) -> Float {
79+
// CHECK-LABEL: sil private [ossa] @AD__$s4null6StructV1xSfvs__pullback_src_0_wrt_0_1 : $@convention(thin) (@inout Struct.TangentVector, @owned _AD__$s4null6StructV1xSfvs_bb0__PB__src_0_wrt_0_1) -> Float {
8080
// CHECK: bb0([[ADJ_SELF:%.*]] : $*Struct.TangentVector, {{.*}} : $_AD__$s4null6StructV1xSfvs_bb0__PB__src_0_wrt_0_1):
8181
// CHECK: [[ADJ_X_ADDR:%.*]] = struct_element_addr [[ADJ_SELF]] : $*Struct.TangentVector, #Struct.TangentVector.x
8282
// CHECK: [[ADJ_X:%.*]] = load [trivial] [[ADJ_X_ADDR]] : $*Float
@@ -85,7 +85,7 @@ func trigger<T: Differentiable>(_ x: T.Type) {
8585
// CHECK: return [[ADJ_X]] : $Float
8686
// CHECK: }
8787

88-
// CHECK-LABEL: sil private [ossa] @AD__$s4null6StructV1xSfvg__pullback_src_0_wrt_0 : $@convention(method) (Float, @owned _AD__$s4null6StructV1xSfvg_bb0__PB__src_0_wrt_0) -> Struct.TangentVector {
88+
// CHECK-LABEL: sil private [ossa] @AD__$s4null6StructV1xSfvg__pullback_src_0_wrt_0 : $@convention(thin) (Float, @owned _AD__$s4null6StructV1xSfvg_bb0__PB__src_0_wrt_0) -> Struct.TangentVector {
8989
// CHECK: bb0([[ADJ_X:%.*]] : $Float, {{.*}} : $_AD__$s4null6StructV1xSfvg_bb0__PB__src_0_wrt_0):
9090
// CHECK: [[ADJ_Y_ADDR:%.*]] = alloc_stack $Float
9191
// CHECK: [[ZERO_FN:%.*]] = witness_method $Float, #AdditiveArithmetic.zero!getter

test/AutoDiff/validation-test/control_flow.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,19 @@ ControlFlowTests.test("Loops") {
716716
expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
717717
expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
718718
expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
719+
720+
// SR13945: Loops in methods caused a runtime segfault.
721+
struct SR13945 {
722+
func loopInMethod(_ x: Float) -> Float {
723+
var result = x
724+
for _ in 0..<2 {
725+
result *= result
726+
}
727+
return result
728+
}
729+
}
730+
expectEqual((0, 0), valueWithGradient(at: 0, in: { SR13945().loopInMethod($0) }))
731+
expectEqual((1, 4), valueWithGradient(at: 1, in: { SR13945().loopInMethod($0) }))
719732
}
720733

721734
ControlFlowTests.test("BranchingCastInstructions") {

0 commit comments

Comments
 (0)