@@ -2093,25 +2093,27 @@ bool PullbackCloner::Implementation::run() {
2093
2093
}
2094
2094
}
2095
2095
2096
- // Do a second pass for all inout values .
2096
+ // Do a second pass for all inout parameters .
2097
2097
for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2098
- // Skip `inout` parameters for functions with single basic blocks:
2099
- // additional adjoint accumulation is not necessary.
2098
+ // Skip non-inout parameters.
2099
+ auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2100
+ if (!isParameterInout)
2101
+ continue ;
2102
+
2103
+ // Skip `inout` parameters for functions with a single basic block:
2104
+ // adjoint accumulation for those parameters is already done by
2105
+ // per-instruction visitors.
2106
+ if (getOriginal ().size () == 1 )
2107
+ continue ;
2100
2108
2101
2109
// For functions with multiple basic blocks, accumulation is needed
2102
2110
// for `inout` parameters because pullback basic blocks have different
2103
2111
// adjoint buffers.
2104
- auto isOriginalFunctionSingleBasicBlock = getOriginal ().size () == 1 ;
2105
- auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2106
- if (isParameterInout) {
2107
- if (isOriginalFunctionSingleBasicBlock) {
2108
- continue ;
2109
- } else {
2110
- pullbackIndirectResults.push_back (
2111
- getPullback ().getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++]);
2112
- }
2113
- addRetElt (i);
2114
- }
2112
+ auto pullbackInoutArgument =
2113
+ getPullback ()
2114
+ .getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++];
2115
+ pullbackIndirectResults.push_back (pullbackInoutArgument);
2116
+ addRetElt (i);
2115
2117
}
2116
2118
2117
2119
// Copy them to adjoint indirect results.
@@ -2126,9 +2128,9 @@ bool PullbackCloner::Implementation::run() {
2126
2128
} else {
2127
2129
builder.createCopyAddr (pbLoc, source, dest, IsTake, IsNotInitialization);
2128
2130
}
2131
+ currentIndex++;
2129
2132
// Prevent source buffer from being deallocated, since the underlying
2130
2133
// value is moved.
2131
- currentIndex++;
2132
2134
destroyedLocalAllocations.insert (source);
2133
2135
}
2134
2136
0 commit comments