Skip to content

[AutoDiff] NFC: gardening. #28251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 63 additions & 54 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
return nullptr;
}

/// Given a value, find its single `destructure_tuple` user if the value is
/// Given a value, finds its single `destructure_tuple` user if the value is
/// tuple-typed and such a user exists.
static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
bool foundDestructureTupleUser = false;
Expand Down Expand Up @@ -130,7 +130,7 @@ static void forEachApplyDirectResult(
resultCallback(result);
}

/// Given a function, gather all of its formal results (both direct and
/// Given a function, gathers all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
/// refer to result values in the body of the function, not at call sites.
static void
Expand All @@ -154,7 +154,7 @@ collectAllFormalResultsInTypeOrder(SILFunction &function,
: indResults[indResIdx++]);
}

/// Given a function, gather all of its direct results in an order defined by
/// Given a function, gathers all of its direct results in an order defined by
/// its result type. Note that "formal results" refer to result values in the
/// body of the function, not at call sites.
static void
Expand All @@ -171,7 +171,7 @@ collectAllDirectResultsInTypeOrder(SILFunction &function,
results.push_back(retVal);
}

/// Given a function call site, gather all of its actual results (both direct
/// Given a function call site, gathers all of its actual results (both direct
/// and indirect) in an order defined by its result type.
static void collectAllActualResultsInTypeOrder(
ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
Expand Down Expand Up @@ -291,7 +291,7 @@ static GenericParamList *cloneGenericParameters(ASTContext &ctx,
return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc());
}

