Skip to content

[AutoDiff] Revamp usefulness propagation in activity analysis. #28225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 108 additions & 110 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ class DifferentiableActivityInfo {
/// Marks the given value as varied and propagates variedness to users.
void setVariedAndPropagateToUsers(SILValue value,
unsigned independentVariableIndex);
/// Propagates variedness for the given operand to its user's results.
/// Propagates variedness from the given operand to its user's results.
void propagateVaried(Operand *operand, unsigned independentVariableIndex);
/// Marks the given value as varied and recursively propagates variedness
/// inwards (to operands) through projections. Skips `@noDerivative` struct
Expand All @@ -1444,8 +1444,18 @@ class DifferentiableActivityInfo {
void setUseful(SILValue value, unsigned dependentVariableIndex);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: currently, setUseful has two users (setUsefulAndPropagateToOperands and propagateUsefulThroughAddress), so it hasn't been inlined.

void setUsefulAcrossArrayInitialization(SILValue value,
unsigned dependentVariableIndex);
void propagateUsefulThroughBuffer(SILValue value,
unsigned dependentVariableIndex);
/// Marks the given value as useful and recursively propagates usefulness to:
/// - Defining instruction operands, if the value has a defining instruction.
/// - Incoming values, if the value is a basic block argument.
void setUsefulAndPropagateToOperands(SILValue value,
unsigned dependentVariableIndex);
/// Propagates usefulnesss to the operands of the given instruction.
void propagateUseful(SILInstruction *inst, unsigned dependentVariableIndex);
/// Marks the given address as useful and recursively propagates usefulness
/// inwards (to operands) through projections. Skips `@noDerivative` struct
/// field projections.
void propagateUsefulThroughAddress(SILValue value,
unsigned dependentVariableIndex);

public:
explicit DifferentiableActivityInfo(
Expand Down Expand Up @@ -1975,6 +1985,71 @@ void DifferentiableActivityInfo::propagateVaried(
}
}

void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
SILValue value, unsigned dependentVariableIndex) {
// Skip already-useful values to prevent infinite recursion.
if (isUseful(value, dependentVariableIndex))
return;
if (value->getType().isAddress()) {
propagateUsefulThroughAddress(value, dependentVariableIndex);
return;
}
setUseful(value, dependentVariableIndex);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: both setUsefulAndPropagateToOperands and propagateUsefulThroughAddress call isUseful and setUseful, which seems suboptimal. I tried removing setUseful from one of them but ran into infinite loops.

Related: I think setUsefulAndPropagateToOperands should be the primary entry point for propagating usefulness, so I eliminated most direct calls to propagateUsefulThroughAddress.

// If the given value is a basic block argument, propagate usefulness to
// incoming values.
if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) {
SmallVector<SILValue, 4> incomingValues;
bbArg->getSingleTerminatorOperands(incomingValues);
for (auto incomingValue : incomingValues)
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
return;
}
auto *inst = value->getDefiningInstruction();
if (!inst)
return;
propagateUseful(inst, dependentVariableIndex);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: propagateUseful cannot be inlined because it has multiple users (called multiple times in propagateUsefulThroughAddress.

}

void DifferentiableActivityInfo::propagateUseful(
SILInstruction *inst, unsigned dependentVariableIndex) {
// Propagate usefulness for the given instruction: mark operands as useful and
// recursively propagate usefulness to defining instructions of operands.
auto i = dependentVariableIndex;
// Handle indirect results in `apply`.
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
if (isWithoutDerivative(ai->getCallee()))
return;
for (auto arg : ai->getArgumentsWithoutIndirectResults())
setUsefulAndPropagateToOperands(arg, i);
}
// Handle store-like instructions:
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_USEFUL_THROUGH_STORE(INST) \
else if (auto *si = dyn_cast<INST##Inst>(inst)) { \
setUsefulAndPropagateToOperands(si->getSrc(), i); \
}
PROPAGATE_USEFUL_THROUGH_STORE(Store)
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow)
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr)
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr)
#undef PROPAGATE_USEFUL_THROUGH_STORE
// Handle struct element extraction, skipping `@noDerivative` fields:
// `struct_extract`, `struct_element_addr`.
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST) \
else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
if (!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
setUsefulAndPropagateToOperands(sei->getOperand(), i); \
}
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract)
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr)
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
// Handle everything else.
else {
for (auto &op : inst->getAllOperands())
setUsefulAndPropagateToOperands(op.get(), i);
}
}

