Skip to content

Commit 2902ad2

Browse files
committed
Adding @asl's fix for subset parameters thunks involving functions with multiple results, and an activity analysis test representing code that had exposed that issue.
1 parent 4918248 commit 2902ad2

File tree

3 files changed

+106
-12
lines changed

3 files changed

+106
-12
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ void collectMinimalIndicesForFunctionCall(
232232
auto &param = paramAndIdx.value();
233233
if (!param.isIndirectMutating())
234234
continue;
235-
unsigned idx = paramAndIdx.index();
235+
unsigned idx =
236+
paramAndIdx.index() + calleeFnTy->getNumIndirectFormalResults();
236237
auto inoutArg = ai->getArgument(idx);
237238
results.push_back(inoutArg);
238239
resultIndices.push_back(inoutParamResultIndex++);

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,12 @@ getOrCreateSubsetParametersThunkForLinearMap(
472472
return mappedIndex;
473473
};
474474

475+
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
476+
auto useNextIndirectResult = [&]() {
477+
assert(toIndirectResultsIter != thunk->getIndirectResults().end());
478+
arguments.push_back(*toIndirectResultsIter++);
479+
};
480+
475481
switch (kind) {
476482
// Differential arguments are:
477483
// - All indirect results, followed by:
@@ -480,9 +486,29 @@ getOrCreateSubsetParametersThunkForLinearMap(
480486
// indices).
481487
// - Zeros (when parameter is not in desired indices).
482488
case AutoDiffDerivativeFunctionKind::JVP: {
483-
// Forward all indirect results.
484-
arguments.append(thunk->getIndirectResults().begin(),
485-
thunk->getIndirectResults().end());
489+
unsigned numIndirectResults = linearMapType->getNumIndirectFormalResults();
490+
// Forward desired indirect results.
491+
for (unsigned idx : *actualConfig.resultIndices) {
492+
if (idx >= numIndirectResults)
493+
break;
494+
495+
auto resultInfo = linearMapType->getResults()[idx];
496+
assert(idx < linearMapType->getNumResults());
497+
498+
// Forward result argument in case we do not need to thunk it away.
499+
if (desiredConfig.resultIndices->contains(idx)) {
500+
useNextIndirectResult();
501+
continue;
502+
}
503+
504+
// Otherwise, allocate and use an uninitialized indirect result.
505+
auto *indirectResult = builder.createAllocStack(
506+
loc, resultInfo.getSILStorageInterfaceType());
507+
localAllocations.push_back(indirectResult);
508+
arguments.push_back(indirectResult);
509+
}
510+
assert(toIndirectResultsIter == thunk->getIndirectResults().end());
511+
486512
auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
487513
auto useNextArgument = [&]() { arguments.push_back(*toArgIter++); };
488514
// Iterate over actual indices.
@@ -507,10 +533,6 @@ getOrCreateSubsetParametersThunkForLinearMap(
507533
// - Zeros (when parameter is not in desired indices).
508534
// - All actual arguments.
509535
case AutoDiffDerivativeFunctionKind::VJP: {
510-
auto toIndirectResultsIter = thunk->getIndirectResults().begin();
511-
auto useNextIndirectResult = [&]() {
512-
arguments.push_back(*toIndirectResultsIter++);
513-
};
514536
// Collect pullback arguments.
515537
unsigned pullbackResultIndex = 0;
516538
for (unsigned i : actualConfig.parameterIndices->getIndices()) {
@@ -539,8 +561,18 @@ getOrCreateSubsetParametersThunkForLinearMap(
539561
arguments.push_back(indirectResult);
540562
}
541563
// Forward all actual non-indirect-result arguments.
542-
arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(),
543-
thunk->getArgumentsWithoutIndirectResults().end() - 1);
564+
auto thunkArgs = thunk->getArgumentsWithoutIndirectResults();
565+
// Slice out the function to be called.
566+
thunkArgs = thunkArgs.slice(0, thunkArgs.size() - 1);
567+
unsigned thunkArg = 0;
568+
for (unsigned idx : *actualConfig.resultIndices) {
569+
// Forward result argument in case we do not need to thunk it away.
570+
if (desiredConfig.resultIndices->contains(idx))
571+
arguments.push_back(thunkArgs[thunkArg++]);
572+
else { // Otherwise, zero it out.
573+
buildZeroArgument(linearMapType->getParameters()[arguments.size()]);
574+
}
575+
}
544576
break;
545577
}
546578
}
@@ -550,10 +582,33 @@ getOrCreateSubsetParametersThunkForLinearMap(
550582
auto *ai = builder.createApply(loc, linearMap, SubstitutionMap(), arguments);
551583

552584
// If differential thunk, deallocate local allocations and directly return
553-
// `apply` result.
585+
// `apply` result (if it is desired).
554586
if (kind == AutoDiffDerivativeFunctionKind::JVP) {
587+
SmallVector<SILValue, 8> differentialDirectResults;
588+
extractAllElements(ai, builder, differentialDirectResults);
589+
SmallVector<SILValue, 8> allResults;
590+
collectAllActualResultsInTypeOrder(ai, differentialDirectResults,
591+
allResults);
592+
unsigned numResults = thunk->getConventions().getNumDirectSILResults() +
593+
thunk->getConventions().getNumDirectSILResults();
594+
SmallVector<SILValue, 8> results;
595+
for (unsigned idx : *actualConfig.resultIndices) {
596+
if (idx >= numResults)
597+
break;
598+
599+
auto result = allResults[idx];
600+
if (desiredConfig.isWrtResult(idx))
601+
results.push_back(result);
602+
else {
603+
if (result->getType().isAddress())
604+
builder.emitDestroyAddrAndFold(loc, result);
605+
else
606+
builder.emitDestroyValueOperation(loc, result);
607+
}
608+
}
555609
cleanupValues();
556-
builder.createReturn(loc, ai);
610+
auto result = joinElements(results, builder, loc);
611+
builder.createReturn(loc, result);
557612
return {thunk, interfaceSubs};
558613
}
559614

test/AutoDiff/SILOptimizer/activity_analysis.swift

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,44 @@ func activeInoutArgNonactiveInitialResult(_ x: Float) -> Float {
533533
// CHECK: [ACTIVE] %13 = begin_access [read] [static] %2 : $*Float
534534
// CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float
535535

536+
public struct ArrayWrapper: Differentiable {
537+
var values: [Float]
538+
539+
@differentiable(reverse)
540+
mutating func get(index: Int) -> Float {
541+
self.values[index]
542+
}
543+
544+
// Check `inout` with result.
545+
546+
// CHECK-LABEL: [AD] Activity info for ${{.*}}get{{.*}} at parameter indices (1) and result indices (0, 1)
547+
// CHECK: bb0:
548+
// CHECK: [USEFUL] %0 = argument of bb0 : $Int
549+
// CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper
550+
// CHECK: [ACTIVE] %4 = begin_access [read] [static] %1 : $*ArrayWrapper
551+
// CHECK: [ACTIVE] %5 = struct_element_addr %4 : $*ArrayWrapper, #ArrayWrapper.values
552+
// CHECK: [ACTIVE] %6 = load_borrow %5 : $*Array<Float>
553+
// CHECK: [ACTIVE] %7 = alloc_stack $Float
554+
// CHECK: [NONE] // function_ref Array.subscript.getter
555+
// CHECK: %8 = function_ref @$sSayxSicig : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0
556+
// CHECK: [NONE] %9 = apply %8<Float>(%7, %0, %6) : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0
557+
// CHECK: [ACTIVE] %10 = load [trivial] %7 : $*Float
558+
}
559+
560+
@differentiable(reverse)
561+
func testInoutAndResult(x: Int, y: inout ArrayWrapper) {
562+
let _ = y.get(index: x)
563+
}
564+
565+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testInoutAndResult{{.*}} at parameter indices (1) and result indices (0)
566+
// CHECK: bb0:
567+
// CHECK: [USEFUL] %0 = argument of bb0 : $Int
568+
// CHECK: [ACTIVE] %1 = argument of bb0 : $*ArrayWrapper
569+
// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %1 : $*ArrayWrapper
570+
// CHECK: [NONE] // function_ref ArrayWrapper.get(index:)
571+
// CHECK: %5 = function_ref @$s17activity_analysis12ArrayWrapperV3get5indexSfSi_tF : $@convention(method) (Int, @inout ArrayWrapper) -> Float
572+
// CHECK: [VARIED] %6 = apply %5(%0, %4) : $@convention(method) (Int, @inout ArrayWrapper) -> Float
573+
536574
//===----------------------------------------------------------------------===//
537575
// Throwing function differentiation (`try_apply`)
538576
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)