Skip to content

Commit 869fc17

Browse files
authored
[AutoDiff] Ensure autodiff code does not ignore getSingleTerminatorOperands return value (#64200)
Also extend activity analysis to handle `try_apply` normal result properly. Add testcase from #63728
1 parent df755d4 commit 869fc17

File tree

4 files changed

+123
-52
lines changed

4 files changed

+123
-52
lines changed

include/swift/SIL/SILArgument.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class SILArgument : public ValueBase {
178178
/// Note: this peeks through any projections or cast implied by the
179179
/// terminator. e.g. the incoming value for a switch_enum payload argument is
180180
/// the enum itself (the operand of the switch_enum).
181-
bool getSingleTerminatorOperands(
181+
[[nodiscard]] bool getSingleTerminatorOperands(
182182
SmallVectorImpl<SILValue> &returnedSingleTermOperands) const;
183183

184184
/// Returns true if we were able to find single terminator operand values for
@@ -188,7 +188,7 @@ class SILArgument : public ValueBase {
188188
/// Note: this peeks through any projections or cast implied by the
189189
/// terminator. e.g. the incoming value for a switch_enum payload argument is
190190
/// the enum itself (the operand of the switch_enum).
191-
bool getSingleTerminatorOperands(
191+
[[nodiscard]] bool getSingleTerminatorOperands(
192192
SmallVectorImpl<std::pair<SILBasicBlock *, SILValue>>
193193
&returnedSingleTermOperands) const;
194194

@@ -303,7 +303,7 @@ class SILPhiArgument : public SILArgument {
303303
/// Note: this peeks through any projections or cast implied by the
304304
/// terminator. e.g. the incoming value for a switch_enum payload argument is
305305
/// the enum itself (the operand of the switch_enum).
306-
bool getSingleTerminatorOperands(
306+
[[nodiscard]] bool getSingleTerminatorOperands(
307307
SmallVectorImpl<SILValue> &returnedSingleTermOperands) const;
308308

309309
/// Returns true if we were able to find single terminator operand values for
@@ -313,7 +313,7 @@ class SILPhiArgument : public SILArgument {
313313
/// Note: this peeks through any projections or cast implied by the
314314
/// terminator. e.g. the incoming value for a switch_enum payload argument is
315315
/// the enum itself (the operand of the switch_enum).
316-
bool getSingleTerminatorOperands(
316+
[[nodiscard]] bool getSingleTerminatorOperands(
317317
SmallVectorImpl<std::pair<SILBasicBlock *, SILValue>>
318318
&returnedSingleTermOperands) const;
319319

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,25 @@ void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
303303
return;
304304
}
305305
setUseful(value, dependentVariableIndex);
306+
306307
// If the given value is a basic block argument, propagate usefulness to
307308
// incoming values.
308309
if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) {
309310
SmallVector<SILValue, 4> incomingValues;
310-
bbArg->getSingleTerminatorOperands(incomingValues);
311-
for (auto incomingValue : incomingValues)
312-
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
313-
return;
311+
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
312+
for (auto incomingValue : incomingValues)
313+
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
314+
return;
315+
} else if (bbArg->isTerminatorResult()) {
316+
if (TryApplyInst *tai = dyn_cast<TryApplyInst>(bbArg->getTerminatorForResult())) {
317+
propagateUseful(tai, dependentVariableIndex);
318+
return;
319+
} else
320+
llvm::report_fatal_error("unknown terminator with result");
321+
} else
322+
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
314323
}
324+
315325
auto *inst = value->getDefiningInstruction();
316326
if (!inst)
317327
return;

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2515,12 +2515,11 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25152515

25162516
// Get predecessor terminator operands.
25172517
SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
2518-
bbArg->getSingleTerminatorOperands(incomingValues);
2519-
2520-
// Returns true if the given terminator instruction is a `switch_enum` on
2521-
// an `Optional`-typed value. `switch_enum` instructions require
2522-
// special-case adjoint value propagation for the operand.
2523-
auto isSwitchEnumInstOnOptional =
2518+
if (bbArg->getSingleTerminatorOperands(incomingValues)) {
2519+
// Returns true if the given terminator instruction is a `switch_enum` on
2520+
// an `Optional`-typed value. `switch_enum` instructions require
2521+
// special-case adjoint value propagation for the operand.
2522+
auto isSwitchEnumInstOnOptional =
25242523
[&ctx = getASTContext()](TermInst *termInst) {
25252524
if (!termInst)
25262525
return false;
@@ -2531,49 +2530,52 @@ void PullbackCloner::Implementation::visitSILBasicBlock(SILBasicBlock *bb) {
25312530
return false;
25322531
};
25332532

2534-
// Check the tangent value category of the active basic block argument.
2535-
switch (getTangentValueCategory(bbArg)) {
2536-
// If argument has a loadable tangent value category: materialize adjoint
2537-
// value of the argument, create a copy, and set the copy as the adjoint
2538-
// value of incoming values.
2539-
case SILValueCategory::Object: {
2540-
auto bbArgAdj = getAdjointValue(bb, bbArg);
2541-
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
2542-
auto concreteBBArgAdjCopy =
2533+
// Check the tangent value category of the active basic block argument.
2534+
switch (getTangentValueCategory(bbArg)) {
2535+
// If argument has a loadable tangent value category: materialize adjoint
2536+
// value of the argument, create a copy, and set the copy as the adjoint
2537+
// value of incoming values.
2538+
case SILValueCategory::Object: {
2539+
auto bbArgAdj = getAdjointValue(bb, bbArg);
2540+
auto concreteBBArgAdj = materializeAdjointDirect(bbArgAdj, pbLoc);
2541+
auto concreteBBArgAdjCopy =
25432542
builder.emitCopyValueOperation(pbLoc, concreteBBArgAdj);
2544-
for (auto pair : incomingValues) {
2545-
auto *predBB = std::get<0>(pair);
2546-
auto incomingValue = std::get<1>(pair);
2547-
// Handle `switch_enum` on `Optional`.
2548-
auto termInst = bbArg->getSingleTerminator();
2549-
if (isSwitchEnumInstOnOptional(termInst)) {
2550-
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2551-
} else {
2552-
blockTemporaries[getPullbackBlock(predBB)].insert(
2543+
for (auto pair : incomingValues) {
2544+
pair.second->dump();
2545+
auto *predBB = std::get<0>(pair);
2546+
auto incomingValue = std::get<1>(pair);
2547+
// Handle `switch_enum` on `Optional`.
2548+
auto termInst = bbArg->getSingleTerminator();
2549+
if (isSwitchEnumInstOnOptional(termInst)) {
2550+
accumulateAdjointForOptional(bb, incomingValue, concreteBBArgAdjCopy);
2551+
} else {
2552+
blockTemporaries[getPullbackBlock(predBB)].insert(
25532553
concreteBBArgAdjCopy);
2554-
setAdjointValue(predBB, incomingValue,
2555-
makeConcreteAdjointValue(concreteBBArgAdjCopy));
2554+
setAdjointValue(predBB, incomingValue,
2555+
makeConcreteAdjointValue(concreteBBArgAdjCopy));
2556+
}
25562557
}
2558+
break;
25572559
}
2558-
break;
2559-
}
2560-
// If argument has an address tangent value category: materialize adjoint
2561-
// value of the argument, create a copy, and set the copy as the adjoint
2562-
// value of incoming values.
2563-
case SILValueCategory::Address: {
2564-
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
2565-
for (auto pair : incomingValues) {
2566-
auto incomingValue = std::get<1>(pair);
2567-
// Handle `switch_enum` on `Optional`.
2568-
auto termInst = bbArg->getSingleTerminator();
2569-
if (isSwitchEnumInstOnOptional(termInst))
2570-
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
2571-
else
2572-
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
2560+
// If argument has an address tangent value category: materialize adjoint
2561+
// value of the argument, create a copy, and set the copy as the adjoint
2562+
// value of incoming values.
2563+
case SILValueCategory::Address: {
2564+
auto bbArgAdjBuf = getAdjointBuffer(bb, bbArg);
2565+
for (auto pair : incomingValues) {
2566+
auto incomingValue = std::get<1>(pair);
2567+
// Handle `switch_enum` on `Optional`.
2568+
auto termInst = bbArg->getSingleTerminator();
2569+
if (isSwitchEnumInstOnOptional(termInst))
2570+
accumulateAdjointForOptional(bb, incomingValue, bbArgAdjBuf);
2571+
else
2572+
addToAdjointBuffer(bb, incomingValue, bbArgAdjBuf, pbLoc);
2573+
}
2574+
break;
25732575
}
2574-
break;
2575-
}
2576-
}
2576+
}
2577+
} else
2578+
llvm::report_fatal_error("do not know how to handle this incoming bb argument");
25772579
}
25782580

25792581
// 3. Build the pullback successor cases for the `switch_enum`
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: not %target-swift-frontend -emit-sil -verify %s
2+
3+
// The testcase from https://github.com/apple/swift/issues/63728 is not valid
4+
// (the function is not differentiable), however, it should not cause verifier errors
5+
// Here the root case is lack of activity analysis for `try_apply` terminators
6+
7+
import _Differentiation
8+
9+
func a() throws {
10+
let keyPaths = (readable: [String: KeyPath<T, Double>](), writable: [String: WritableKeyPath<T, Double>]())
11+
@differentiable(reverse)
12+
func f(p: PAndT) -> Double {
13+
var mutableP = p
14+
let s = p.p.e
15+
var sArray: [[Double]] = []
16+
sArray.append((s["a"]!.asArray()).map {$0.value})
17+
mutableP.s = w(mutableP.s, at: keyPaths.writable["a"]!, with: sArray[0][0])
18+
return mutableP.s[keyPath: keyPaths.writable["a"]!]
19+
}
20+
}
21+
22+
public struct S<I: SProtocol, D> {
23+
public func asArray() -> [(index: I, value: D)] {
24+
return [(index: I, value: D)]()
25+
}
26+
}
27+
struct T: Differentiable {}
28+
struct P: Differentiable {
29+
public var e: F<Double, Double>
30+
}
31+
32+
struct PAndT: Differentiable{
33+
@differentiable(reverse) public var p: P
34+
@differentiable(reverse) public var s: T
35+
}
36+
37+
public struct F<I: SProtocol, D>{
38+
public func asArray() -> [(index: I, value: D)] {return [(index: I, value: D)]()}
39+
public subscript(_ name: String) -> S<I, D>? {get { return self.sGet(name) }}
40+
func sGet(_ name: String) -> S<I, D>? { fatalError("") }
41+
}
42+
43+
public protocol ZProtocol: Differentiable {var z: () -> TangentVector { get }}
44+
public protocol SProtocol: Hashable {}
45+
46+
//extension S: Differentiable where D: ZProtocol, D.TangentVector == D {}
47+
48+
extension F: Differentiable where D: ZProtocol, D.TangentVector == D {}
49+
public extension ZProtocol {var z: () -> TangentVector {{ Self.TangentVector.zero }}}
50+
51+
extension F: Equatable where I: Equatable, D: Equatable {}
52+
//extension S: Equatable where I: Equatable, D: Equatable {}
53+
extension Double: SProtocol, ZProtocol {}
54+
55+
@differentiable(reverse where O: Differentiable, M: ZProtocol)
56+
func w<O, M>(_ o: O, at m: WritableKeyPath<O, M>, with v: M) -> O {return o}
57+
58+
@derivative(of: w)
59+
func vjpw<O, M>(_ o: O, at m: WritableKeyPath<O, M>, with v: M) -> (value: O, pullback: (O.TangentVector) -> (O.TangentVector, M.TangentVector)) where O: Differentiable, M: ZProtocol{fatalError("")}

0 commit comments

Comments
 (0)