Skip to content

[AutoDiff] Use @_fieldwiseDifferentiable on struct decls and deprecate @_fieldwiseProductSpace. #22009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph,
OnFunc, 82)
SIMPLE_DECL_ATTR(TFParameter, TFParameter,
OnVar, 83)
SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace,
OnTypeAlias | OnNominalType | UserInaccessible, 84)
SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable,
OnNominalType | UserInaccessible, 84)
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
OnVar, 85)

Expand Down
15 changes: 12 additions & 3 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2742,6 +2742,11 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
"@noDerivative is only allowed on stored properties in structure types "
"that declare a conformance to 'Differentiable'", ())

// @_fieldwiseDifferentiable attribute
ERROR(fieldwise_differentiable_only_on_differentiable_structs,none,
"@_fieldwiseDifferentiable is only allowed on structure types that "
"conform to 'Differentiable'", ())

//------------------------------------------------------------------------------
// MARK: Type Check Expressions
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3655,11 +3660,15 @@ ERROR(unreferenced_generic_parameter,none,
// SWIFT_ENABLE_TENSORFLOW
// Function differentiability
ERROR(autodiff_attr_argument_not_differentiable,none,
"argument is not differentiable, but the enclosing function type is marked '@autodiff'; did you want to add '@nondiff' to this argument?", ())
"argument is not differentiable, but the enclosing function type is "
"marked '@autodiff'; did you want to add '@nondiff' to this argument?",
())
ERROR(autodiff_attr_result_not_differentiable,none,
"result is not differentiable, but the function type is marked '@autodiff'", ())
"result is not differentiable, but the function type is marked "
"'@autodiff'", ())
ERROR(nondiff_attr_invalid_on_nondifferentiable_function,none,
"'nondiff' cannot be applied to arguments of a non-differentiable function", ())
"'nondiff' cannot be applied to arguments of a non-differentiable "
"function", ())

// SIL
ERROR(opened_non_protocol,none,
Expand Down
183 changes: 60 additions & 123 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,6 @@ static void createEntryArguments(SILFunction *f) {
}
}

/// Looks up a function in the current module. If it exists, returns it.
/// Otherwise, attempt to link it from imported modules. Returns null if such
/// function name does not exist.
static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) {
assert(!name.empty());
if (auto *localFn = module.lookUpFunction(name))
return localFn;
return module.findFunction(name, SILLinkage::PublicExternal);
}

