Skip to content

Commit cb2eacc

Browse files
authored
fix bug where AD is not always seeing custom [differentiable] attrs (#21565)
1 parent 7903809 commit cb2eacc

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,7 @@ class ADContext {
10011001
/// parameters as the function.
10021002
StructDecl *createPrimalValueStruct(const DifferentiationTask *task);
10031003

1004+
private:
10041005
/// Finds the `[differentiable]` attribute on the specified original function
10051006
/// corresponding to the specified parameter indices. Returns nullptr if it
10061007
/// does not exist.
@@ -1078,6 +1079,7 @@ class ADContext {
10781079
return differentiationTasks[existing->getSecond()].get();
10791080
}
10801081

1082+
public:
10811083
/// Register a differentiation task in the global worklist. This will ensure
10821084
/// that a `[differentiable]` attribute will be generated for the specified
10831085
/// indices, and that primal/adjoint synthesis will be run in the
@@ -1113,6 +1115,21 @@ class ADContext {
11131115
lookUpOrRegisterDifferentiationTask(SILFunction *original,
11141116
const SILAutoDiffIndices &indices,
11151117
DifferentiationInvoker invoker) {
1118+
// If `original` has no differentiable attributes, it may be the case that
1119+
// it has not been loaded yet. Load it, check for differentiable attributes,
1120+
// and register the attributes as tasks so that they can be looked up.
1121+
if (original->getDifferentiableAttrs().empty() &&
1122+
original->isExternalDeclaration()) {
1123+
auto loaded = module.loadFunction(original);
1124+
assert(loaded && "Cannot load original function");
1125+
(void)loaded;
1126+
for (auto *diffAttr : original->getDifferentiableAttrs()) {
1127+
registerDifferentiationTask(
1128+
original, diffAttr->getIndices(),
1129+
DifferentiationInvoker(diffAttr, original));
1130+
}
1131+
}
1132+
11161133
if (auto *existingTask = lookUpMinimalDifferentiationTask(original, indices))
11171134
return existingTask;
11181135
return registerDifferentiationTask(original, indices, invoker);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s
2+
3+
// Do not add other tests to this file, because differentiating other things
4+
// can hide this bug.
5+
6+
// There was a bug where the AD pass would not see the custom [differentiable]
7+
// attribute on `Float.*` wrt both parameters until after the AD pass generated
8+
// its own adjoint for `Float.*` wrt the second parameter. This test verifies
9+
// that the AD pass uses the custom adjoint defined on `Float.*`.
10+
11+
func mul3(_ x: Float) -> Float {
12+
return 3 * x
13+
}
14+
15+
let _ = gradient(at: 0, in: mul3)
16+
17+
// CHECK-LABEL: sil{{.*}} @AD__{{.*}}mul3{{.*}}__primal{{.*}}
18+
// CHECK: function_ref static Float._adjointMultiply(_:_:_:_:)
19+
// CHECK: } // end sil function 'AD__{{.*}}mul3{{.*}}__primal{{.*}}'

0 commit comments

Comments
 (0)