Skip to content

Commit 43041ee

Browse files
dan-zhengbgogul
authored andcommitted
[AutoDiff] Simplify varied propagation in activity analysis. (#28191)
- Remove `setVaried`. - `setVaried` has only one user and can be inlined. - Remove `setVariedAcrossArrayInitialization`. - Activity info does not change when removing special variedness propagation support for `array.uninitialized_intrinsic` applications. - Verified with array initialization activity info test cases. - Add indirect `ExpressibleByArrayLiteral` test.
1 parent 3f7d4c1 commit 43041ee

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,9 +1430,6 @@ class DifferentiableActivityInfo {
14301430
/// Perform analysis and populate sets.
14311431
void analyze(DominanceInfo *di, PostDominanceInfo *pdi);
14321432

1433-
void setVaried(SILValue value, unsigned independentVariableIndex);
1434-
void setVariedAcrossArrayInitialization(SILValue value,
1435-
unsigned independentVariableIndex);
14361433
/// Marks the given value as varied and propagates variedness to users.
14371434
void setVariedAndPropagateToUsers(SILValue value,
14381435
unsigned independentVariableIndex);
@@ -1883,7 +1880,9 @@ void DifferentiableActivityInfo::setVariedAndPropagateToUsers(
18831880
// Skip already-varied values to prevent infinite recursion.
18841881
if (isVaried(value, independentVariableIndex))
18851882
return;
1886-
setVaried(value, independentVariableIndex);
1883+
// Set the value as varied.
1884+
variedValueSets[independentVariableIndex].insert(value);
1885+
// Propagate variedness to users.
18871886
for (auto *use : value->getUses())
18881887
propagateVaried(use, independentVariableIndex);
18891888
}
@@ -2100,16 +2099,6 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
21002099
}
21012100
}
21022101

2103-
void DifferentiableActivityInfo::setVariedAcrossArrayInitialization(
2104-
SILValue value, unsigned independentVariableIndex) {
2105-
auto uai = getAllocateUninitializedArrayIntrinsic(value);
2106-
if (!uai) return;
2107-
for (auto use : value->getUses())
2108-
if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()))
2109-
// The first tuple field of the intrinsic's return value is the array.
2110-
setVariedAndPropagateToUsers(dti->getResult(0), independentVariableIndex);
2111-
}
2112-
21132102
void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
21142103
SILValue value, unsigned dependentVariableIndex) {
21152104
// Array initializer syntax is lowered to an intrinsic and one or more
@@ -2138,12 +2127,6 @@ void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
21382127
}
21392128
}
21402129

2141-
void DifferentiableActivityInfo::setVaried(SILValue value,
2142-
unsigned independentVariableIndex) {
2143-
variedValueSets[independentVariableIndex].insert(value);
2144-
setVariedAcrossArrayInitialization(value, independentVariableIndex);
2145-
}
2146-
21472130
void DifferentiableActivityInfo::setUseful(SILValue value,
21482131
unsigned dependentVariableIndex) {
21492132
usefulValueSets[dependentVariableIndex].insert(value);

test/AutoDiff/activity_analysis.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,42 @@ func testNondifferentiableTupleElementAddr<T>(_ x: T) -> T {
5757
// CHECK: [ACTIVE] %56 = tuple_element_addr %55 : $*(Int, Int, (T, Int), Int), 2
5858
// CHECK: [ACTIVE] %57 = tuple_element_addr %56 : $*(T, Int), 0
5959

60+
// Check activity analysis for `array.uninitialized_intrinsic` applications.
61+
62+
@differentiable
63+
func testArrayUninitializedIntrinsic(_ x: Float, _ y: Float) -> [Float] {
64+
return [x, y]
65+
}
66+
67+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testArrayUninitializedIntrinsic{{.*}} at (source=0 parameters=(0 1))
68+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
69+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Float
70+
// CHECK: [USEFUL] %4 = integer_literal $Builtin.Word, 2
71+
// CHECK: [NONE] // function_ref _allocateUninitializedArray<A>(_:)
72+
// CHECK: [ACTIVE] %6 = apply %5<Float>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
73+
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
74+
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
75+
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
76+
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
77+
// CHECK: [VARIED] %12 = index_addr %9 : $*Float, %11 : $Builtin.Word
78+
79+
@differentiable(where T: Differentiable)
80+
func testArrayUninitializedIntrinsicGeneric<T>(_ x: T, _ y: T) -> [T] {
81+
return [x, y]
82+
}
83+
84+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testArrayUninitializedIntrinsicGeneric{{.*}} at (source=0 parameters=(0 1))
85+
// CHECK: [VARIED] %0 = argument of bb0 : $*T
86+
// CHECK: [VARIED] %1 = argument of bb0 : $*T
87+
// CHECK: [USEFUL] %4 = integer_literal $Builtin.Word, 2
88+
// CHECK: [NONE] // function_ref _allocateUninitializedArray<A>(_:)
89+
// CHECK: [ACTIVE] %6 = apply %5<T>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
90+
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
91+
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
92+
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*T
93+
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
94+
// CHECK: [VARIED] %12 = index_addr %9 : $*T, %11 : $Builtin.Word
95+
6096
// TF-781: check activity analysis for active local address + nested conditionals.
6197

6298
@differentiable(wrt: x)

test/AutoDiff/array.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,28 @@ ArrayAutoDiffTests.test("ArrayLiteralIndirect") {
5656
expectEqual(Float(2), gradY)
5757
}
5858

59+
ArrayAutoDiffTests.test("ExpressibleByArrayLiteralIndirect") {
60+
struct Indirect<T: Differentiable>: Differentiable & ExpressibleByArrayLiteral {
61+
var x: T
62+
63+
typealias ArrayLiteralElement = T
64+
init(arrayLiteral: T...) {
65+
assert(arrayLiteral.count > 1)
66+
self.x = arrayLiteral[0]
67+
}
68+
}
69+
70+
func testArrayUninitializedIntrinsic<T>(_ x: T, _ y: T) -> Indirect<T> {
71+
return [x, y]
72+
}
73+
74+
let (gradX, gradY) = pullback(at: Float(1), Float(1), in: {
75+
x, y in testArrayUninitializedIntrinsic(x, y)
76+
})(Indirect<Float>.TangentVector(x: 1))
77+
expectEqual(1, gradX)
78+
expectEqual(0, gradY)
79+
}
80+
5981
ArrayAutoDiffTests.test("ArrayConcat") {
6082
struct TwoArrays : Differentiable {
6183
var a: [Float]

0 commit comments

Comments
 (0)