/// Computes the correct linkage for functions generated by the AD pass
/// associated with a function with linkage `originalLinkage`.
static SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage) {
Expand Down Expand Up @@ -528,7 +518,7 @@ enum class StructExtractDifferentiationStrategy {
// that is zero except along the direction of the corresponding field.
//
// Fields correspond by matching name.
FieldwiseProductSpace,
Fieldwise,

// Differentiate the `struct_extract` by looking up the corresponding getter
// and using its VJP.
Expand Down Expand Up @@ -1291,6 +1281,7 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
if (isVaried(cai->getSrc(), i))
recursivelySetVariedIfDifferentiable(cai->getDest(), i);
}
// Handle `struct_extract`.
else if (auto *sei = dyn_cast<StructExtractInst>(&inst)) {
if (isVaried(sei->getOperand(), i)) {
auto hasNoDeriv = sei->getField()->getAttrs()
Expand Down Expand Up @@ -2091,46 +2082,30 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
}

void visitStructExtractInst(StructExtractInst *sei) {
auto &astCtx = getContext().getASTContext();
auto &structExtractDifferentiationStrategies =
auto &strategies =
getDifferentiationTask()->getStructExtractDifferentiationStrategies();

// Special handling logic only applies when the `struct_extract` is active.
// If not, just do standard cloning.
if (!activityInfo.isActive(sei, synthesis.indices)) {
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
structExtractDifferentiationStrategies.insert(
strategies.insert(
{sei, StructExtractDifferentiationStrategy::Inactive});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}

// This instruction is active. Determine the appropriate differentiation
// strategy, and use it.

// Use the FieldwiseProductSpace strategy, if appropriate.
auto *structDecl = sei->getStructDecl();
auto cotangentDeclLookup =
structDecl->lookupDirect(astCtx.Id_CotangentVector);
if (cotangentDeclLookup.size() >= 1) {
assert(cotangentDeclLookup.size() == 1);
auto cotangentTypeDecl = cotangentDeclLookup.front();
assert(isa<TypeAliasDecl>(cotangentTypeDecl) ||
isa<StructDecl>(cotangentTypeDecl));
if (cotangentTypeDecl->getAttrs()
.hasAttribute<FieldwiseProductSpaceAttr>()) {
structExtractDifferentiationStrategies.insert(
{sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}
if (structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
strategies.insert(
{sei, StructExtractDifferentiationStrategy::Fieldwise});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}

// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
// strategy.
structExtractDifferentiationStrategies.insert(
strategies.insert(
{sei, StructExtractDifferentiationStrategy::Getter});

// Find the corresponding getter and its VJP.
auto *getterDecl = sei->getField()->getGetter();
assert(getterDecl);
Expand All @@ -2142,42 +2117,29 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
errorOccurred = true;
return;
}
auto getterDiffAttrs = getterFn->getDifferentiableAttrs();
if (getterDiffAttrs.size() < 1) {
getContext().emitNondifferentiabilityError(
sei, synthesis.task, diag::autodiff_property_not_differentiable);
errorOccurred = true;
return;
}
auto *getterDiffAttr = getterDiffAttrs[0];
if (!getterDiffAttr->hasVJP()) {
SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0});
auto *task = getContext().lookUpDifferentiationTask(getterFn, indices);
if (!task) {
getContext().emitNondifferentiabilityError(
sei, synthesis.task, diag::autodiff_property_not_differentiable);
errorOccurred = true;
return;
}
assert(getterDiffAttr->getIndices() ==
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0}));
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(),
getContext().getModule());

// Reference and apply the VJP.
auto loc = sei->getLoc();
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
auto *getterVJPRef = getBuilder().createFunctionRef(loc, task->getVJP());
auto *getterVJPApply = getBuilder().createApply(
loc, getterVJPRef, /*substitutionMap*/ {},
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);
SmallVector<SILValue, 8> vjpDirectResults;
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
ArrayRef<SILValue> originalDirectResults =
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);

// Map original results.
auto originalDirectResults =
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
SILValue originalDirectResult = joinElements(originalDirectResults,
getBuilder(),
getterVJPApply->getLoc());
mapValue(sei, originalDirectResult);

// Checkpoint the pullback.
SILValue pullback = vjpDirectResults.back();
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType());
Expand Down Expand Up @@ -3079,60 +3041,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
auto loc = remapLocation(sei->getLoc());
auto &differentiationStrategies =
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
auto differentiationStrategyLookUp = differentiationStrategies.find(sei);
assert(differentiationStrategyLookUp != differentiationStrategies.end());
auto differentiationStrategy = differentiationStrategyLookUp->second;

if (differentiationStrategy ==
StructExtractDifferentiationStrategy::Inactive) {
auto strategy = differentiationStrategies.lookup(sei);
switch (strategy) {
case StructExtractDifferentiationStrategy::Inactive:
assert(!activityInfo.isActive(sei, synthesis.indices));
return;
}

if (differentiationStrategy ==
StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
case StructExtractDifferentiationStrategy::Fieldwise: {
// Compute adjoint as follows:
// y = struct_extract <key>, x
// adj[x] = struct (0, ..., key': adj[y], ..., 0)
// where `key'` is the field in the cotangent space corresponding to
// `key`.

// Find the decl of the cotangent space type.
auto structTy = sei->getOperand()->getType().getASTType();
auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent,
LookUpConformanceInModule(getModule().getSwiftModule()))
->getType()->getCanonicalType();
assert(!getModule()
.Types.getTypeLowering(cotangentVectorTy)
.isAddressOnly());
->getType()->getCanonicalType();
assert(!getModule().Types.getTypeLowering(cotangentVectorTy)
.isAddressOnly());
auto cotangentVectorSILTy =
SILType::getPrimitiveObjectType(cotangentVectorTy);
auto *cotangentVectorDecl =
cotangentVectorTy->getStructOrBoundGenericStruct();
assert(cotangentVectorDecl);

// Find the corresponding field in the cotangent space.
VarDecl *correspondingField = nullptr;
// If the cotangent space is the original sapce, then it's the same field.
if (cotangentVectorDecl == sei->getStructDecl())
correspondingField = sei->getField();
// Otherwise we just look it up by name.
else {
auto correspondingFieldLookup =
cotangentVectorDecl->lookupDirect(sei->getField()->getName());
assert(correspondingFieldLookup.size() == 1);
assert(isa<VarDecl>(correspondingFieldLookup[0]));
correspondingField = cast<VarDecl>(correspondingFieldLookup[0]);
correspondingField = cast<VarDecl>(correspondingFieldLookup.front());
}
assert(correspondingField);

#ifndef NDEBUG
unsigned numMatchingStoredProperties = 0;
for (auto *storedProperty : cotangentVectorDecl->getStoredProperties())
if (storedProperty == correspondingField)
numMatchingStoredProperties += 1;
assert(numMatchingStoredProperties == 1);
#endif

// Compute adjoint.
auto av = getAdjointValue(sei);
switch (av.getKind()) {
Expand All @@ -3148,44 +3091,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
eltVals.push_back(av);
else
eltVals.push_back(AdjointValue::getZero(
SILType::getPrimitiveObjectType(field->getType()
->getCanonicalType())));
SILType::getPrimitiveObjectType(
field->getType()->getCanonicalType())));
}
addAdjointValue(sei->getOperand(),
AdjointValue::getAggregate(cotangentVectorSILTy,
eltVals, allocator));
}
}

