@@ -265,6 +265,11 @@ struct DifferentiationInvoker {
265
265
// a Swift AST node.
266
266
GradientInst,
267
267
268
+ // No known invoker. This is the case when the differentiation is requested
269
+ // from SIL source via a `autodiff_function` instruction **without** being
270
+ // linked to a Swift AST node.
271
+ AutoDiffFunctionInst,
272
+
268
273
// Invoked by the indirect application of differentiation. This case has an
269
274
// associated differentiation task reference.
270
275
IndirectDifferentiation,
@@ -287,10 +292,14 @@ struct DifferentiationInvoker {
287
292
private:
288
293
Kind kind;
289
294
union Value {
290
- // / The instruction associated with the `SILSource ` case.
295
+ // / The instruction associated with the `GradientInst ` case.
291
296
GradientInst *gradientInst;
292
297
Value (GradientInst *inst) : gradientInst (inst) {}
293
298
299
+ // / The instruction associated with the `AutoDiffFunctionInst` case.
300
+ AutoDiffFunctionInst *adFuncInst;
301
+ Value (AutoDiffFunctionInst *inst) : adFuncInst (inst) {}
302
+
294
303
// / The parent differentiation task associated with the
295
304
// / `IndirectDifferentiation` case.
296
305
std::pair<ApplyInst *, DifferentiationTask *> indirectDifferentiation;
@@ -322,6 +331,8 @@ struct DifferentiationInvoker {
322
331
public:
323
332
DifferentiationInvoker (GradientInst *inst)
324
333
: kind(Kind::GradientInst), value(inst) {}
334
+ DifferentiationInvoker (AutoDiffFunctionInst *inst)
335
+ : kind(Kind::AutoDiffFunctionInst), value(inst) {}
325
336
DifferentiationInvoker (ApplyInst *applyInst, DifferentiationTask *task)
326
337
: kind(Kind::IndirectDifferentiation), value(applyInst, task) {}
327
338
DifferentiationInvoker (ReverseAutoDiffExpr *expr)
@@ -338,6 +349,11 @@ struct DifferentiationInvoker {
338
349
return value.gradientInst ;
339
350
}
340
351
352
+ AutoDiffFunctionInst *getAutoDiffFunctionInst () const {
353
+ assert (kind == Kind::AutoDiffFunctionInst);
354
+ return value.adFuncInst ;
355
+ }
356
+
341
357
std::pair<ApplyInst *, DifferentiationTask *>
342
358
getIndirectDifferentiation () const {
343
359
assert (kind == Kind::IndirectDifferentiation);
@@ -812,6 +828,9 @@ void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
812
828
case Kind::GradientInst:
813
829
os << " gradient_inst=(" << *getGradientInst () << " )" ;
814
830
break ;
831
+ case Kind::AutoDiffFunctionInst:
832
+ os << " autodiff_function_inst=(" << *getAutoDiffFunctionInst () << " )" ;
833
+ break ;
815
834
case Kind::IndirectDifferentiation: {
816
835
auto indDiff = getIndirectDifferentiation ();
817
836
os << " indirect_differentiation=(apply_inst=(" << *indDiff.first
@@ -1177,6 +1196,7 @@ void ADContext::emitNondifferentiabilityError(SILInstruction *inst,
1177
1196
// associated with any source location, we emit a diagnostic at the
1178
1197
// instruction source location.
1179
1198
case DifferentiationInvoker::Kind::GradientInst:
1199
+ case DifferentiationInvoker::Kind::AutoDiffFunctionInst:
1180
1200
case DifferentiationInvoker::Kind::SILDifferentiableAttribute:
1181
1201
diagnose (opLoc,
1182
1202
diag.getValueOr (diag::autodiff_expression_is_not_differentiable));
@@ -1293,7 +1313,7 @@ class DifferentiableActivityInfo;
1293
1313
// / does not matter for the final result.
1294
1314
// /
1295
1315
// / Reference:
1296
- // / Laurent Hascoët. Automatic Differentiation by Program Transformation. 2017 .
1316
+ // / Laurent Hascoët. Automatic Differentiation by Program Transformation. 2007 .
1297
1317
class DifferentiableActivityAnalysis
1298
1318
: public FunctionAnalysisBase<DifferentiableActivityInfo> {
1299
1319
private:
@@ -4673,6 +4693,14 @@ static ReverseAutoDiffExpr *findDifferentialOperator(GradientInst *inst) {
4673
4693
return inst->getLoc ().getAsASTNode <ReverseAutoDiffExpr>();
4674
4694
}
4675
4695
4696
+ // / Given an `autodiff_function` instruction, find the corresponding
4697
+ // / differential operator used in the AST. If no differential operator is found,
4698
+ // / return nullptr.
4699
+ static AutoDiffFunctionExpr *findDifferentialOperator (
4700
+ AutoDiffFunctionInst *inst) {
4701
+ return inst->getLoc ().getAsASTNode <AutoDiffFunctionExpr>();
4702
+ }
4703
+
4676
4704
// Retrieve or create an empty gradient function based on a `gradient`
4677
4705
// instruction and replace all users of the `gradient` instruction with the
4678
4706
// gradient function. Returns the gradient function.
@@ -4953,6 +4981,9 @@ class Differentiation : public SILModuleTransform {
4953
4981
// / pushing a differentiation task onto the global list. Returns true if any
4954
4982
// / error occurred.
4955
4983
bool processGradientInst (GradientInst *gi, ADContext &context);
4984
+
4985
+ bool processAutoDiffFunctionInst (AutoDiffFunctionInst *adfi,
4986
+ ADContext &context);
4956
4987
};
4957
4988
} // end anonymous namespace
4958
4989
@@ -5003,6 +5034,60 @@ bool Differentiation::processGradientInst(GradientInst *gi,
5003
5034
return false ;
5004
5035
}
5005
5036
5037
+ bool Differentiation::processAutoDiffFunctionInst (AutoDiffFunctionInst *adfi,
5038
+ ADContext &context) {
5039
+ SILFunction *parent = adfi->getFunction ();
5040
+ auto origFnOperand = adfi->getOriginalFunction ();
5041
+ // If it traces back to a `function_ref`, differentiate that.
5042
+ if (auto *originalFRI = findReferenceToVisibleFunction (origFnOperand)) {
5043
+ auto *original = originalFRI->getReferencedFunction ();
5044
+ // TODO: Find syntax-level AD invoker from `adfi`.
5045
+ auto *task = context.lookUpOrRegisterDifferentiationTask (
5046
+ original, SILAutoDiffIndices (0 , adfi->getParameterIndices ()),
5047
+ DifferentiationInvoker (adfi));
5048
+ // Expand the `autodiff_function` instruction by adding the JVP and VJP
5049
+ // functions.
5050
+ SILBuilder builder (adfi);
5051
+ auto loc = parent->getLocation ();
5052
+ auto *vjp = task->getVJP ();
5053
+ auto *vjpRef = builder.createFunctionRef (loc, vjp);
5054
+ auto finalVJP = reapplyFunctionConversion (
5055
+ vjpRef, originalFRI, origFnOperand, builder, loc);
5056
+ SILValue finalJVP;
5057
+ // TODO: Implement "forward-mode differentiation" to get JVP. Currently it's
5058
+ // just an undef because we won't use it.
5059
+ {
5060
+ auto origFnTy = origFnOperand->getType ().getAs <SILFunctionType>();
5061
+ auto jvpType = origFnTy->getAutoDiffAssociatedFunctionType (
5062
+ adfi->getParameterIndices (), /* resultIndex*/ 0 ,
5063
+ /* differentiationOrder*/ 1 , AutoDiffAssociatedFunctionKind::JVP,
5064
+ *getModule (),
5065
+ /* lookupConformance*/
5066
+ LookUpConformanceInModule (getModule ()->getSwiftModule ()));
5067
+ finalJVP = SILUndef::get (SILType::getPrimitiveObjectType (jvpType),
5068
+ getModule ());
5069
+ }
5070
+ auto *newADFI = builder.createAutoDiffFunction (
5071
+ loc, adfi->getParameterIndices (), adfi->getDifferentiationOrder (),
5072
+ origFnOperand, {finalJVP, finalVJP});
5073
+ adfi->replaceAllUsesWith (newADFI);
5074
+ adfi->eraseFromParent ();
5075
+ }
5076
+ // Differentiating opaque functions is not supported yet.
5077
+ else {
5078
+ // Find the original differential operator expression. Show an error at the
5079
+ // operator, highlight the argument, and show a note at the definition site
5080
+ // of the argument.
5081
+ if (auto *expr = findDifferentialOperator (adfi))
5082
+ context.diagnose (expr->getSubExpr ()->getLoc (),
5083
+ diag::autodiff_function_not_differentiable)
5084
+ .highlight (expr->getSubExpr ()->getSourceRange ());
5085
+ return true ;
5086
+ }
5087
+ PM->invalidateAnalysis (parent, SILAnalysis::InvalidationKind::FunctionBody);
5088
+ return false ;
5089
+ }
5090
+
5006
5091
// / AD pass entry.
5007
5092
void Differentiation::run () {
5008
5093
auto &module = *getModule ();
@@ -5013,6 +5098,7 @@ void Differentiation::run() {
5013
5098
SmallVector<std::pair<SILFunction *,
5014
5099
SILDifferentiableAttr *>, 8 > diffAttrs;
5015
5100
SmallVector<GradientInst *, 16 > gradInsts;
5101
+ SmallVector<AutoDiffFunctionInst *, 16 > autodiffInsts;
5016
5102
// Handle each `gradient` instruction and each `differentiable`
5017
5103
// attribute in the module.
5018
5104
for (SILFunction &f : module ) {
@@ -5037,6 +5123,8 @@ void Differentiation::run() {
5037
5123
// operator, push it to the work list.
5038
5124
if (auto *gi = dyn_cast<GradientInst>(&i))
5039
5125
gradInsts.push_back (gi);
5126
+ else if (auto *adfi = dyn_cast<AutoDiffFunctionInst>(&i))
5127
+ autodiffInsts.push_back (adfi);
5040
5128
}
5041
5129
}
5042
5130
}
@@ -5081,9 +5169,12 @@ void Differentiation::run() {
5081
5169
// turned into a differentiation task. But we don't back out just yet - primal
5082
5170
// synthesis and adjoint synthesis for newly created differentiation tasks
5083
5171
// should still run because they will diagnose more errors.
5084
- bool errorProcessingGradInsts = false ;
5172
+ bool errorProcessingAutoDiffInsts = false ;
5085
5173
for (auto *gi : gradInsts)
5086
- errorProcessingGradInsts |= processGradientInst (gi, context);
5174
+ errorProcessingAutoDiffInsts |= processGradientInst (gi, context);
5175
+
5176
+ for (auto *adfi : autodiffInsts)
5177
+ errorProcessingAutoDiffInsts |= processAutoDiffFunctionInst (adfi, context);
5087
5178
5088
5179
// Run primal generation for newly created differentiation tasks. If any error
5089
5180
// occurs, back out.
@@ -5099,7 +5190,7 @@ void Differentiation::run() {
5099
5190
5100
5191
// If there was any error that occurred during `gradient` instruction
5101
5192
// processing, back out.
5102
- if (errorProcessingGradInsts )
5193
+ if (errorProcessingAutoDiffInsts )
5103
5194
return ;
5104
5195
5105
5196
// Fill the body of each empty canonical gradient function.
0 commit comments