@@ -2052,6 +2052,8 @@ bool PullbackCloner::Implementation::run() {
2052
2052
SmallVector<SILValue, 8 > retElts;
2053
2053
// This vector will contain all indirect parameter adjoint buffers.
2054
2054
SmallVector<SILValue, 4 > indParamAdjoints;
2055
+ // This vector will identify the locations where initialization is needed.
2056
+ SmallVector<bool , 8 > outputsToInitialize;
2055
2057
2056
2058
auto conv = getOriginal ().getConventions ();
2057
2059
auto origParams = getOriginal ().getArgumentsWithoutIndirectResults ();
@@ -2071,27 +2073,62 @@ bool PullbackCloner::Implementation::run() {
2071
2073
case SILValueCategory::Address: {
2072
2074
auto adjBuf = getAdjointBuffer (origEntry, origParam);
2073
2075
indParamAdjoints.push_back (adjBuf);
2076
+ outputsToInitialize.push_back (
2077
+ !conv.getParameters ()[parameterIndex].isIndirectMutating ());
2074
2078
break ;
2075
2079
}
2076
2080
}
2077
2081
};
2082
+ SmallVector<SILArgument *, 4 > pullbackIndirectResults (
2083
+ getPullback ().getIndirectResults ().begin (),
2084
+ getPullback ().getIndirectResults ().end ());
2085
+
2078
2086
// Collect differentiation parameter adjoints.
2087
+ // Do a first pass to collect non-inout values.
2088
+ unsigned pullbackInoutArgumentIndex = 0 ;
2079
2089
for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2080
- // Skip `inout` parameters.
2081
- if (conv.getParameters ()[i].isIndirectMutating ())
2082
- continue ;
2083
- addRetElt (i);
2090
+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2091
+ if (!isParameterInout) {
2092
+ addRetElt (i);
2093
+ }
2094
+ }
2095
+
2096
+ // Do a second pass for all inout values.
2097
+ for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2098
+ // Skip `inout` parameters for functions with single basic blocks:
2099
+ // additional adjoint accumulation is not necessary.
2100
+
2101
+ // For functions with multiple basic blocks, accumulation is needed
2102
+ // for `inout` parameters because pullback basic blocks have different
2103
+ // adjoint buffers.
2104
+ auto isOriginalFunctionSingleBasicBlock = getOriginal ().size () == 1 ;
2105
+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2106
+ if (isParameterInout) {
2107
+ if (isOriginalFunctionSingleBasicBlock) {
2108
+ continue ;
2109
+ } else {
2110
+ pullbackIndirectResults.push_back (
2111
+ getPullback ().getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++]);
2112
+ }
2113
+ addRetElt (i);
2114
+ }
2084
2115
}
2085
2116
2086
2117
// Copy them to adjoint indirect results.
2087
- assert (indParamAdjoints.size () == getPullback (). getIndirectResults () .size () &&
2118
+ assert (indParamAdjoints.size () == pullbackIndirectResults .size () &&
2088
2119
" Indirect parameter adjoint count mismatch" );
2089
- for (auto pair : zip (indParamAdjoints, getPullback ().getIndirectResults ())) {
2120
+ unsigned currentIndex = 0 ;
2121
+ for (auto pair : zip (indParamAdjoints, pullbackIndirectResults)) {
2090
2122
auto source = std::get<0 >(pair);
2091
2123
auto *dest = std::get<1 >(pair);
2092
- builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2124
+ if (outputsToInitialize[currentIndex]) {
2125
+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsInitialization);
2126
+ } else {
2127
+ builder.createCopyAddr (pbLoc, source, dest, IsTake, IsNotInitialization);
2128
+ }
2093
2129
// Prevent source buffer from being deallocated, since the underlying
2094
2130
// value is moved.
2131
+ currentIndex++;
2095
2132
destroyedLocalAllocations.insert (source);
2096
2133
}
2097
2134
0 commit comments