return;
}

// The only remaining strategy is the getter strategy.
// Replace the `struct_extract` with a call to its pullback.
assert(differentiationStrategy ==
StructExtractDifferentiationStrategy::Getter);

// Get the pullback.
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
assert(pullbackField);
SILValue pullback = builder.createStructExtract(loc,
primalValueAggregateInAdj,
pullbackField);

// Construct the pullback arguments.
SmallVector<SILValue, 8> args;
auto seed = getAdjointValue(sei);
assert(seed.getType().isObject());
args.push_back(materializeAdjointDirect(seed, loc));

// Call the pullback.
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
args, /*isNonThrowing*/ false);
assert(!pullbackCall->hasIndirectResults());

// Set adjoint for the `struct_extract` operand.
addAdjointValue(sei->getOperand(),
AdjointValue::getMaterialized(pullbackCall));
case StructExtractDifferentiationStrategy::Getter: {
// Get the pullback.
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
assert(pullbackField);
SILValue pullback = builder.createStructExtract(loc,
primalValueAggregateInAdj,
pullbackField);

// Construct the pullback arguments.
SmallVector<SILValue, 8> args;
auto seed = getAdjointValue(sei);
assert(seed.getType().isObject());
args.push_back(materializeAdjointDirect(seed, loc));

// Call the pullback.
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
args, /*isNonThrowing*/ false);
assert(!pullbackCall->hasIndirectResults());

// Set adjoint for the `struct_extract` operand.
addAdjointValue(sei->getOperand(),
AdjointValue::getMaterialized(pullbackCall));
break;
}
}
}

/// Handle `tuple` instruction.
Expand Down Expand Up @@ -4236,25 +4176,22 @@ void DifferentiationTask::createVJP() {
loc, adjointRef, vjpSubstMap, partialAdjointArgs,
ParameterConvention::Direct_Guaranteed);

// === Clean up the stack allocations. ===
// Clean up the stack allocations.
for (auto alloc : reversed(stackAllocsToCleanUp))
builder.createDeallocStack(loc, alloc);

// === Return the direct results. ===
// (Note that indirect results have already been filled in by the application
// of the primal).
// Return the direct results. Note that indirect results have already been
// filled in by the application of the primal.
SmallVector<SILValue, 8> directResults;
auto originalDirectResults = ArrayRef<SILValue>(primalDirectResults)
.take_back(originalConv.getNumDirectSILResults());
for (auto originalDirectResult : originalDirectResults)
directResults.push_back(originalDirectResult);
directResults.push_back(adjointPartialApply);
if (directResults.size() > 1) {
auto tupleRet = builder.createTuple(loc, directResults);
builder.createReturn(loc, tupleRet);
} else {
builder.createReturn(loc, directResults[0]);
}
if (directResults.size() > 1)
builder.createReturn(loc, builder.createTuple(loc, directResults));
else
builder.createReturn(loc, directResults.front());
}

//===----------------------------------------------------------------------===//
Expand Down
Loading