/// Given an `differentiable_function` instruction, find the corresponding
/// Given an `differentiable_function` instruction, finds the corresponding
/// differential operator used in the AST. If no differential operator is found,
/// return nullptr.
static DifferentiableFunctionExpr *
Expand Down Expand Up @@ -412,15 +412,23 @@ struct DifferentiationInvoker {

class DifferentiableActivityInfo;

/// Information about the JVP/VJP function produced during JVP/VJP generation,
/// e.g. mappings from original values to corresponding values in the
/// pullback/differential struct.
/// Linear map struct and branching trace enum information for an original
/// function and and derivative function (JVP or VJP).
///
/// A linear map struct is an aggregate value containing linear maps checkpointed
/// during the JVP/VJP computation. Linear map structs are generated for every
/// original function during JVP/VJP generation. Linear map struct values are
/// constructed by JVP/VJP functions and consumed by pullback/differential
/// functions.
/// Linear map structs contain all callee linear maps produced in a JVP/VJP
/// basic block. A linear map struct is created for each basic block in the
/// original function, and a linear map struct field is created for every active
/// `apply` in the original basic block.
///
/// Branching trace enums model the control flow graph of the original function.
/// A branching trace enum is created for each basic block in the original
/// function, and a branching trace enum case is created for every basic block
/// predecessor/successor. This supports control flow differentiation: JVP/VJP
/// functions build branching trace enums to record an execution trace. Indirect
/// branching trace enums are created for basic blocks that are in loops.
///
/// Linear map struct values and branching trace enum values are constructed in
/// JVP/VJP functions and consumed in pullback/differential functions.
class LinearMapInfo {
private:
/// The linear map kind.
Expand All @@ -446,13 +454,12 @@ class LinearMapInfo {
/// For differentials: these are successor enums.
DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;

/// Mapping from `apply` and `struct_extract` instructions in the original
/// function to the corresponding linear map declaration in the linear map
/// struct.
DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;
/// Mapping from `apply` instructions in the original function to the
/// corresponding linear map field declaration in the linear map struct.
DenseMap<ApplyInst *, VarDecl *> linearMapFieldMap;

/// Mapping from predecessor+succcessor basic block pairs in original function
/// to the corresponding branching trace enum case.
/// Mapping from predecessor-succcessor basic block pairs in the original
/// function to the corresponding branching trace enum case.
DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
branchingTraceEnumCases;

Expand Down Expand Up @@ -505,7 +512,7 @@ class LinearMapInfo {
llvm_unreachable("No files?");
}

/// Compute and set the access level for the given nominal type, given the
/// Computes and sets the access level for the given nominal type, given the
/// original function linkage.
void computeAccessLevel(
NominalTypeDecl *nominal, SILLinkage originalLinkage) {
Expand Down Expand Up @@ -661,8 +668,8 @@ class LinearMapInfo {
return linearMapStruct;
}

/// Add a linear map to the linear map struct.
VarDecl *addLinearMapDecl(SILInstruction *inst, SILType linearMapType) {
/// Adds a linear map field to the linear map struct.
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
// IRGen requires decls to have AST types (not `SILFunctionType`), so we
// convert the `SILFunctionType` of the linear map to a `FunctionType` with
// the same parameters and results.
Expand All @@ -678,28 +685,28 @@ class LinearMapInfo {
astFnTy = FunctionType::get(
params, silFnTy->getAllResultsInterfaceType().getASTType());

auto *origBB = inst->getParent();
auto *origBB = ai->getParent();
auto *linMapStruct = getLinearMapStruct(origBB);
std::string linearMapName;
switch (kind) {
case AutoDiffLinearMapKind::Differential:
linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size());
linearMapName = "differential_" + llvm::itostr(linearMapFieldMap.size());
break;
case AutoDiffLinearMapKind::Pullback:
linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size());
linearMapName = "pullback_" + llvm::itostr(linearMapFieldMap.size());
break;
}
auto *linearMapDecl = addVarDecl(linMapStruct, linearMapName, astFnTy);
linearMapValueMap.insert({inst, linearMapDecl});
linearMapFieldMap.insert({ai, linearMapDecl});
return linearMapDecl;
}

/// Given an `apply` instruction, conditionally adds its linear map function
/// to the linear map struct if it is active.
/// Given an `apply` instruction, conditionally adds a linear map struct field
/// for its linear map function if it is active.
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
const SILAutoDiffIndices &indices);

/// Generate linear map struct and branching enum declarations for the given
/// Generates linear map struct and branching enum declarations for the given
/// function. Linear map structs are populated with linear map fields and a
/// branching enum field.
void generateDifferentiationDataStructures(
Expand Down Expand Up @@ -771,12 +778,13 @@ class LinearMapInfo {
return linearMapStructEnumFields.lookup(linearMapStruct);
}

/// Finds the linear map declaration in the pullback struct for an `apply` or
/// `struct_extract` in the original function.
VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
auto lookup = linearMapValueMap.find(inst);
assert(lookup != linearMapValueMap.end() &&
"No linear map declaration corresponding to the given instruction");
/// Finds the linear map declaration in the pullback struct for the given
/// `apply` instruction in the original function.
VarDecl *lookUpLinearMapDecl(ApplyInst *ai) {
assert(ai->getFunction() == original);
auto lookup = linearMapFieldMap.find(ai);
assert(lookup != linearMapFieldMap.end() &&
"No linear map field corresponding to the given `apply`");
return lookup->getSecond();
}
};
Expand Down Expand Up @@ -1047,7 +1055,8 @@ class ADContext {
for (auto *da : original->getAttrs().getAttributes<DifferentiableAttr>()) {
auto *daParamIndices = da->getParameterIndices();
auto *daIndexSet = autodiff::getLoweredParameterIndices(
daParamIndices, original->getInterfaceType()->castTo<AnyFunctionType>());
daParamIndices,
original->getInterfaceType()->castTo<AnyFunctionType>());
// If all indices in `indexSet` are in `daIndexSet`, and it has fewer
// indices than our current candidate and a primitive VJP, then `da` is
// our new candidate.
Expand Down Expand Up @@ -1646,8 +1655,6 @@ bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
return false;
}

/// Takes an `apply` instruction and adds its linear map function to the
/// linear map struct if it is active.
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
const SILAutoDiffIndices &indices) {
SmallVector<SILValue, 4> allResults;
Expand Down Expand Up @@ -1696,15 +1703,15 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
[&](CanSILFunctionType origFnTy) {
// Check non-differentiable arguments.
for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
auto remappedParamType =
origFnTy->getParameters()[paramIndex].getSILStorageInterfaceType();
auto remappedParamType = origFnTy->getParameters()[paramIndex]
.getSILStorageInterfaceType();
if (applyIndices.isWrtParameter(paramIndex) &&
!remappedParamType.isDifferentiable(derivative->getModule()))
return true;
}
// Check non-differentiable results.
auto remappedResultType =
origFnTy->getResults()[applyIndices.source].getSILStorageInterfaceType();
auto remappedResultType = origFnTy->getResults()[applyIndices.source]
.getSILStorageInterfaceType();
if (!remappedResultType.isDifferentiable(derivative->getModule()))
return true;
return false;
Expand All @@ -1713,13 +1720,13 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
return;

AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
parameters, source, derivativeFnKind, context.getTypeConverter(),
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
auto derivativeFnType =
remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
parameters, source, derivativeFnKind, context.getTypeConverter(),
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));