void DifferentiableActivityInfo::analyze(DominanceInfo *di,
PostDominanceInfo *pdi) {
auto &function = getFunction();
Expand Down Expand Up @@ -2010,117 +2085,40 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,

// Mark differentiable outputs as useful.
assert(usefulValueSets.empty());
for (auto output : outputValues) {
for (auto outputAndIdx : enumerate(outputValues)) {
auto output = outputAndIdx.value();
unsigned i = outputAndIdx.index();
usefulValueSets.push_back({});
// If the output has an address or class type, propagate usefulness
// recursively.
if (output->getType().isAddress() ||
output->getType().isClassOrClassMetatype())
propagateUsefulThroughBuffer(output, usefulValueSets.size() - 1);
// Otherwise, just mark the output as useful.
else
setUseful(output, usefulValueSets.size() - 1);
}
// Propagate usefulness through the function in post-dominance order.
PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi);
while (auto *bb = postDomOrder.getNext()) {
for (auto &inst : llvm::reverse(*bb)) {
for (auto i : indices(outputValues)) {
// Handle indirect results in `apply`.
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
if (isWithoutDerivative(ai->getCallee()))
continue;
auto checkAndSetUseful = [&](SILValue res) {
if (isUseful(res, i))
for (auto arg : ai->getArgumentsWithoutIndirectResults())
setUseful(arg, i);
};
for (auto dirRes : ai->getResults())
checkAndSetUseful(dirRes);
for (auto indRes : ai->getIndirectSILResults())
checkAndSetUseful(indRes);
auto paramInfos = ai->getSubstCalleeConv().getParameters();
for (auto i : indices(paramInfos))
if (paramInfos[i].isIndirectInOut())
checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]);
}
// Handle store-like instructions:
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \
else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
if (isUseful(si->getDest(), i)) \
PROPAGATE(si->getSrc(), i); \
}
PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful)
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful)
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer)
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr,
propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STORE
// Handle struct element extraction, skipping `@noDerivative` fields:
// `struct_extract`, `struct_element_addr`.
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
if (isUseful(sei, i)) { \
auto hasNoDeriv = sei->getField()->getAttrs() \
.hasAttribute<NoDerivativeAttr>(); \
if (!hasNoDeriv) \
PROPAGATE(sei->getOperand(), i); \
} \
}
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful)
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr,
propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
// Handle everything else.
else if (llvm::any_of(inst.getResults(),
[&](SILValue res) { return isUseful(res, i); })) {
for (auto &op : inst.getAllOperands()) {
auto value = op.get();
if (value->getType().isAddress())
propagateUsefulThroughBuffer(value, i);
setUseful(value, i);
}
}
}
}
// Propagate usefulness from basic block arguments to incoming phi values.
for (auto i : indices(outputValues)) {
for (auto *arg : bb->getArguments()) {
if (isUseful(arg, i)) {
SmallVector<SILValue, 4> incomingValues;
arg->getSingleTerminatorOperands(incomingValues);
for (auto incomingValue : incomingValues)
setUseful(incomingValue, i);
}
}
}
postDomOrder.pushChildren(bb);
setUsefulAndPropagateToOperands(output, i);
}
}

