Skip to content

[AutoDiff] populate diff witnesses during differentiation #28402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion lib/SIL/SILFunctionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,25 @@ void SILFunctionBuilder::addFunctionAttributes(SILFunction *F,
for (auto *A : Attrs.getAttributes<DifferentiableAttr>())
(void)A->getParameterIndices();
for (auto *A : Attrs.getAttributes<DifferentiableAttr>()) {
auto &ctx = F->getASTContext();
// Get lowered argument indices.
auto *paramIndices = A->getParameterIndices();
assert(paramIndices && "Parameter indices should have been resolved");
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
paramIndices, decl->getInterfaceType()->castTo<AnyFunctionType>());
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the only case that triggers this "capacity extending due to captured variables" logic is the @differentiable attribute on func original nested in func differentiableFunction(from:)?

Full context: there are known issues regarding differentiation and local variable capture (TF-881). @rxwei mentioned disallowing @differentiable attribute on nested functions for now and creating a builtin to support func differentiableFunction(from:). One known user of differentiableFunction(from:) is the custom differentiation tutorial.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the only thing (in the stdlib + swift-apis + tests, at least) that triggers this problem.

It would indeed be nice to remove that and forbid @differentiable on nested functions for now.

// parameters corresponding to captured variables. These parameters do not
// appear in the type of `origFnType`.
// TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
// take `CaptureInfo` into account.
auto origSilFnType = F->getLoweredFunctionType();
if (origSilFnType->getNumParameters() >
loweredParamIndices->getCapacity())
loweredParamIndices = loweredParamIndices->extendingCapacity(
ctx, origSilFnType->getNumParameters());
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
// Get JVP/VJP names.
std::string jvpName, vjpName;
auto &ctx = F->getASTContext();
if (auto *jvpFn = A->getJVPFunction()) {
Mangle::ASTMangler mangler;
jvpName = ctx.getIdentifier(
Expand Down
4 changes: 2 additions & 2 deletions lib/SIL/SILPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2821,8 +2821,8 @@ static void printSILDifferentiabilityWitnesses(
[] (const SILDifferentiabilityWitness *w1,
const SILDifferentiabilityWitness *w2) -> bool {
// TODO(TF-893): Sort based on more criteria for deterministic ordering.
return w1->getOriginalFunction()->getName()
.compare(w2->getOriginalFunction()->getName());
return w1->getOriginalFunction()->getName().compare(
w2->getOriginalFunction()->getName()) == -1;
}
);
for (auto *dw : sortedDiffWitnesses)
Expand Down
5 changes: 5 additions & 0 deletions lib/SIL/SILVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5405,6 +5405,11 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
if (!M.getOptions().VerifyAll)
return;
#endif
// Skip lowered SIL: LoadableByAddress changes parameter/result conventions.
// TODO: Check that derivative function types match excluding
// parameter/result conventions in lowered SIL.
if (M.getStage() == SILStage::Lowered)
return;
auto origFnType = getOriginalFunction()->getLoweredFunctionType();
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = getDerivativeGenericSignature())
Expand Down
72 changes: 70 additions & 2 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,57 @@ static Inst *peerThroughFunctionConversions(SILValue value) {
return nullptr;
}

/// Finds the differentiability witness corresponding to `attr` in `module`.
static SILDifferentiabilityWitness *
findDifferentiabilityWitness(SILModule &module, SILDifferentiableAttr *attr) {
auto *resultIndices =
IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source});
AutoDiffConfig config(attr->getIndices().parameters, resultIndices,
attr->getDerivativeGenericSignature());
return module.lookUpDifferentiabilityWitness(
{attr->getOriginal()->getName(), config});
}

/// Sets the differentiability witness JVP and VJP to the JVP and VJP in `attr`.
///
/// `attr` must have a JVP and VJP.
static void
canonicalizeDifferentiabilityWitness(SILModule &module,
const SILDifferentiableAttr *attr,
SILDifferentiabilityWitness *witness) {
auto jvpName = attr->getJVPName();
assert(!jvpName.empty() && "Expected JVP name");
auto *jvpFn = module.lookUpFunction(attr->getJVPName());
assert(jvpFn && "Expected JVP function");
assert(!witness->getJVP() ||
witness->getJVP() == jvpFn && "Pass trying to change witness jvp");
witness->setJVP(jvpFn);
auto vjpName = attr->getVJPName();
assert(!vjpName.empty() && "Expected VJP name");
auto *vjpFn = module.lookUpFunction(attr->getVJPName());
assert(vjpFn && "Expected VJP function");
assert(!witness->getVJP() ||
witness->getVJP() == vjpFn && "Pass trying to change witness vjp");
witness->setVJP(vjpFn);
}

/// Creates a differentiability witness definition corresponding to `attr` in
/// `module`.
///
/// `attr` must have a JVP and VJP.
static SILDifferentiabilityWitness *
createDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
SILDifferentiableAttr *attr) {
auto *resultIndices =
IndexSubset::get(module.getASTContext(), 1, {attr->getIndices().source});
auto *witness = SILDifferentiabilityWitness::createDefinition(
module, linkage, attr->getOriginal(), attr->getIndices().parameters,
resultIndices, attr->getDerivativeGenericSignature(), /*jvp*/ nullptr,
/*vjp*/ nullptr, /*isSerialized*/ false);
canonicalizeDifferentiabilityWitness(module, attr, witness);
return witness;
}