auto derivativeFnResultTypes =
derivativeFnType->getAllResultsInterfaceType().castTo<TupleType>();
derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1);
auto linearMapSILType = SILType::getPrimitiveObjectType(
derivativeFnResultTypes
->getElement(derivativeFnResultTypes->getElements().size() - 1)
Expand Down Expand Up @@ -1760,8 +1767,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
break;
}
for (auto &origBB : *original) {
auto *traceEnum =
createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo);
auto *traceEnum = createBranchingTraceDecl(
&origBB, indices, derivativeFnGenSig, loopInfo);
branchingTraceDecls.insert({&origBB, traceEnum});
if (origBB.isEntry())
continue;
Expand Down Expand Up @@ -5847,6 +5854,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
/// Record a temporary value for cleanup before its block's terminator.
SILValue recordTemporary(SILValue value) {
assert(value->getType().isObject());
assert(value->getFunction() == &getPullback());
blockTemporaries[value->getParentBlock()].push_back(value);
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
auto insertion = blockTemporarySet.insert(value); (void)insertion;
Expand Down Expand Up @@ -5971,7 +5979,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
auto insertion = valueMap.try_emplace({origBB, originalValue},
adjointValue);
LLVM_DEBUG(getADDebugStream()
<< "The existing adjoint value will be replaced: "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the previous message made it sound like the "existing (old) adjoint value" is displayed instead of the new value.

<< "The new adjoint value, replacing the existing one, is: "
<< insertion.first->getSecond());
if (!insertion.second)
insertion.first->getSecond() = adjointValue;
Expand Down Expand Up @@ -6000,7 +6008,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
assert(originalValue->getType().isObject());
assert(newAdjointValue.getType().isObject());
assert(originalValue->getFunction() == &getOriginal());
LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
LLVM_DEBUG(getADDebugStream() << "Adding adjoint value for "
<< originalValue);
// The adjoint value must be in the tangent space.
assert(newAdjointValue.getType() ==
getRemappedTangentType(originalValue->getType()));
Expand Down Expand Up @@ -6494,8 +6503,8 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
SmallVector<SILValue, 4> directResults;
auto indirectResultIt = pullback.getIndirectResults().begin();
for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
auto resultType =
pullback.mapTypeIntoContext(resultInfo.getInterfaceType())->getCanonicalType();
auto resultType = pullback.mapTypeIntoContext(
resultInfo.getInterfaceType())->getCanonicalType();
if (resultInfo.isFormalDirect())
directResults.push_back(emitZeroDirect(resultType, pbLoc));
else
Expand Down Expand Up @@ -6539,7 +6548,7 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
auto &predBBActiveValues = activeValues[origPredBB];
for (auto activeValue : predBBActiveValues) {
LLVM_DEBUG(getADDebugStream()
<< "Propagating active adjoint " << activeValue
<< "Propagating adjoint of active value " << activeValue
<< " to predecessors' pullback blocks\n");
if (activeValue->getType().isObject()) {
auto activeValueAdj = getAdjointValue(origBB, activeValue);
Expand Down Expand Up @@ -7641,7 +7650,7 @@ SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
// TODO: Optimize for the case when lhs == rhs.
LLVM_DEBUG(getADDebugStream() <<
"Emitting adjoint accumulation for lhs: " << lhs <<
" and rhs: " << rhs << "\n");
" and rhs: " << rhs);
assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
assert(lhs->getType().isObject() && rhs->getType().isObject() &&
"Adjoint types must be both object types!");
Expand Down