Skip to content

Commit c967231

Browse files
authored
[AutoDiff] Use borrowed '@differentiable' function when reabstracting it.
Reabstracting a `@differentiable` function during SILGen works by extracting and reabstracting each component function from the `@differentiable` function. The current implementation creates a copy of the `@differentiable` function before each extraction, which is very inefficient. This patch changes it to borrowing the `@differentiable` function and copying the extracted component functions.
1 parent ef6ee9b commit c967231

File tree

2 files changed

+48
-12
lines changed

2 files changed

+48
-12
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,7 +3240,7 @@ static ManagedValue createThunk(SILGenFunction &SGF,
32403240
}
32413241

32423242
// SWIFT_ENABLE_TENSORFLOW
3243-
/// Create a reabstraction thunk for an @differentiable function.
3243+
/// Create a reabstraction thunk for a @differentiable function.
32443244
static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
32453245
SILLocation loc,
32463246
ManagedValue fn,
@@ -3275,9 +3275,10 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
32753275
outputSubstTypeNotDiff);
32763276
// `autodiff_function_extract` is consuming; copy `fn` before passing as
32773277
// operand.
3278-
auto copiedFnValue = fn.copy(SGF, loc);
3279-
auto *original = SGF.B.createAutoDiffFunctionExtractOriginal(
3280-
loc, copiedFnValue.forward(SGF));
3278+
auto borrowedFnValue = fn.borrow(SGF, loc);
3279+
SILValue original = SGF.B.createAutoDiffFunctionExtractOriginal(
3280+
loc, borrowedFnValue.getValue());
3281+
original = SGF.B.emitCopyValueOperation(loc, original);
32813282
auto managedOriginal = SGF.emitManagedRValueWithCleanup(original);
32823283

32833284
ManagedValue originalThunk = createThunk(
@@ -3319,11 +3320,9 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33193320
auto assocFnOutputSubstType = getAssocFnTy(outputSubstTypeNotDiff, kind);
33203321
auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType,
33213322
assocFnOutputSubstType);
3322-
// `autodiff_function_extract` is consuming; copy `fn` before passing as
3323-
// operand.
3324-
auto copiedFnValue = fn.copy(SGF, loc);
3325-
auto *assocFn = SGF.B.createAutoDiffFunctionExtract(
3326-
loc, kind, /*differentiationOrder*/ 1, copiedFnValue.forward(SGF));
3323+
SILValue assocFn = SGF.B.createAutoDiffFunctionExtract(
3324+
loc, kind, /*differentiationOrder*/ 1, borrowedFnValue.getValue());
3325+
assocFn = SGF.B.emitCopyValueOperation(loc, assocFn);
33273326
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(assocFn);
33283327
return createThunk(SGF, loc, managedAssocFn, assocFnInputOrigType,
33293328
assocFnInputSubstType, assocFnOutputOrigType,
@@ -3336,9 +3335,8 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33363335
SILValue convertedBundle = SGF.B.createAutoDiffFunction(
33373336
loc, sourceType->getDifferentiationParameterIndices(),
33383337
/*differentiationOrder*/ 1,
3339-
originalThunk.ensurePlusOne(SGF, loc).forward(SGF),
3340-
{jvpThunk.ensurePlusOne(SGF, loc).forward(SGF),
3341-
vjpThunk.ensurePlusOne(SGF, loc).forward(SGF)});
3338+
originalThunk.forward(SGF),
3339+
{jvpThunk.forward(SGF), vjpThunk.forward(SGF)});
33423340
return SGF.emitManagedRValueWithCleanup(convertedBundle);
33433341
}
33443342

test/AutoDiff/autodiff_function_silgen.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
// RUN: %target-swift-frontend -emit-silgen %s | %FileCheck %s -check-prefix=CHECK-SILGEN
33
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL
44

5+
//===----------------------------------------------------------------------===//
6+
// Closure conversion
7+
//===----------------------------------------------------------------------===//
8+
59
func thin(x: Float) -> Float { return x }
610

711
func myfunction(_ f: @escaping @differentiable (Float) -> (Float)) -> (Float) -> Float {
@@ -43,3 +47,37 @@ func apply() {
4347

4448
// CHECK-SIL: [[DIFFED:%.*]] = autodiff_function [wrt 0] [order 1] {{%.*}} : $@callee_guaranteed (Float) -> Float
4549
// CHECK-SIL: release_value [[DIFFED]] : $@differentiable @callee_guaranteed (Float) -> Float
50+
51+
//===----------------------------------------------------------------------===//
52+
// Reabstraction
53+
//===----------------------------------------------------------------------===//
54+
55+
func pullback<T, R>(
56+
at x: T, in f: @escaping @differentiable (T) -> R
57+
) -> (R.TangentVector) -> T.TangentVector {
58+
fatalError()
59+
}
60+
61+
func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) {
62+
_ = pullback(at: .zero, in: f)
63+
}
64+
65+
// CHECK-SILGEN-LABEL: @{{.*}}appliesReabstraction{{.*}}
66+
// CHECK-SILGEN: bb0([[DIFF_FUNC_ARG:%.*]] : @guaranteed $@differentiable @callee_guaranteed (Float) -> Float):
67+
// CHECK-SILGEN: [[DIFF_FUNC:%.*]] = copy_value [[DIFF_FUNC_ARG]] : $@differentiable @callee_guaranteed (Float) -> Float
68+
// CHECK-SILGEN: [[DIFF_FUNC_BORROWED:%.*]] = begin_borrow [[DIFF_FUNC]] : $@differentiable @callee_guaranteed (Float) -> Float
69+
// CHECK-SILGEN: [[ORIG:%.*]] = autodiff_function_extract [original] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
70+
// CHECK-SILGEN: [[ORIG_COPY:%.*]] = copy_value [[ORIG]] : $@callee_guaranteed (Float) -> Float
71+
// CHECK-SILGEN: [[REABS_ORIG:%.*]] = function_ref @$sS2fIegyd_S2fIegnr_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float
72+
// CHECK-SILGEN: [[NEW_ORIG:%.*]] = partial_apply [callee_guaranteed] [[REABS_ORIG]]([[ORIG_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> @out Float
73+
// CHECK-SILGEN: [[JVP:%.*]] = autodiff_function_extract [jvp] [order 1] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
74+
// CHECK-SILGEN: [[JVP_COPY:%.*]] = copy_value [[JVP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
75+
// CHECK-SILGEN: [[REABS_JVP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
76+
// CHECK-SILGEN: [[NEW_JVP:%.*]] = partial_apply [callee_guaranteed] [[REABS_JVP]]([[JVP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
77+
// CHECK-SILGEN: [[VJP:%.*]] = autodiff_function_extract [vjp] [order 1] [[DIFF_FUNC_BORROWED]] : $@differentiable @callee_guaranteed (Float) -> Float
78+
// CHECK-SILGEN: [[VJP_COPY:%.*]] = copy_value [[VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
79+
// CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
80+
// CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
81+
// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = autodiff_function [wrt 0] [order 1] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
82+
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
83+
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector

0 commit comments

Comments
 (0)