void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
SILValue value, unsigned dependentVariableIndex) {
// Array initializer syntax is lowered to an intrinsic and one or more
// stores to a `RawPointer` returned by the intrinsic.
auto uai = getAllocateUninitializedArrayIntrinsic(value);
auto *uai = getAllocateUninitializedArrayIntrinsic(value);
if (!uai) return;
for (auto use : value->getUses()) {
auto dti = dyn_cast<DestructureTupleInst>(use->getUser());
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
if (!dti) continue;
// The second tuple field of the return value is the `RawPointer`.
for (auto use : dti->getResult(1)->getUses()) {
// The `RawPointer` passes through a `pointer_to_address`. That
// instruction's first use is a `store` whose src is useful; its
// instruction's first use is a `store` whose source is useful; its
// subsequent uses are `index_addr`s whose only use is a useful `store`.
for (auto use : use->getUser()->getResult(0)->getUses()) {
auto inst = use->getUser();
if (auto si = dyn_cast<StoreInst>(inst)) {
setUseful(si->getSrc(), dependentVariableIndex);
} else if (auto iai = dyn_cast<IndexAddrInst>(inst)) {
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
assert(ptai && "Expected `pointer_to_address` user for uninitialized "
"array intrinsic");
for (auto use : ptai->getUses()) {
auto *inst = use->getUser();
if (auto *si = dyn_cast<StoreInst>(inst)) {
setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex);
} else if (auto *iai = dyn_cast<IndexAddrInst>(inst)) {
for (auto use : iai->getUses())
if (auto si = dyn_cast<StoreInst>(use->getUser()))
setUseful(si->getSrc(), dependentVariableIndex);
setUsefulAndPropagateToOperands(si->getSrc(),
dependentVariableIndex);
}
}
}
Expand Down Expand Up @@ -2154,21 +2152,20 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
op.get(), independentVariableIndex);
}

void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
void DifferentiableActivityInfo::propagateUsefulThroughAddress(
SILValue value, unsigned dependentVariableIndex) {
assert(value->getType().isAddress() ||
value->getType().isClassOrClassMetatype());
assert(value->getType().isAddress());
// Check whether value is already useful to prevent infinite recursion.
if (isUseful(value, dependentVariableIndex))
return;
setUseful(value, dependentVariableIndex);
if (auto *inst = value->getDefiningInstruction())
for (auto &operand : inst->getAllOperands())
if (operand.get()->getType().isAddress())
propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex);
propagateUseful(inst, dependentVariableIndex);
// Recursively propagate usefulness through users that are projections or
// `begin_access` instructions.
for (auto use : value->getUses()) {
// Propagate usefulness through user's operands.
propagateUseful(use->getUser(), dependentVariableIndex);
for (auto res : use->getUser()->getResults()) {
#define SKIP_NODERIVATIVE(INST) \
if (auto *sei = dyn_cast<INST##Inst>(res)) \
Expand All @@ -2178,7 +2175,7 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
SKIP_NODERIVATIVE(StructElementAddr)
#undef SKIP_NODERIVATIVE
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res))
propagateUsefulThroughBuffer(res, dependentVariableIndex);
propagateUsefulThroughAddress(res, dependentVariableIndex);
}
}
}
Expand Down Expand Up @@ -6219,15 +6216,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(),
bbActiveValues.end());
// Register a value as active if it has not yet been visited.
bool diagnosedActiveEnumValue = false;
auto addActiveValue = [&](SILValue v) {
if (visited.count(v))
return;
// Diagnose active enum values. Differentiation of enum values is not
// yet supported; requires special adjoint value handling.
if (v->getType().getEnumOrBoundGenericEnum()) {
// Diagnose active enum values. Differentiation of enum values requires
// special adjoint value handling and is not yet supported. Diagnose
// only the first active enum value to prevent too many diagnostics.
if (!diagnosedActiveEnumValue &&
v->getType().getEnumOrBoundGenericEnum()) {
getContext().emitNondifferentiabilityError(
v, getInvoker(), diag::autodiff_enums_unsupported);
errorOccurred = true;
diagnosedActiveEnumValue = true;
}
// Skip address projections.
// Address projections do not need their own adjoint buffers; they
Expand All @@ -6238,9 +6239,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
bbActiveValues.push_back(v);
};
// Register bb arguments and all instruction operands/results.
for (auto *arg : bb->getArguments())
if (getActivityInfo().isActive(arg, getIndices()))
addActiveValue(arg);
for (auto &inst : *bb) {
for (auto op : inst.getOperandValues())
if (getActivityInfo().isActive(op, getIndices()))
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/activity_analysis.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ func TF_954(_ x: Float) -> Float {
// CHECK: bb1:
// CHECK: [ACTIVE] %10 = alloc_stack $Float, var, name "inner"
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Float
// CHECK: [NONE] %14 = metatype $@thin Float.Type
// CHECK: [USEFUL] %14 = metatype $@thin Float.Type
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %10 : $*Float
// CHECK: [VARIED] %16 = load [trivial] %15 : $*Float
// CHECK: [ACTIVE] %16 = load [trivial] %15 : $*Float
// CHECK: [NONE] // function_ref static Float.* infix(_:_:)
// CHECK: %18 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK: [VARIED] %19 = apply %18(%16, %0, %14) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK: [ACTIVE] %19 = apply %18(%16, %0, %14) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
// CHECK: [ACTIVE] %20 = begin_access [modify] [static] %10 : $*Float
// CHECK: bb3:
// CHECK: [ACTIVE] %31 = begin_access [read] [static] %10 : $*Float
Expand Down
19 changes: 5 additions & 14 deletions test/AutoDiff/control_flow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ ControlFlowTests.test("Conditionals") {
}
return outer
}
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
// nested control flow.
// expectEqual((9, 6), valueWithGradient(at: 3, in: cond4_var))
expectEqual((9, 0), valueWithGradient(at: 3, in: cond4_var))
expectEqual((9, 6), valueWithGradient(at: 3, in: cond4_var))

func cond_tuple(_ x: Float) -> Float {
// Convoluted function returning `x + x`.
Expand Down Expand Up @@ -707,16 +704,10 @@ ControlFlowTests.test("Loops") {
}
return outer
}
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
// nested control flow.
// expectEqual((6, 5), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
// expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
// expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
// expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
expectEqual((6, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
expectEqual((20, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
expectEqual((52, 26), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
expectEqual((6, 5), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
}

runAllTests()
16 changes: 11 additions & 5 deletions test/AutoDiff/control_flow_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,12 @@ enum Tree : Differentiable & AdditiveArithmetic {

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
// TODO(TF-956): Improve location of active enum non-differentiability errors
// so that they are closer to the source of the non-differentiability.
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func +(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
// expected-note @+1 {{differentiating enum values is not yet supported}}
case let (.leaf(x), .leaf(y)):
return .leaf(x + y)
case let (.branch(x1, x2), .branch(y1, y2)):
Expand All @@ -128,10 +130,12 @@ enum Tree : Differentiable & AdditiveArithmetic {

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
// TODO(TF-956): Improve location of active enum non-differentiability errors
// so that they are closer to the source of the non-differentiability.
// expected-note @+2 {{when differentiating this function definition}}
// expected-note @+1 {{differentiating enum values is not yet supported}}
static func -(_ lhs: Self, _ rhs: Self) -> Self {
switch (lhs, rhs) {
// expected-note @+1 {{differentiating enum values is not yet supported}}
case let (.leaf(x), .leaf(y)):
return .leaf(x - y)
case let (.branch(x1, x2), .branch(y1, y2)):
Expand All @@ -147,7 +151,9 @@ enum Tree : Differentiable & AdditiveArithmetic {
// expected-note @+1 {{when differentiating this function definition}}
func loop_array(_ array: [Float]) -> Float {
var result: Float = 1
// expected-note @+1 {{differentiating enum values is not yet supported}}
// TODO(TF-957): Improve non-differentiability errors for for-in loops
// (`Collection.makeIterator` and `IteratorProtocol.next`).
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
for x in array {
result = result * x
}
Expand Down
Loading