@@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
2052
2052
SmallVector<SILValue, 8 > retElts;
2053
2053
// This vector will contain all indirect parameter adjoint buffers.
2054
2054
SmallVector<SILValue, 4 > indParamAdjoints;
2055
+ // This vector will identify the locations where initialization is needed.
2056
+ SmallBitVector outputsToInitialize;
2055
2057
2056
2058
auto conv = getOriginal ().getConventions ();
2057
2059
auto origParams = getOriginal ().getArgumentsWithoutIndirectResults ();
@@ -2071,25 +2073,62 @@ bool PullbackCloner::Implementation::run() {
2071
2073
case SILValueCategory::Address: {
2072
2074
auto adjBuf = getAdjointBuffer (origEntry, origParam);
2073
2075
indParamAdjoints.push_back (adjBuf);
2076
+ outputsToInitialize.push_back (
2077
+ !conv.getParameters ()[parameterIndex].isIndirectMutating ());
2074
2078
break ;
2075
2079
}
2076
2080
}
2077
2081
};
2082
+ SmallVector<SILArgument *, 4 > pullbackIndirectResults (
2083
+ getPullback ().getIndirectResults ().begin (),
2084
+ getPullback ().getIndirectResults ().end ());
2085
+
2078
2086
// Collect differentiation parameter adjoints.
2087
+ // Do a first pass to collect non-inout values.
2088
+ unsigned pullbackInoutArgumentIndex = 0 ;
2089
+ for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2090
+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2091
+ if (!isParameterInout) {
2092
+ addRetElt (i);
2093
+ }
2094
+ }
2095
+
2096
+ // Do a second pass for all inout parameters.
2079
2097
for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2080
- // Skip `inout` parameters.
2081
- if (conv.getParameters ()[i].isIndirectMutating ())
2098
+ // Skip non-inout parameters.
2099
+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2100
+ if (!isParameterInout)
2082
2101
continue ;
2102
+
2103
+ // Skip `inout` parameters for functions with a single basic block:
2104
+ // adjoint accumulation for those parameters is already done by
2105
+ // per-instruction visitors.
2106
+ if (getOriginal ().size () == 1 )
2107
+ continue ;
2108
+
2109
+ // For functions with multiple basic blocks, accumulation is needed
2110
+ // for `inout` parameters because pullback basic blocks have different
2111
+ // adjoint buffers.
2112
+ auto pullbackInoutArgument =
2113
+ getPullback ()
2114
+ .getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++];
2115
+ pullbackIndirectResults.push_back (pullbackInoutArgument);
2083
2116
addRetElt (i);
2084
2117
}
2085
2118
2086
2119
// Copy them to adjoint indirect results.
2087
- assert (indParamAdjoints.size () == getPullback (). getIndirectResults () .size () &&
2120
+ assert (indParamAdjoints.size () == pullbackIndirectResults .size () &&
2088
2121
" Indirect parameter adjoint count mismatch" );
2089
- for (auto pair : zip (indParamAdjoints, getPullback ().getIndirectResults ())) {
2122
+ unsigned currentIndex = 0 ;
2123
+ for (auto pair : zip (indParamAdjoints, pullbackIndirectResults)) {
2090
2124
auto source = std::get<0 >(pair);
2091
2125
auto *dest = std::get<1 >(pair);
2092
- builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2126
+ if (outputsToInitialize[currentIndex]) {
2127
+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2128
+ } else {
2129
+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsNotInitialization);
2130
+ }
2131
+ currentIndex++;
2093
2132
// Prevent source buffer from being deallocated, since the underlying
2094
2133
// value is moved.
2095
2134
destroyedLocalAllocations.insert (source);
0 commit comments