@@ -2085,37 +2085,38 @@ bool PullbackCloner::Implementation::run() {
2085
2085
2086
2086
// Collect differentiation parameter adjoints.
2087
2087
// Do a first pass to collect non-inout values.
2088
- unsigned pullbackInoutArgumentIndex = 0 ;
2089
2088
for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2090
- auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2091
- if (!isParameterInout) {
2089
+ if (!conv.getParameters ()[i].isIndirectMutating ()) {
2090
+ addRetElt (i);
2091
+ }
2092
+ }
2093
+
2094
+ // Do a second pass for all inout parameters, however this is only necessary
2095
+ // for functions with multiple basic blocks. For functions with a single
2096
+ // basic block adjoint accumulation for those parameters is already done by
2097
+ // per-instruction visitors.
2098
+ if (getOriginal ().size () > 1 ) {
2099
+ const auto &pullbackConv = pullback.getConventions ();
2100
+ SmallVector<SILArgument *, 1 > pullbackInOutArgs;
2101
+ for (auto pullbackArg : enumerate(pullback.getArgumentsWithoutIndirectResults ())) {
2102
+ if (pullbackConv.getParameters ()[pullbackArg.index ()].isIndirectMutating ())
2103
+ pullbackInOutArgs.push_back (pullbackArg.value ());
2104
+ }
2105
+
2106
+ unsigned pullbackInoutArgumentIdx = 0 ;
2107
+ for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2108
+ // Skip non-inout parameters.
2109
+ if (!conv.getParameters ()[i].isIndirectMutating ())
2110
+ continue ;
2111
+
2112
+ // For functions with multiple basic blocks, accumulation is needed
2113
+ // for `inout` parameters because pullback basic blocks have different
2114
+ // adjoint buffers.
2115
+ pullbackIndirectResults.push_back (pullbackInOutArgs[pullbackInoutArgumentIdx++]);
2092
2116
addRetElt (i);
2093
2117
}
2094
2118
}
2095
2119
2096
- // Do a second pass for all inout parameters.
2097
- for (auto i : getConfig ().parameterIndices ->getIndices ()) {
2098
- // Skip non-inout parameters.
2099
- auto isParameterInout = conv.getParameters ()[i].isIndirectMutating ();
2100
- if (!isParameterInout)
2101
- continue ;
2102
-
2103
- // Skip `inout` parameters for functions with a single basic block:
2104
- // adjoint accumulation for those parameters is already done by
2105
- // per-instruction visitors.
2106
- if (getOriginal ().size () == 1 )
2107
- continue ;
2108
-
2109
- // For functions with multiple basic blocks, accumulation is needed
2110
- // for `inout` parameters because pullback basic blocks have different
2111
- // adjoint buffers.
2112
- auto pullbackInoutArgument =
2113
- getPullback ()
2114
- .getArgumentsWithoutIndirectResults ()[pullbackInoutArgumentIndex++];
2115
- pullbackIndirectResults.push_back (pullbackInoutArgument);
2116
- addRetElt (i);
2117
- }
2118
-
2119
2120
// Copy them to adjoint indirect results.
2120
2121
assert (indParamAdjoints.size () == pullbackIndirectResults.size () &&
2121
2122
" Indirect parameter adjoint count mismatch" );
0 commit comments