Skip to content

Commit 7f79ae5

Browse files
authored
Support differentiating closures with contexts (#20186)
1 parent c605516 commit 7f79ae5

File tree

4 files changed

+33
-8
lines changed

4 files changed

+33
-8
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ static FunctionRefInst *findReferenceToVisibleFunction(SILValue value) {
226226
return findReferenceToVisibleFunction(thinToThick->getOperand());
227227
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(inst))
228228
return findReferenceToVisibleFunction(convertFn->getOperand());
229+
if (auto *partialApply = dyn_cast<PartialApplyInst>(inst))
230+
return findReferenceToVisibleFunction(partialApply->getCallee());
229231
return nullptr;
230232
}
231233

@@ -1568,7 +1570,7 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
15681570
newFunc, oldFunc, pai->getCallee(), builder, loc, substituteOperand);
15691571
return builder.createPartialApply(
15701572
loc, innerNewFunc, pai->getSubstitutionMap(), newArgs,
1571-
pai->getOrigCalleeType()->getCalleeConvention());
1573+
ParameterConvention::Direct_Guaranteed);
15721574
}
15731575
llvm_unreachable("Unhandled function convertion instruction");
15741576
}
@@ -4004,7 +4006,7 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
40044006
// Ensure the witness method is linked.
40054007
getModule().lookUpFunctionInWitnessTable(confRef, declRef);
40064008
auto subMap =
4007-
SubstitutionMap::getProtocolSubstitutions(proto, adjointASTTy, confRef);
4009+
SubstitutionMap::getProtocolSubstitutions(proto, adjointASTTy, confRef);
40084010
// %1 = metatype $T.Type
40094011
auto metatypeType =
40104012
CanMetatypeType::get(adjointASTTy, MetatypeRepresentation::Thick);

test/AutoDiff/autodiff_e2e_basic.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ func adjointId(_ x: Float, originalValue: Float, seed: Float) -> Float {
1111

1212
_ = #gradient(id)(2)
1313

14-
// CHECK: @{{.*}}id{{.*}}__grad_src_0_wrt_0
14+
// CHECK-LABEL: @{{.*}}id{{.*}}__grad_src_0_wrt_0
1515
// CHECK-LABEL: @{{.*}}id{{.*}}__grad_src_0_wrt_0_s_p
1616

1717
#if os(macOS) || os(iOS) || os(tvOS) || os(watchOS)
@@ -41,7 +41,7 @@ let x = #gradient(sigmoid)(3)
4141
let (value: y, gradient: z) = #valueAndGradient(sigmoid)(4)
4242
print(x * z)
4343

44-
// CHECK: @{{.*}}sigmoid{{.*}}__grad_src_0_wrt_0
44+
// CHECK-LABEL: @{{.*}}sigmoid{{.*}}__grad_src_0_wrt_0
4545
// CHECK: @{{.*}}sigmoid{{.*}}__grad_src_0_wrt_0_s_p
4646
// CHECK: @{{.*}}sigmoid{{.*}}__grad_src_0_wrt_0_p
4747

@@ -51,7 +51,6 @@ public func publicFunc(_ x: Float) -> Float {
5151
}
5252
_ = #gradient(publicFunc)
5353

54-
// CHECK: sil non_abi @{{.*}}publicFunc{{.*}}__grad_src_0_wrt_0
55-
// CHECK: sil non_abi @{{.*}}publicFunc{{.*}}__primal_src_0_wrt_0
56-
// CHECK: sil non_abi @{{.*}}publicFunc{{.*}}__adjoint_src_0_wrt_0
57-
54+
// CHECK-LABEL: @{{.*}}publicFunc{{.*}}__grad_src_0_wrt_0
55+
// CHECK: @{{.*}}publicFunc{{.*}}__primal_src_0_wrt_0
56+
// CHECK: @{{.*}}publicFunc{{.*}}__adjoint_src_0_wrt_0

test/AutoDiff/closures.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
public func closureCapture() {
4+
let val: Float = 10
5+
let clo: (Float) -> Float = { x in
6+
val * x
7+
}
8+
_ = #gradient(clo)
9+
}
10+
11+
// CHECK-LABEL: @{{.*}}closureCapture{{.*}}
12+
// CHECK: [[ORIG_FN:%.*]] = function_ref @{{.*}}closureCapture{{.*}}
13+
// CHECK: [[PARTIAL_APPLIED:%.*]] = partial_apply [callee_guaranteed] [[ORIG_FN]]
14+
// CHECK: [[GRAD_FN:%.*]] = function_ref @{{.*}}closureCapture{{.*}}___grad
15+
// CHECK: [[PARTIAL_APPLIED_GRAD:%.*]] = partial_apply [callee_guaranteed] [[GRAD_FN]]

test/AutoDiff/simple_math.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,13 @@ SimpleMathTests.test("ResultSelection") {
5656
expectEqual((0, 1), #gradient(foo, result: .1)(3, 3))
5757
}
5858

59+
60+
SimpleMathTests.test("CaptureGlobal") {
61+
let z: Float = 10
62+
func foo(_ x: Float) -> Float {
63+
return z * x
64+
}
65+
expectEqual(10, #gradient(foo)(0))
66+
}
67+
5968
runAllTests()

0 commit comments

Comments
 (0)