-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[AutoDiff] Revamp usefulness propagation in activity analysis. #28225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1433,7 +1433,7 @@ class DifferentiableActivityInfo { | |
/// Marks the given value as varied and propagates variedness to users. | ||
void setVariedAndPropagateToUsers(SILValue value, | ||
unsigned independentVariableIndex); | ||
/// Propagates variedness for the given operand to its user's results. | ||
/// Propagates variedness from the given operand to its user's results. | ||
void propagateVaried(Operand *operand, unsigned independentVariableIndex); | ||
/// Marks the given value as varied and recursively propagates variedness | ||
/// inwards (to operands) through projections. Skips `@noDerivative` struct | ||
|
@@ -1444,8 +1444,18 @@ class DifferentiableActivityInfo { | |
void setUseful(SILValue value, unsigned dependentVariableIndex); | ||
void setUsefulAcrossArrayInitialization(SILValue value, | ||
unsigned dependentVariableIndex); | ||
void propagateUsefulThroughBuffer(SILValue value, | ||
unsigned dependentVariableIndex); | ||
/// Marks the given value as useful and recursively propagates usefulness to: | ||
/// - Defining instruction operands, if the value has a defining instruction. | ||
/// - Incoming values, if the value is a basic block argument. | ||
void setUsefulAndPropagateToOperands(SILValue value, | ||
unsigned dependentVariableIndex); | ||
/// Propagates usefulnesss to the operands of the given instruction. | ||
void propagateUseful(SILInstruction *inst, unsigned dependentVariableIndex); | ||
/// Marks the given address as useful and recursively propagates usefulness | ||
/// inwards (to operands) through projections. Skips `@noDerivative` struct | ||
/// field projections. | ||
void propagateUsefulThroughAddress(SILValue value, | ||
unsigned dependentVariableIndex); | ||
|
||
public: | ||
explicit DifferentiableActivityInfo( | ||
|
@@ -1975,6 +1985,71 @@ void DifferentiableActivityInfo::propagateVaried( | |
} | ||
} | ||
|
||
void DifferentiableActivityInfo::setUsefulAndPropagateToOperands( | ||
SILValue value, unsigned dependentVariableIndex) { | ||
// Skip already-useful values to prevent infinite recursion. | ||
if (isUseful(value, dependentVariableIndex)) | ||
return; | ||
if (value->getType().isAddress()) { | ||
propagateUsefulThroughAddress(value, dependentVariableIndex); | ||
return; | ||
} | ||
setUseful(value, dependentVariableIndex); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: both Related: I think |
||
// If the given value is a basic block argument, propagate usefulness to | ||
// incoming values. | ||
if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) { | ||
SmallVector<SILValue, 4> incomingValues; | ||
bbArg->getSingleTerminatorOperands(incomingValues); | ||
for (auto incomingValue : incomingValues) | ||
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex); | ||
return; | ||
} | ||
auto *inst = value->getDefiningInstruction(); | ||
if (!inst) | ||
return; | ||
propagateUseful(inst, dependentVariableIndex); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: |
||
} | ||
|
||
void DifferentiableActivityInfo::propagateUseful( | ||
SILInstruction *inst, unsigned dependentVariableIndex) { | ||
// Propagate usefulness for the given instruction: mark operands as useful and | ||
// recursively propagate usefulness to defining instructions of operands. | ||
auto i = dependentVariableIndex; | ||
// Handle indirect results in `apply`. | ||
if (auto *ai = dyn_cast<ApplyInst>(inst)) { | ||
if (isWithoutDerivative(ai->getCallee())) | ||
return; | ||
for (auto arg : ai->getArgumentsWithoutIndirectResults()) | ||
setUsefulAndPropagateToOperands(arg, i); | ||
} | ||
// Handle store-like instructions: | ||
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` | ||
#define PROPAGATE_USEFUL_THROUGH_STORE(INST) \ | ||
else if (auto *si = dyn_cast<INST##Inst>(inst)) { \ | ||
setUsefulAndPropagateToOperands(si->getSrc(), i); \ | ||
} | ||
PROPAGATE_USEFUL_THROUGH_STORE(Store) | ||
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow) | ||
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr) | ||
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr) | ||
#undef PROPAGATE_USEFUL_THROUGH_STORE | ||
// Handle struct element extraction, skipping `@noDerivative` fields: | ||
// `struct_extract`, `struct_element_addr`. | ||
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST) \ | ||
else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \ | ||
if (!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \ | ||
setUsefulAndPropagateToOperands(sei->getOperand(), i); \ | ||
} | ||
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract) | ||
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr) | ||
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION | ||
// Handle everything else. | ||
else { | ||
for (auto &op : inst->getAllOperands()) | ||
setUsefulAndPropagateToOperands(op.get(), i); | ||
} | ||
} | ||
|
||
void DifferentiableActivityInfo::analyze(DominanceInfo *di, | ||
PostDominanceInfo *pdi) { | ||
auto &function = getFunction(); | ||
|
@@ -2010,117 +2085,40 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, | |
|
||
// Mark differentiable outputs as useful. | ||
assert(usefulValueSets.empty()); | ||
for (auto output : outputValues) { | ||
for (auto outputAndIdx : enumerate(outputValues)) { | ||
auto output = outputAndIdx.value(); | ||
unsigned i = outputAndIdx.index(); | ||
usefulValueSets.push_back({}); | ||
// If the output has an address or class type, propagate usefulness | ||
// recursively. | ||
if (output->getType().isAddress() || | ||
output->getType().isClassOrClassMetatype()) | ||
propagateUsefulThroughBuffer(output, usefulValueSets.size() - 1); | ||
// Otherwise, just mark the output as useful. | ||
else | ||
setUseful(output, usefulValueSets.size() - 1); | ||
} | ||
// Propagate usefulness through the function in post-dominance order. | ||
PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi); | ||
while (auto *bb = postDomOrder.getNext()) { | ||
for (auto &inst : llvm::reverse(*bb)) { | ||
for (auto i : indices(outputValues)) { | ||
// Handle indirect results in `apply`. | ||
if (auto *ai = dyn_cast<ApplyInst>(&inst)) { | ||
if (isWithoutDerivative(ai->getCallee())) | ||
continue; | ||
auto checkAndSetUseful = [&](SILValue res) { | ||
if (isUseful(res, i)) | ||
for (auto arg : ai->getArgumentsWithoutIndirectResults()) | ||
setUseful(arg, i); | ||
}; | ||
for (auto dirRes : ai->getResults()) | ||
checkAndSetUseful(dirRes); | ||
for (auto indRes : ai->getIndirectSILResults()) | ||
checkAndSetUseful(indRes); | ||
auto paramInfos = ai->getSubstCalleeConv().getParameters(); | ||
for (auto i : indices(paramInfos)) | ||
if (paramInfos[i].isIndirectInOut()) | ||
checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]); | ||
} | ||
// Handle store-like instructions: | ||
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast` | ||
#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \ | ||
else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \ | ||
if (isUseful(si->getDest(), i)) \ | ||
PROPAGATE(si->getSrc(), i); \ | ||
} | ||
PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful) | ||
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful) | ||
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer) | ||
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr, | ||
propagateUsefulThroughBuffer) | ||
#undef PROPAGATE_USEFUL_THROUGH_STORE | ||
// Handle struct element extraction, skipping `@noDerivative` fields: | ||
// `struct_extract`, `struct_element_addr`. | ||
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \ | ||
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \ | ||
if (isUseful(sei, i)) { \ | ||
auto hasNoDeriv = sei->getField()->getAttrs() \ | ||
.hasAttribute<NoDerivativeAttr>(); \ | ||
if (!hasNoDeriv) \ | ||
PROPAGATE(sei->getOperand(), i); \ | ||
} \ | ||
} | ||
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful) | ||
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr, | ||
propagateUsefulThroughBuffer) | ||
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION | ||
// Handle everything else. | ||
else if (llvm::any_of(inst.getResults(), | ||
[&](SILValue res) { return isUseful(res, i); })) { | ||
for (auto &op : inst.getAllOperands()) { | ||
auto value = op.get(); | ||
if (value->getType().isAddress()) | ||
propagateUsefulThroughBuffer(value, i); | ||
setUseful(value, i); | ||
} | ||
} | ||
} | ||
} | ||
// Propagate usefulness from basic block arguments to incoming phi values. | ||
for (auto i : indices(outputValues)) { | ||
for (auto *arg : bb->getArguments()) { | ||
if (isUseful(arg, i)) { | ||
SmallVector<SILValue, 4> incomingValues; | ||
arg->getSingleTerminatorOperands(incomingValues); | ||
for (auto incomingValue : incomingValues) | ||
setUseful(incomingValue, i); | ||
} | ||
} | ||
} | ||
postDomOrder.pushChildren(bb); | ||
setUsefulAndPropagateToOperands(output, i); | ||
} | ||
} | ||
|
||
void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization( | ||
SILValue value, unsigned dependentVariableIndex) { | ||
// Array initializer syntax is lowered to an intrinsic and one or more | ||
// stores to a `RawPointer` returned by the intrinsic. | ||
auto uai = getAllocateUninitializedArrayIntrinsic(value); | ||
auto *uai = getAllocateUninitializedArrayIntrinsic(value); | ||
if (!uai) return; | ||
for (auto use : value->getUses()) { | ||
auto dti = dyn_cast<DestructureTupleInst>(use->getUser()); | ||
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()); | ||
if (!dti) continue; | ||
// The second tuple field of the return value is the `RawPointer`. | ||
for (auto use : dti->getResult(1)->getUses()) { | ||
// The `RawPointer` passes through a `pointer_to_address`. That | ||
// instruction's first use is a `store` whose src is useful; its | ||
// instruction's first use is a `store` whose source is useful; its | ||
// subsequent uses are `index_addr`s whose only use is a useful `store`. | ||
for (auto use : use->getUser()->getResult(0)->getUses()) { | ||
auto inst = use->getUser(); | ||
if (auto si = dyn_cast<StoreInst>(inst)) { | ||
setUseful(si->getSrc(), dependentVariableIndex); | ||
} else if (auto iai = dyn_cast<IndexAddrInst>(inst)) { | ||
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser()); | ||
assert(ptai && "Expected `pointer_to_address` user for uninitialized " | ||
"array intrinsic"); | ||
for (auto use : ptai->getUses()) { | ||
auto *inst = use->getUser(); | ||
if (auto *si = dyn_cast<StoreInst>(inst)) { | ||
setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex); | ||
} else if (auto *iai = dyn_cast<IndexAddrInst>(inst)) { | ||
for (auto use : iai->getUses()) | ||
if (auto si = dyn_cast<StoreInst>(use->getUser())) | ||
setUseful(si->getSrc(), dependentVariableIndex); | ||
setUsefulAndPropagateToOperands(si->getSrc(), | ||
dependentVariableIndex); | ||
} | ||
} | ||
} | ||
|
@@ -2154,21 +2152,20 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections( | |
op.get(), independentVariableIndex); | ||
} | ||
|
||
void DifferentiableActivityInfo::propagateUsefulThroughBuffer( | ||
void DifferentiableActivityInfo::propagateUsefulThroughAddress( | ||
SILValue value, unsigned dependentVariableIndex) { | ||
assert(value->getType().isAddress() || | ||
value->getType().isClassOrClassMetatype()); | ||
assert(value->getType().isAddress()); | ||
// Check whether value is already useful to prevent infinite recursion. | ||
if (isUseful(value, dependentVariableIndex)) | ||
return; | ||
setUseful(value, dependentVariableIndex); | ||
if (auto *inst = value->getDefiningInstruction()) | ||
for (auto &operand : inst->getAllOperands()) | ||
if (operand.get()->getType().isAddress()) | ||
propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex); | ||
propagateUseful(inst, dependentVariableIndex); | ||
// Recursively propagate usefulness through users that are projections or | ||
// `begin_access` instructions. | ||
for (auto use : value->getUses()) { | ||
// Propagate usefulness through user's operands. | ||
propagateUseful(use->getUser(), dependentVariableIndex); | ||
for (auto res : use->getUser()->getResults()) { | ||
#define SKIP_NODERIVATIVE(INST) \ | ||
if (auto *sei = dyn_cast<INST##Inst>(res)) \ | ||
|
@@ -2178,7 +2175,7 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer( | |
SKIP_NODERIVATIVE(StructElementAddr) | ||
#undef SKIP_NODERIVATIVE | ||
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res)) | ||
propagateUsefulThroughBuffer(res, dependentVariableIndex); | ||
propagateUsefulThroughAddress(res, dependentVariableIndex); | ||
} | ||
} | ||
} | ||
|
@@ -6219,15 +6216,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> { | |
SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(), | ||
bbActiveValues.end()); | ||
// Register a value as active if it has not yet been visited. | ||
bool diagnosedActiveEnumValue = false; | ||
auto addActiveValue = [&](SILValue v) { | ||
if (visited.count(v)) | ||
return; | ||
// Diagnose active enum values. Differentiation of enum values is not | ||
// yet supported; requires special adjoint value handling. | ||
if (v->getType().getEnumOrBoundGenericEnum()) { | ||
// Diagnose active enum values. Differentiation of enum values requires | ||
// special adjoint value handling and is not yet supported. Diagnose | ||
// only the first active enum value to prevent too many diagnostics. | ||
if (!diagnosedActiveEnumValue && | ||
v->getType().getEnumOrBoundGenericEnum()) { | ||
getContext().emitNondifferentiabilityError( | ||
v, getInvoker(), diag::autodiff_enums_unsupported); | ||
errorOccurred = true; | ||
diagnosedActiveEnumValue = true; | ||
} | ||
// Skip address projections. | ||
// Address projections do not need their own adjoint buffers; they | ||
|
@@ -6238,9 +6239,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> { | |
bbActiveValues.push_back(v); | ||
}; | ||
// Register bb arguments and all instruction operands/results. | ||
for (auto *arg : bb->getArguments()) | ||
if (getActivityInfo().isActive(arg, getIndices())) | ||
addActiveValue(arg); | ||
for (auto &inst : *bb) { | ||
for (auto op : inst.getOperandValues()) | ||
if (getActivityInfo().isActive(op, getIndices())) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: currently,
setUseful
has two users (setUsefulAndPropagateToOperands
andpropagateUsefulThroughAddress
), so it hasn't been inlined.