@@ -1001,6 +1001,7 @@ class ADContext {
1001
1001
// / parameters as the function.
1002
1002
StructDecl *createPrimalValueStruct (const DifferentiationTask *task);
1003
1003
1004
+ private:
1004
1005
// / Finds the `[differentiable]` attribute on the specified original function
1005
1006
// / corresponding to the specified parameter indices. Returns nullptr if it
1006
1007
// / does not exist.
@@ -1078,6 +1079,7 @@ class ADContext {
1078
1079
return differentiationTasks[existing->getSecond ()].get ();
1079
1080
}
1080
1081
1082
+ public:
1081
1083
// / Register a differentiation task in the global worklist. This will ensure
1082
1084
// / that a `[differentiable]` attribute will be generated for the specified
1083
1085
// / indices, and that primal/adjoint synthesis will be run in the
@@ -1113,6 +1115,21 @@ class ADContext {
1113
1115
lookUpOrRegisterDifferentiationTask (SILFunction *original,
1114
1116
const SILAutoDiffIndices &indices,
1115
1117
DifferentiationInvoker invoker) {
1118
+ // If `original` has no differentiable attributes, it may be the case that
1119
+ // it has not been loaded yet. Load it, check for differentiable attributes,
1120
+ // and register the attributes as tasks so that they can be looked up.
1121
+ if (original->getDifferentiableAttrs ().empty () &&
1122
+ original->isExternalDeclaration ()) {
1123
+ auto loaded = module .loadFunction (original);
1124
+ assert (loaded && " Cannot load original function" );
1125
+ (void )loaded;
1126
+ for (auto *diffAttr : original->getDifferentiableAttrs ()) {
1127
+ registerDifferentiationTask (
1128
+ original, diffAttr->getIndices (),
1129
+ DifferentiationInvoker (diffAttr, original));
1130
+ }
1131
+ }
1132
+
1116
1133
if (auto *existingTask = lookUpMinimalDifferentiationTask (original, indices))
1117
1134
return existingTask;
1118
1135
return registerDifferentiationTask (original, indices, invoker);
0 commit comments