Skip to content

Commit b41adcd

Browse files
author
marcrasi
authored
[AutoDiff] populate diff witnesses during differentiation (#28402)
1 parent 46b7a58 commit b41adcd

File tree

7 files changed

+172
-6
lines changed

7 files changed

+172
-6
lines changed

lib/SIL/SILFunctionBuilder.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,15 +93,25 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
9393
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
9494
(void)A->getParameterIndices();
9595
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
96+
auto &ctx = F->getASTContext();
9697
// Get lowered argument indices.
9798
auto *paramIndices = A->getParameterIndices();
9899
assert(paramIndices && "Parameter indices should have been resolved");
99100
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
100101
paramIndices, decl->getInterfaceType()->castTo<AnyFunctionType>());
102+
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
103+
// parameters corresponding to captured variables. These parameters do not
104+
// appear in the type of `origFnType`.
105+
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
106+
// take `CaptureInfo` into account.
107+
auto origSilFnType = F->getLoweredFunctionType();
108+
if (origSilFnType->getNumParameters() >
109+
loweredParamIndices->getCapacity())
110+
loweredParamIndices = loweredParamIndices->extendingCapacity(
111+
ctx, origSilFnType->getNumParameters());
101112
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
102113
// Get JVP/VJP names.
103114
std::string jvpName, vjpName;
104-
auto &ctx = F->getASTContext();
105115
if (auto *jvpFn = A->getJVPFunction()) {
106116
Mangle::ASTMangler mangler;
107117
jvpName = ctx.getIdentifier(

lib/SIL/SILPrinter.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2821,8 +2821,8 @@ static void printSILDifferentiabilityWitnesses(
28212821
[] (const SILDifferentiabilityWitness *w1,
28222822
const SILDifferentiabilityWitness *w2) -> bool {
28232823
// TODO(TF-893): Sort based on more criteria for deterministic ordering.
2824-
return w1->getOriginalFunction()->getName()
2825-
.compare(w2->getOriginalFunction()->getName());
2824+
return w1->getOriginalFunction()->getName().compare(
2825+
w2->getOriginalFunction()->getName()) == -1;
28262826
}
28272827
);
28282828
for (auto *dw : sortedDiffWitnesses)

lib/SIL/SILVerifier.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5405,6 +5405,11 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
54055405
if (!M.getOptions().VerifyAll)
54065406
return;
54075407
#endif
5408+
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
5409+
// TODO: Check that derivative function types match excluding
5410+
// parameter/result conventions in lowered SIL.
5411+
if (M.getStage() == SILStage::Lowered)
5412+
return;
54085413
auto origFnType = getOriginalFunction()->getLoweredFunctionType();
54095414
CanGenericSignature derivativeCanGenSig;
54105415
if (auto derivativeGenSig = getDerivativeGenericSignature())

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,57 @@ static Inst *peerThroughFunctionConversions(SILValue value) {
342342
return nullptr;
343343
}
344344

345+
/// Finds the differentiability witness corresponding to `attr` in `module`.
346+
static SILDifferentiabilityWitness *
347+
findDifferentiabilityWitness(SILModule &module, SILDifferentiableAttr *attr) {
348+
auto *resultIndices =
349+
IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source});
350+
AutoDiffConfig config(attr->getIndices().parameters, resultIndices,
351+
attr->getDerivativeGenericSignature());
352+
return module.lookUpDifferentiabilityWitness(
353+
{attr->getOriginal()->getName(), config});
354+
}
355+
356+
/// Sets the differentiability witness JVP and VJP to the JVP and VJP in `attr`.
357+
///
358+
/// `attr` must have a JVP and VJP.
359+
static void
360+
canonicalizeDifferentiabilityWitness(SILModule &module,
361+
const SILDifferentiableAttr *attr,
362+
SILDifferentiabilityWitness *witness) {
363+
auto jvpName = attr->getJVPName();
364+
assert(!jvpName.empty() && "Expected JVP name");
365+
auto *jvpFn = module.lookUpFunction(attr->getJVPName());
366+
assert(jvpFn && "Expected JVP function");
367+
assert(!witness->getJVP() ||
368+
witness->getJVP() == jvpFn && "Pass trying to change witness jvp");
369+
witness->setJVP(jvpFn);
370+
auto vjpName = attr->getVJPName();
371+
assert(!vjpName.empty() && "Expected VJP name");
372+
auto *vjpFn = module.lookUpFunction(attr->getVJPName());
373+
assert(vjpFn && "Expected VJP function");
374+
assert(!witness->getVJP() ||
375+
witness->getVJP() == vjpFn && "Pass trying to change witness vjp");
376+
witness->setVJP(vjpFn);
377+
}
378+
379+
/// Creates a differentiability witness definition corresponding to `attr` in
380+
/// `module`.
381+
///
382+
/// `attr` must have a JVP and VJP.
383+
static SILDifferentiabilityWitness *
384+
createDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
385+
SILDifferentiableAttr *attr) {
386+
auto *resultIndices =
387+
IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source});
388+
auto *witness = SILDifferentiabilityWitness::createDefinition(
389+
module, linkage, attr->getOriginal(), attr->getIndices().parameters,
390+
resultIndices, attr->getDerivativeGenericSignature(), /*jvp*/ nullptr,
391+
/*vjp*/ nullptr, /*isSerialized*/ false);
392+
canonicalizeDifferentiabilityWitness(module, attr, witness);
393+
return witness;
394+
}
395+
345396
//===----------------------------------------------------------------------===//
346397
// Auxiliary data structures
347398
//===----------------------------------------------------------------------===//
@@ -2703,6 +2754,8 @@ emitDerivativeFunctionReference(
27032754
originalFn, desiredIndices, contextualDerivativeGenSig);
27042755
if (context.processDifferentiableAttribute(originalFn, newAttr, invoker))
27052756
return None;
2757+
createDifferentiabilityWitness(context.getModule(), SILLinkage::Hidden,
2758+
newAttr);
27062759
minimalAttr = newAttr;
27072760
}
27082761
assert(minimalAttr);
@@ -8945,8 +8998,23 @@ void Differentiation::run() {
89458998
auto *attr = invokerPair.first;
89468999
auto *original = attr->getOriginal();
89479000
auto invoker = invokerPair.second;
8948-
errorOccurred |=
8949-
context.processDifferentiableAttribute(original, attr, invoker);
9001+
9002+
if (context.processDifferentiableAttribute(original, attr, invoker)) {
9003+
errorOccurred = true;
9004+
continue;
9005+
}
9006+
9007+
// External function witnesses are defined externally, so we don't need to
9008+
// define them here.
9009+
if (original->isExternalDeclaration())
9010+
continue;
9011+
9012+
auto *witness = findDifferentiabilityWitness(module, attr);
9013+
assert(
9014+
witness &&
9015+
"SILGen should create a witness for every [differentiable] attribute");
9016+
assert(witness->isDefinition());
9017+
canonicalizeDifferentiabilityWitness(module, attr, witness);
89509018
}
89519019

