@@ -755,19 +755,6 @@ emitDerivativeFunctionReference(
755
755
return None;
756
756
}
757
757
758
- // / Emits a reference to the transpose function of `originalFunction`,
759
- // / differentiated with respect to exactly `desiredIndices`. Returns the
760
- // / transpose function `SILValue`.
761
- // /
762
- // / Returns `None` on failure, signifying that a diagnostic has been emitted
763
- // / using `invoker`.
764
- static Optional<SILValue> emitTransposeFunctionReference (
765
- DifferentiationTransformer &transformer, SILBuilder &builder,
766
- SILAutoDiffIndices desiredIndices, SILValue originalFunction,
767
- DifferentiationInvoker invoker) {
768
- // TODO: Fill in.
769
- }
770
-
771
758
// ===----------------------------------------------------------------------===//
772
759
// `SILDifferentiabilityWitness` processing
773
760
// ===----------------------------------------------------------------------===//
@@ -1226,13 +1213,25 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1226
1213
}
1227
1214
1228
1215
SILValue DifferentiationTransformer::promoteToLinearFunction (
1229
- LinearFunctionInst *inst , SILBuilder &builder, SILLocation loc,
1216
+ LinearFunctionInst *lfi , SILBuilder &builder, SILLocation loc,
1230
1217
DifferentiationInvoker invoker) {
1231
1218
// TODO: Fill in. Copy code from above.
1232
1219
// For now, create a new `linear_function` instruction with an undef
1233
1220
// transpose.
1234
1221
// Eventually, use `emitTransposeFunctionReference` to fill in legitimately.
1235
- return inst;
1222
+ auto origFnOperand = lfi->getOriginalFunction ();
1223
+ auto origFnCopy = builder.emitCopyValueOperation (loc, origFnOperand);
1224
+ auto *parameterIndices = lfi->getParameterIndices ();
1225
+ auto originalType = origFnOperand->getType ().castTo <SILFunctionType>();
1226
+ auto transposeFnType = originalType->getAutoDiffTransposeFunctionType (
1227
+ parameterIndices, context.getTypeConverter (),
1228
+ LookUpConformanceInModule (builder.getModule ().getSwiftModule ()));
1229
+ auto transposeType = SILType::getPrimitiveObjectType (transposeFnType);
1230
+ auto transposeFn = SILUndef::get (transposeType, builder.getFunction ());
1231
+ auto *newLinearFn = context.createLinearFunction (
1232
+ builder, loc, parameterIndices, origFnCopy, SILValue (transposeFn));
1233
+ context.getLinearFunctionInstWorklist ().push_back (lfi);
1234
+ return newLinearFn;
1236
1235
}
1237
1236
1238
1237
// / Fold `differentiable_function_extract` users of the given
@@ -1390,7 +1389,8 @@ void Differentiation::run() {
1390
1389
1391
1390
// If nothing has triggered differentiation, there's nothing to do.
1392
1391
if (context.getInvokers ().empty () &&
1393
- context.getDifferentiableFunctionInstWorklist ().empty ())
1392
+ context.getDifferentiableFunctionInstWorklist ().empty () &&
1393
+ context.getLinearFunctionInstWorklist ().empty ())
1394
1394
return ;
1395
1395
1396
1396
// Differentiation relies on the stdlib (the Swift module).
0 commit comments