Skip to content

Commit 48980e2

Browse files
committed
[AutoDiff] Start linear_function canonicalization skeleton
Start `linear_function` canonicalization skeleton copying from `differentiable_function` canonicalization. For now, transpose function operands are filled in with `undef`.
1 parent 33de64e commit 48980e2

File tree

4 files changed

+54
-18
lines changed

4 files changed

+54
-18
lines changed

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,7 @@ class ADContext {
219219
/// pointer value as a previously processed and deleted instruction.
220220
LinearFunctionInst *
221221
createLinearFunction(SILBuilder &builder, SILLocation loc,
222-
IndexSubset *parameterIndices,
223-
IndexSubset *resultIndices, SILValue original,
222+
IndexSubset *parameterIndices, SILValue original,
224223
Optional<SILValue> transposeFunction = None);
225224

226225
// Given a `differentiable_function` instruction, finds the corresponding

lib/SILOptimizer/Differentiation/ADContext.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
123123
return dfi;
124124
}
125125

126+
LinearFunctionInst *ADContext::createLinearFunction(
127+
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
128+
SILValue original, Optional<SILValue> transposeFunction) {
129+
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
130+
transposeFunction);
131+
processedLinearFunctionInsts.erase(lfi);
132+
return lfi;
133+
}
134+
126135
DifferentiableFunctionExpr *
127136
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
128137
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -755,19 +755,6 @@ emitDerivativeFunctionReference(
755755
return None;
756756
}
757757

758-
/// Emits a reference to the transpose function of `originalFunction`,
759-
/// differentiated with respect to exactly `desiredIndices`. Returns the
760-
/// transpose function `SILValue`.
761-
///
762-
/// Returns `None` on failure, signifying that a diagnostic has been emitted
763-
/// using `invoker`.
764-
static Optional<SILValue> emitTransposeFunctionReference(
765-
DifferentiationTransformer &transformer, SILBuilder &builder,
766-
SILAutoDiffIndices desiredIndices, SILValue originalFunction,
767-
DifferentiationInvoker invoker) {
768-
// TODO: Fill in.
769-
}
770-
771758
//===----------------------------------------------------------------------===//
772759
// `SILDifferentiabilityWitness` processing
773760
//===----------------------------------------------------------------------===//
@@ -1226,13 +1213,25 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
12261213
}
12271214

12281215
SILValue DifferentiationTransformer::promoteToLinearFunction(
1229-
LinearFunctionInst *inst, SILBuilder &builder, SILLocation loc,
1216+
LinearFunctionInst *lfi, SILBuilder &builder, SILLocation loc,
12301217
DifferentiationInvoker invoker) {
12311218
// TODO: Fill in. Copy code from above.
12321219
// For now, create a new `linear_function` instruction with an undef
12331220
// transpose.
12341221
// Eventually, use `emitTransposeFunctionReference` to fill in legitimately.
1235-
return inst;
1222+
auto origFnOperand = lfi->getOriginalFunction();
1223+
auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand);
1224+
auto *parameterIndices = lfi->getParameterIndices();
1225+
auto originalType = origFnOperand->getType().castTo<SILFunctionType>();
1226+
auto transposeFnType = originalType->getAutoDiffTransposeFunctionType(
1227+
parameterIndices, context.getTypeConverter(),
1228+
LookUpConformanceInModule(builder.getModule().getSwiftModule()));
1229+
auto transposeType = SILType::getPrimitiveObjectType(transposeFnType);
1230+
auto transposeFn = SILUndef::get(transposeType, builder.getFunction());
1231+
auto *newLinearFn = context.createLinearFunction(
1232+
builder, loc, parameterIndices, origFnCopy, SILValue(transposeFn));
1233+
context.getLinearFunctionInstWorklist().push_back(lfi);
1234+
return newLinearFn;
12361235
}
12371236

12381237
/// Fold `differentiable_function_extract` users of the given
@@ -1390,7 +1389,8 @@ void Differentiation::run() {
13901389

13911390
// If nothing has triggered differentiation, there's nothing to do.
13921391
if (context.getInvokers().empty() &&
1393-
context.getDifferentiableFunctionInstWorklist().empty())
1392+
context.getDifferentiableFunctionInstWorklist().empty() &&
1393+
context.getLinearFunctionInstWorklist().empty())
13941394
return;
13951395

13961396
// Differentiation relies on the stdlib (the Swift module).
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %target-sil-opt -differentiation %s | %FileCheck %s
2+
3+
sil_stage raw
4+
5+
import Swift
6+
import Builtin
7+
8+
import _Differentiation
9+
10+
sil hidden @foo : $@convention(thin) (Float, Float, Float) -> Float {
11+
bb0(%0 : $Float, %1 : $Float, %2 : $Float):
12+
return %2 : $Float
13+
}
14+
15+
sil @make_linear_func : $@convention(thin) () -> () {
16+
bb0:
17+
%orig = function_ref @foo : $@convention(thin) (Float, Float, Float) -> Float
18+
%linear_fn_0 = linear_function [parameters 0] %orig : $@convention(thin) (Float, Float, Float) -> Float
19+
%linear_fn_1 = linear_function [parameters 0 2] %orig : $@convention(thin) (Float, Float, Float) -> Float
20+
return undef : $()
21+
}
22+
23+
// CHECK-LABEL: sil @make_linear_func
24+
// CHECK: bb0:
25+
// CHECK: [[ORIG_FN:%.*]] = function_ref @foo : $@convention(thin) (Float, Float, Float) -> Float
26+
// CHECK: linear_function [parameters 0] %0 : $@convention(thin) (Float, Float, Float) -> Float with_transpose undef : $@convention(thin) (Float, Float, Float) -> Float
27+
// CHECK: linear_function [parameters 0 2] %0 : $@convention(thin) (Float, Float, Float) -> Float with_transpose undef : $@convention(thin) (Float, Float) -> (Float, Float)
28+
// CHECK: }

0 commit comments

Comments
 (0)