//===----------------------------------------------------------------------===//
// Auxiliary data structures
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2703,6 +2754,8 @@ emitDerivativeFunctionReference(
originalFn, desiredIndices, contextualDerivativeGenSig);
if (context.processDifferentiableAttribute(originalFn, newAttr, invoker))
return None;
createDifferentiabilityWitness(context.getModule(), SILLinkage::Hidden,
newAttr);
minimalAttr = newAttr;
}
assert(minimalAttr);
Expand Down Expand Up @@ -8945,8 +8998,23 @@ void Differentiation::run() {
auto *attr = invokerPair.first;
auto *original = attr->getOriginal();
auto invoker = invokerPair.second;
errorOccurred |=
context.processDifferentiableAttribute(original, attr, invoker);

if (context.processDifferentiableAttribute(original, attr, invoker)) {
errorOccurred = true;
continue;
}

// External function witnesses are defined externally, so we don't need to
// define them here.
if (original->isExternalDeclaration())
continue;

auto *witness = findDifferentiabilityWitness(module, attr);
assert(
witness &&
"SILGen should create a witness for every [differentiable] attribute");
assert(witness->isDefinition());
canonicalizeDifferentiabilityWitness(module, attr, witness);
}

// Iteratively process `differentiable_function` instruction worklist.
Expand Down
4 changes: 4 additions & 0 deletions test/AutoDiff/differentiable_function_inst_irgen.sil
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ sil_stage raw
import Swift
import Builtin

sil_differentiability_witness hidden [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
vjp: @foo_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
}

// The adjoint function emitted by the compiler. Parameters are a vector, as in
// vector-Jacobian products, and pullback struct value. The function is not
// itself a pullback, but to be partially applied to form a pullback, which
Expand Down
78 changes: 78 additions & 0 deletions test/AutoDiff/pass_creates_differentiability_witnesses.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: %target-swift-frontend -emit-sil -emit-sorted-sil %s | %FileCheck %s

// MARK: - Public functions

@differentiable
@_silgen_name("f000_invokedDirectlyByDifferentiableAttrPublic")
public func f000_invokedDirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
return f001_invokedIndirectlyByDifferentiableAttrPublic(x)
}
// CHECK-LABEL: sil_differentiability_witness [parameters 0] [results 0] @f000_invokedDirectlyByDifferentiableAttrPublic
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f001_invokedIndirectlyByDifferentiableAttrPublic")
public func f001_invokedIndirectlyByDifferentiableAttrPublic(_ x: Float) -> Float {
return x
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f001_invokedIndirectlyByDifferentiableAttrPublic
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f002_invokedDirectlyByConversionPublic")
public func f002_invokedDirectlyByConversionPublic(_ x: Float) -> Float {
return f003_invokedIndirectlyByConversionPublic(x)
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f002_invokedDirectlyByConversionPublic
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f003_invokedIndirectlyByConversionPublic")
public func f003_invokedIndirectlyByConversionPublic(_ x: Float) -> Float {
return x
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f003_invokedIndirectlyByConversionPublic
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

// MARK: - Internal functions

@differentiable
@_silgen_name("f004_invokedDirectlyByDifferentiableAttrInternal")
internal func f004_invokedDirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
return f005_invokedIndirectlyByDifferentiableAttrInternal(x)
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f004_invokedDirectlyByDifferentiableAttrInternal
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f005_invokedIndirectlyByDifferentiableAttrInternal")
internal func f005_invokedIndirectlyByDifferentiableAttrInternal(_ x: Float) -> Float {
return x
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f005_invokedIndirectlyByDifferentiableAttrInternal
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f006_invokedDirectlyByConversionInternal")
internal func f006_invokedDirectlyByConversionInternal(_ x: Float) -> Float {
return f007_invokedIndirectlyByConversionInternal(x)
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f006_invokedDirectlyByConversionInternal
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

@_silgen_name("f007_invokedIndirectlyByConversionInternal")
internal func f007_invokedIndirectlyByConversionInternal(_ x: Float) -> Float {
return x
}
// CHECK-LABEL: sil_differentiability_witness hidden [parameters 0] [results 0] @f007_invokedIndirectlyByConversionInternal
// CHECK-NEXT: jvp
// CHECK-NEXT: vjp

func invokesByConversion() -> Float {
var result: Float = 0
result += gradient(at: 0, in: f002_invokedDirectlyByConversionPublic)
result += gradient(at: 0, in: f006_invokedDirectlyByConversionInternal)
return result
}
3 changes: 2 additions & 1 deletion test/AutoDiff/subset_parameters_thunk.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ func differentiate_foo_wrt_0(_ x: Float) -> Float {
foo(x, 1)
}

// CHECK-LABEL: @{{.*}}differentiate_foo_wrt_0{{.*}}__vjp
// Intentional "//" in the label so that this doesn't match a [differentiable] attr pointing at the vjp.
// CHECK-LABEL: // {{.*}}differentiate_foo_wrt_0{{.*}}__vjp
// CHECK: bb0
// CHECK: [[FOO_ORIG:%.*]] = function_ref @{{.*}}foo{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
// CHECK: [[FOO_FLOAT:%.*]] = partial_apply [callee_guaranteed] [[FOO_ORIG]]<Float>() : $@convention(thin) <τ_0_0 where τ_0_0 : Numeric> (@in_guaranteed τ_0_0, @in_guaranteed τ_0_0) -> @out τ_0_0
Expand Down