89529020
// Iteratively process `differentiable_function` instruction worklist.

test/AutoDiff/differentiable_function_inst_irgen.sil

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ sil_stage raw
55
import Swift
66
import Builtin
77

8+
sil_differentiability_witness hidden [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
9+
vjp: @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
10+
}
11+
812
// The adjoint function emitted by the compiler. Parameters are a vector, as in
913
// vector-Jacobian products, and pullback struct value. The function is not
1014
// itself a pullback, but to be partially applied to form a pullback, which
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// RUN: %target-swift-frontend -emit-sil -emit-sorted-sil %s | %FileCheck %s
2+
3+
// MARK: - Public functions
4+
5+
@differentiable
6+
@_silgen_name("f000_invokedDirectlyByDifferentiableAttrPublic")
7+
public func f000_invokedDirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
8+
return f001_invokedIndirectlyByDifferentiableAttrPublic(x)
9+
}
10+
// CHECK-LABEL: sil_differentiability_witness [parameters 0] [results 0] @f000_invokedDirectlyByDifferentiableAttrPublic
11+
// CHECK-NEXT: jvp
12+
// CHECK-NEXT: vjp
13+
14+
@_silgen_name("f001_invokedIndirectlyByDifferentiableAttrPublic")
15+
public func f001_invokedIndirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
16+
return x
17+
}
18+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f001_invokedIndirectlyByDifferentiableAttrPublic
19+
// CHECK-NEXT: jvp
20+
// CHECK-NEXT: vjp
21+
22+
@_silgen_name("f002_invokedDirectlyByConversionPublic")
23+
public func f002_invokedDirectlyByConversionPublic(_ x: Float) -> Float {
24+
return f003_invokedIndirectlyByConversionPublic(x)
25+
}
26+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f002_invokedDirectlyByConversionPublic
27+
// CHECK-NEXT: jvp
28+
// CHECK-NEXT: vjp
29+
30+
@_silgen_name("f003_invokedIndirectlyByConversionPublic")
31+
public func f003_invokedIndirectlyByConversionPublic(_ x: Float) -> Float {
32+
return x
33+
}
34+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f003_invokedIndirectlyByConversionPublic
35+
// CHECK-NEXT: jvp
36+
// CHECK-NEXT: vjp
37+
38+
// MARK: - Internal functions
39+
40+
@differentiable
41+
@_silgen_name("f004_invokedDirectlyByDifferentiableAttrInternal")
42+
internal func f004_invokedDirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
43+
return f005_invokedIndirectlyByDifferentiableAttrInternal(x)
44+
}
45+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f004_invokedDirectlyByDifferentiableAttrInternal
46+
// CHECK-NEXT: jvp
47+
// CHECK-NEXT: vjp
48+
49+
@_silgen_name("f005_invokedIndirectlyByDifferentiableAttrInternal")
50+
internal func f005_invokedIndirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
51+
return x
52+
}
53+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f005_invokedIndirectlyByDifferentiableAttrInternal
54+
// CHECK-NEXT: jvp
55+
// CHECK-NEXT: vjp
56+
57+
@_silgen_name("f006_invokedDirectlyByConversionInternal")
58+
internal func f006_invokedDirectlyByConversionInternal(_ x: Float) -> Float {
59+
return f007_invokedIndirectlyByConversionInternal(x)
60+
}
61+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f006_invokedDirectlyByConversionInternal
62+
// CHECK-NEXT: jvp
63+
// CHECK-NEXT: vjp
64+
65+
@_silgen_name("f007_invokedIndirectlyByConversionInternal")
66+
internal func f007_invokedIndirectlyByConversionInternal(_ x: Float) -> Float {
67+
return x
68+
}
69+
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f007_invokedIndirectlyByConversionInternal
70+
// CHECK-NEXT: jvp
71+
// CHECK-NEXT: vjp
72+
73+
func invokesByConversion() -> Float {
74+
var result: Float = 0
75+
result += gradient(at: 0, in: f002_invokedDirectlyByConversionPublic)
76+
result += gradient(at: 0, in: f006_invokedDirectlyByConversionInternal)
77+
return result
78+
}

test/AutoDiff/subset_parameters_thunk.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float {
1414
foo(x, 1)
1515
}
1616

17-
// CHECK-LABEL: @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp
17+
// Intentional "//" in the label so that this doesn't match a [differentiable] attr pointing at the vjp.
18+
// CHECK-LABEL: // {{.*}}differentiate_foo_wrt_0{{.*}}__vjp
1819
// CHECK: bb0
1920
// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
2021
// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0

0 commit comments

Comments
 (0)