Skip to content

[AutoDiff] Simplify varied propagation in activity analysis. #28191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1430,9 +1430,6 @@ class DifferentiableActivityInfo {
/// Perform analysis and populate sets.
void analyze(DominanceInfo *di, PostDominanceInfo *pdi);

void setVaried(SILValue value, unsigned independentVariableIndex);
void setVariedAcrossArrayInitialization(SILValue value,
unsigned independentVariableIndex);
/// Marks the given value as varied and propagates variedness to users.
void setVariedAndPropagateToUsers(SILValue value,
unsigned independentVariableIndex);
Expand Down Expand Up @@ -1883,7 +1880,9 @@ void DifferentiableActivityInfo::setVariedAndPropagateToUsers(
// Skip already-varied values to prevent infinite recursion.
if (isVaried(value, independentVariableIndex))
return;
setVaried(value, independentVariableIndex);
// Set the value as varied.
variedValueSets[independentVariableIndex].insert(value);
// Propagate variedness to users.
for (auto *use : value->getUses())
propagateVaried(use, independentVariableIndex);
}
Expand Down Expand Up @@ -2100,16 +2099,6 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
}
}

void DifferentiableActivityInfo::setVariedAcrossArrayInitialization(
SILValue value, unsigned independentVariableIndex) {
auto uai = getAllocateUninitializedArrayIntrinsic(value);
if (!uai) return;
for (auto use : value->getUses())
if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()))
// The first tuple field of the intrinsic's return value is the array.
setVariedAndPropagateToUsers(dti->getResult(0), independentVariableIndex);
}

void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
SILValue value, unsigned dependentVariableIndex) {
// Array initializer syntax is lowered to an intrinsic and one or more
Expand Down Expand Up @@ -2138,12 +2127,6 @@ void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
}
}

void DifferentiableActivityInfo::setVaried(SILValue value,
unsigned independentVariableIndex) {
variedValueSets[independentVariableIndex].insert(value);
setVariedAcrossArrayInitialization(value, independentVariableIndex);
}

void DifferentiableActivityInfo::setUseful(SILValue value,
unsigned dependentVariableIndex) {
usefulValueSets[dependentVariableIndex].insert(value);
Expand Down
36 changes: 36 additions & 0 deletions test/AutoDiff/activity_analysis.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,42 @@ func testNondifferentiableTupleElementAddr<T>(_ x: T) -> T {
// CHECK: [ACTIVE] %56 = tuple_element_addr %55 : $*(Int, Int, (T, Int), Int), 2
// CHECK: [ACTIVE] %57 = tuple_element_addr %56 : $*(T, Int), 0

// Check activity analysis for `array.uninitialized_intrinsic` applications.

@differentiable
func testArrayUninitializedIntrinsic(_ x: Float, _ y: Float) -> [Float] {
return [x, y]
}

// CHECK-LABEL: [AD] Activity info for ${{.*}}testArrayUninitializedIntrinsic{{.*}} at (source=0 parameters=(0 1))
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
// CHECK: [ACTIVE] %1 = argument of bb0 : $Float
// CHECK: [USEFUL] %4 = integer_literal $Builtin.Word, 2
// CHECK: [NONE] // function_ref _allocateUninitializedArray<A>(_:)
// CHECK: [ACTIVE] %6 = apply %5<Float>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<Float>, Builtin.RawPointer)
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*Float
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
// CHECK: [VARIED] %12 = index_addr %9 : $*Float, %11 : $Builtin.Word

@differentiable(where T: Differentiable)
func testArrayUninitializedIntrinsicGeneric<T>(_ x: T, _ y: T) -> [T] {
return [x, y]
}

// CHECK-LABEL: [AD] Activity info for ${{.*}}testArrayUninitializedIntrinsicGeneric{{.*}} at (source=0 parameters=(0 1))
// CHECK: [VARIED] %0 = argument of bb0 : $*T
// CHECK: [VARIED] %1 = argument of bb0 : $*T
// CHECK: [USEFUL] %4 = integer_literal $Builtin.Word, 2
// CHECK: [NONE] // function_ref _allocateUninitializedArray<A>(_:)
// CHECK: [ACTIVE] %6 = apply %5<T>(%4) : $@convention(thin) <τ_0_0> (Builtin.Word) -> (@owned Array<τ_0_0>, Builtin.RawPointer)
// CHECK: [ACTIVE] (**%7**, %8) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
// CHECK: [VARIED] (%7, **%8**) = destructure_tuple %6 : $(Array<T>, Builtin.RawPointer)
// CHECK: [VARIED] %9 = pointer_to_address %8 : $Builtin.RawPointer to [strict] $*T
// CHECK: [VARIED] %11 = integer_literal $Builtin.Word, 1
// CHECK: [VARIED] %12 = index_addr %9 : $*T, %11 : $Builtin.Word

// TF-781: check activity analysis for active local address + nested conditionals.

@differentiable(wrt: x)
Expand Down
22 changes: 22 additions & 0 deletions test/AutoDiff/array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ ArrayAutoDiffTests.test("ArrayLiteralIndirect") {
expectEqual(Float(2), gradY)
}

ArrayAutoDiffTests.test("ExpressibleByArrayLiteralIndirect") {
struct Indirect<T: Differentiable>: Differentiable & ExpressibleByArrayLiteral {
var x: T

typealias ArrayLiteralElement = T
init(arrayLiteral: T...) {
assert(arrayLiteral.count > 1)
self.x = arrayLiteral[0]
}
}

func testArrayUninitializedIntrinsic<T>(_ x: T, _ y: T) -> Indirect<T> {
return [x, y]
}

let (gradX, gradY) = pullback(at: Float(1), Float(1), in: {
x, y in testArrayUninitializedIntrinsic(x, y)
})(Indirect<Float>.TangentVector(x: 1))
expectEqual(1, gradX)
expectEqual(0, gradY)
}

ArrayAutoDiffTests.test("ArrayConcat") {
struct TwoArrays : Differentiable {
var a: [Float]
Expand Down