@@ -1433,7 +1433,7 @@ class DifferentiableActivityInfo {
1433
1433
// / Marks the given value as varied and propagates variedness to users.
1434
1434
void setVariedAndPropagateToUsers (SILValue value,
1435
1435
unsigned independentVariableIndex);
1436
- // / Propagates variedness for the given operand to its user's results.
1436
+ // / Propagates variedness from the given operand to its user's results.
1437
1437
void propagateVaried (Operand *operand, unsigned independentVariableIndex);
1438
1438
// / Marks the given value as varied and recursively propagates variedness
1439
1439
// / inwards (to operands) through projections. Skips `@noDerivative` struct
@@ -1444,8 +1444,18 @@ class DifferentiableActivityInfo {
1444
1444
void setUseful (SILValue value, unsigned dependentVariableIndex);
1445
1445
void setUsefulAcrossArrayInitialization (SILValue value,
1446
1446
unsigned dependentVariableIndex);
1447
- void propagateUsefulThroughBuffer (SILValue value,
1448
- unsigned dependentVariableIndex);
1447
+ // / Marks the given value as useful and recursively propagates usefulness to:
1448
+ // / - Defining instruction operands, if the value has a defining instruction.
1449
+ // / - Incoming values, if the value is a basic block argument.
1450
+ void setUsefulAndPropagateToOperands (SILValue value,
1451
+ unsigned dependentVariableIndex);
1452
+ // / Propagates usefulnesss to the operands of the given instruction.
1453
+ void propagateUseful (SILInstruction *inst, unsigned dependentVariableIndex);
1454
+ // / Marks the given address as useful and recursively propagates usefulness
1455
+ // / inwards (to operands) through projections. Skips `@noDerivative` struct
1456
+ // / field projections.
1457
+ void propagateUsefulThroughAddress (SILValue value,
1458
+ unsigned dependentVariableIndex);
1449
1459
1450
1460
public:
1451
1461
explicit DifferentiableActivityInfo (
@@ -1975,6 +1985,71 @@ void DifferentiableActivityInfo::propagateVaried(
1975
1985
}
1976
1986
}
1977
1987
1988
+ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands (
1989
+ SILValue value, unsigned dependentVariableIndex) {
1990
+ // Skip already-useful values to prevent infinite recursion.
1991
+ if (isUseful (value, dependentVariableIndex))
1992
+ return ;
1993
+ if (value->getType ().isAddress ()) {
1994
+ propagateUsefulThroughAddress (value, dependentVariableIndex);
1995
+ return ;
1996
+ }
1997
+ setUseful (value, dependentVariableIndex);
1998
+ // If the given value is a basic block argument, propagate usefulness to
1999
+ // incoming values.
2000
+ if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) {
2001
+ SmallVector<SILValue, 4 > incomingValues;
2002
+ bbArg->getSingleTerminatorOperands (incomingValues);
2003
+ for (auto incomingValue : incomingValues)
2004
+ setUsefulAndPropagateToOperands (incomingValue, dependentVariableIndex);
2005
+ return ;
2006
+ }
2007
+ auto *inst = value->getDefiningInstruction ();
2008
+ if (!inst)
2009
+ return ;
2010
+ propagateUseful (inst, dependentVariableIndex);
2011
+ }
2012
+
2013
+ void DifferentiableActivityInfo::propagateUseful (
2014
+ SILInstruction *inst, unsigned dependentVariableIndex) {
2015
+ // Propagate usefulness for the given instruction: mark operands as useful and
2016
+ // recursively propagate usefulness to defining instructions of operands.
2017
+ auto i = dependentVariableIndex;
2018
+ // Handle indirect results in `apply`.
2019
+ if (auto *ai = dyn_cast<ApplyInst>(inst)) {
2020
+ if (isWithoutDerivative (ai->getCallee ()))
2021
+ return ;
2022
+ for (auto arg : ai->getArgumentsWithoutIndirectResults ())
2023
+ setUsefulAndPropagateToOperands (arg, i);
2024
+ }
2025
+ // Handle store-like instructions:
2026
+ // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
2027
+ #define PROPAGATE_USEFUL_THROUGH_STORE (INST ) \
2028
+ else if (auto *si = dyn_cast<INST##Inst>(inst)) { \
2029
+ setUsefulAndPropagateToOperands (si->getSrc (), i); \
2030
+ }
2031
+ PROPAGATE_USEFUL_THROUGH_STORE (Store)
2032
+ PROPAGATE_USEFUL_THROUGH_STORE (StoreBorrow)
2033
+ PROPAGATE_USEFUL_THROUGH_STORE (CopyAddr)
2034
+ PROPAGATE_USEFUL_THROUGH_STORE (UnconditionalCheckedCastAddr)
2035
+ #undef PROPAGATE_USEFUL_THROUGH_STORE
2036
+ // Handle struct element extraction, skipping `@noDerivative` fields:
2037
+ // `struct_extract`, `struct_element_addr`.
2038
+ #define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (INST ) \
2039
+ else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
2040
+ if (!sei->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) \
2041
+ setUsefulAndPropagateToOperands (sei->getOperand (), i); \
2042
+ }
2043
+ PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructExtract)
2044
+ PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructElementAddr)
2045
+ #undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
2046
+ // Handle everything else.
2047
+ else {
2048
+ for (auto &op : inst->getAllOperands ())
2049
+ setUsefulAndPropagateToOperands (op.get (), i);
2050
+ }
2051
+ }
2052
+
1978
2053
void DifferentiableActivityInfo::analyze (DominanceInfo *di,
1979
2054
PostDominanceInfo *pdi) {
1980
2055
auto &function = getFunction ();
@@ -2010,117 +2085,40 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
2010
2085
2011
2086
// Mark differentiable outputs as useful.
2012
2087
assert (usefulValueSets.empty ());
2013
- for (auto output : outputValues) {
2088
+ for (auto outputAndIdx : enumerate(outputValues)) {
2089
+ auto output = outputAndIdx.value ();
2090
+ unsigned i = outputAndIdx.index ();
2014
2091
usefulValueSets.push_back ({});
2015
- // If the output has an address or class type, propagate usefulness
2016
- // recursively.
2017
- if (output->getType ().isAddress () ||
2018
- output->getType ().isClassOrClassMetatype ())
2019
- propagateUsefulThroughBuffer (output, usefulValueSets.size () - 1 );
2020
- // Otherwise, just mark the output as useful.
2021
- else
2022
- setUseful (output, usefulValueSets.size () - 1 );
2023
- }
2024
- // Propagate usefulness through the function in post-dominance order.
2025
- PostDominanceOrder postDomOrder (&*function.findReturnBB (), pdi);
2026
- while (auto *bb = postDomOrder.getNext ()) {
2027
- for (auto &inst : llvm::reverse (*bb)) {
2028
- for (auto i : indices (outputValues)) {
2029
- // Handle indirect results in `apply`.
2030
- if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
2031
- if (isWithoutDerivative (ai->getCallee ()))
2032
- continue ;
2033
- auto checkAndSetUseful = [&](SILValue res) {
2034
- if (isUseful (res, i))
2035
- for (auto arg : ai->getArgumentsWithoutIndirectResults ())
2036
- setUseful (arg, i);
2037
- };
2038
- for (auto dirRes : ai->getResults ())
2039
- checkAndSetUseful (dirRes);
2040
- for (auto indRes : ai->getIndirectSILResults ())
2041
- checkAndSetUseful (indRes);
2042
- auto paramInfos = ai->getSubstCalleeConv ().getParameters ();
2043
- for (auto i : indices (paramInfos))
2044
- if (paramInfos[i].isIndirectInOut ())
2045
- checkAndSetUseful (ai->getArgumentsWithoutIndirectResults ()[i]);
2046
- }
2047
- // Handle store-like instructions:
2048
- // `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
2049
- #define PROPAGATE_USEFUL_THROUGH_STORE (INST, PROPAGATE ) \
2050
- else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
2051
- if (isUseful (si->getDest (), i)) \
2052
- PROPAGATE (si->getSrc (), i); \
2053
- }
2054
- PROPAGATE_USEFUL_THROUGH_STORE (Store, setUseful)
2055
- PROPAGATE_USEFUL_THROUGH_STORE (StoreBorrow, setUseful)
2056
- PROPAGATE_USEFUL_THROUGH_STORE (CopyAddr, propagateUsefulThroughBuffer)
2057
- PROPAGATE_USEFUL_THROUGH_STORE (UnconditionalCheckedCastAddr,
2058
- propagateUsefulThroughBuffer)
2059
- #undef PROPAGATE_USEFUL_THROUGH_STORE
2060
- // Handle struct element extraction, skipping `@noDerivative` fields:
2061
- // `struct_extract`, `struct_element_addr`.
2062
- #define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (INST, PROPAGATE ) \
2063
- else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
2064
- if (isUseful (sei, i)) { \
2065
- auto hasNoDeriv = sei->getField ()->getAttrs () \
2066
- .hasAttribute <NoDerivativeAttr>(); \
2067
- if (!hasNoDeriv) \
2068
- PROPAGATE (sei->getOperand (), i); \
2069
- } \
2070
- }
2071
- PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructExtract, setUseful)
2072
- PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION (StructElementAddr,
2073
- propagateUsefulThroughBuffer)
2074
- #undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
2075
- // Handle everything else.
2076
- else if (llvm::any_of (inst.getResults (),
2077
- [&](SILValue res) { return isUseful (res, i); })) {
2078
- for (auto &op : inst.getAllOperands ()) {
2079
- auto value = op.get ();
2080
- if (value->getType ().isAddress ())
2081
- propagateUsefulThroughBuffer (value, i);
2082
- setUseful (value, i);
2083
- }
2084
- }
2085
- }
2086
- }
2087
- // Propagate usefulness from basic block arguments to incoming phi values.
2088
- for (auto i : indices (outputValues)) {
2089
- for (auto *arg : bb->getArguments ()) {
2090
- if (isUseful (arg, i)) {
2091
- SmallVector<SILValue, 4 > incomingValues;
2092
- arg->getSingleTerminatorOperands (incomingValues);
2093
- for (auto incomingValue : incomingValues)
2094
- setUseful (incomingValue, i);
2095
- }
2096
- }
2097
- }
2098
- postDomOrder.pushChildren (bb);
2092
+ setUsefulAndPropagateToOperands (output, i);
2099
2093
}
2100
2094
}
2101
2095
2102
2096
void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization (
2103
2097
SILValue value, unsigned dependentVariableIndex) {
2104
2098
// Array initializer syntax is lowered to an intrinsic and one or more
2105
2099
// stores to a `RawPointer` returned by the intrinsic.
2106
- auto uai = getAllocateUninitializedArrayIntrinsic (value);
2100
+ auto * uai = getAllocateUninitializedArrayIntrinsic (value);
2107
2101
if (!uai) return ;
2108
2102
for (auto use : value->getUses ()) {
2109
- auto dti = dyn_cast<DestructureTupleInst>(use->getUser ());
2103
+ auto * dti = dyn_cast<DestructureTupleInst>(use->getUser ());
2110
2104
if (!dti) continue ;
2111
2105
// The second tuple field of the return value is the `RawPointer`.
2112
2106
for (auto use : dti->getResult (1 )->getUses ()) {
2113
2107
// The `RawPointer` passes through a `pointer_to_address`. That
2114
- // instruction's first use is a `store` whose src is useful; its
2108
+ // instruction's first use is a `store` whose source is useful; its
2115
2109
// subsequent uses are `index_addr`s whose only use is a useful `store`.
2116
- for (auto use : use->getUser ()->getResult (0 )->getUses ()) {
2117
- auto inst = use->getUser ();
2118
- if (auto si = dyn_cast<StoreInst>(inst)) {
2119
- setUseful (si->getSrc (), dependentVariableIndex);
2120
- } else if (auto iai = dyn_cast<IndexAddrInst>(inst)) {
2110
+ auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser ());
2111
+ assert (ptai && " Expected `pointer_to_address` user for uninitialized "
2112
+ " array intrinsic" );
2113
+ for (auto use : ptai->getUses ()) {
2114
+ auto *inst = use->getUser ();
2115
+ if (auto *si = dyn_cast<StoreInst>(inst)) {
2116
+ setUsefulAndPropagateToOperands (si->getSrc (), dependentVariableIndex);
2117
+ } else if (auto *iai = dyn_cast<IndexAddrInst>(inst)) {
2121
2118
for (auto use : iai->getUses ())
2122
2119
if (auto si = dyn_cast<StoreInst>(use->getUser ()))
2123
- setUseful (si->getSrc (), dependentVariableIndex);
2120
+ setUsefulAndPropagateToOperands (si->getSrc (),
2121
+ dependentVariableIndex);
2124
2122
}
2125
2123
}
2126
2124
}
@@ -2154,21 +2152,20 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
2154
2152
op.get (), independentVariableIndex);
2155
2153
}
2156
2154
2157
- void DifferentiableActivityInfo::propagateUsefulThroughBuffer (
2155
+ void DifferentiableActivityInfo::propagateUsefulThroughAddress (
2158
2156
SILValue value, unsigned dependentVariableIndex) {
2159
- assert (value->getType ().isAddress () ||
2160
- value->getType ().isClassOrClassMetatype ());
2157
+ assert (value->getType ().isAddress ());
2161
2158
// Check whether value is already useful to prevent infinite recursion.
2162
2159
if (isUseful (value, dependentVariableIndex))
2163
2160
return ;
2164
2161
setUseful (value, dependentVariableIndex);
2165
2162
if (auto *inst = value->getDefiningInstruction ())
2166
- for (auto &operand : inst->getAllOperands ())
2167
- if (operand.get ()->getType ().isAddress ())
2168
- propagateUsefulThroughBuffer (operand.get (), dependentVariableIndex);
2163
+ propagateUseful (inst, dependentVariableIndex);
2169
2164
// Recursively propagate usefulness through users that are projections or
2170
2165
// `begin_access` instructions.
2171
2166
for (auto use : value->getUses ()) {
2167
+ // Propagate usefulness through user's operands.
2168
+ propagateUseful (use->getUser (), dependentVariableIndex);
2172
2169
for (auto res : use->getUser ()->getResults ()) {
2173
2170
#define SKIP_NODERIVATIVE (INST ) \
2174
2171
if (auto *sei = dyn_cast<INST##Inst>(res)) \
@@ -2178,7 +2175,7 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
2178
2175
SKIP_NODERIVATIVE (StructElementAddr)
2179
2176
#undef SKIP_NODERIVATIVE
2180
2177
if (Projection::isAddressProjection (res) || isa<BeginAccessInst>(res))
2181
- propagateUsefulThroughBuffer (res, dependentVariableIndex);
2178
+ propagateUsefulThroughAddress (res, dependentVariableIndex);
2182
2179
}
2183
2180
}
2184
2181
}
@@ -6219,15 +6216,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6219
6216
SmallPtrSet<SILValue, 8 > visited (bbActiveValues.begin (),
6220
6217
bbActiveValues.end ());
6221
6218
// Register a value as active if it has not yet been visited.
6219
+ bool diagnosedActiveEnumValue = false ;
6222
6220
auto addActiveValue = [&](SILValue v) {
6223
6221
if (visited.count (v))
6224
6222
return ;
6225
- // Diagnose active enum values. Differentiation of enum values is not
6226
- // yet supported; requires special adjoint value handling.
6227
- if (v->getType ().getEnumOrBoundGenericEnum ()) {
6223
+ // Diagnose active enum values. Differentiation of enum values requires
6224
+ // special adjoint value handling and is not yet supported. Diagnose
6225
+ // only the first active enum value to prevent too many diagnostics.
6226
+ if (!diagnosedActiveEnumValue &&
6227
+ v->getType ().getEnumOrBoundGenericEnum ()) {
6228
6228
getContext ().emitNondifferentiabilityError (
6229
6229
v, getInvoker (), diag::autodiff_enums_unsupported);
6230
6230
errorOccurred = true ;
6231
+ diagnosedActiveEnumValue = true ;
6231
6232
}
6232
6233
// Skip address projections.
6233
6234
// Address projections do not need their own adjoint buffers; they
@@ -6238,9 +6239,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
6238
6239
bbActiveValues.push_back (v);
6239
6240
};
6240
6241
// Register bb arguments and all instruction operands/results.
6241
- for (auto *arg : bb->getArguments ())
6242
- if (getActivityInfo ().isActive (arg, getIndices ()))
6243
- addActiveValue (arg);
6244
6242
for (auto &inst : *bb) {
6245
6243
for (auto op : inst.getOperandValues ())
6246
6244
if (getActivityInfo ().isActive (op, getIndices ()))
0 commit comments