Skip to content

Commit 111917e

Browse files
committed
[AutoDiff] Fixes small bugs in the PR fixing #66522
During internal testing we discovered 2 more bugs - 1. The element adjoint of a struct_extract can itself be an AddElement. 2. Indirect concrete adjoint materialization was missing a copy operation. This commit fixes these bugs and adds relevant test cases.
1 parent b9fc40b commit 111917e

File tree

3 files changed

+99
-14
lines changed

3 files changed

+99
-14
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ class PullbackCloner::Implementation final
424424
/// the destination address.
425425
case AdjointValueKind::Concrete: {
426426
auto concreteVal = val.getConcreteValue();
427-
builder.emitStoreValueOperation(loc, concreteVal, destAddress,
427+
auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
428+
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
428429
StoreOwnershipQualifier::Init);
429430
break;
430431
}
@@ -1241,17 +1242,14 @@ class PullbackCloner::Implementation final
12411242
break;
12421243
}
12431244
case AdjointValueKind::Aggregate:
1244-
case AdjointValueKind::Concrete: {
1245+
case AdjointValueKind::Concrete:
1246+
case AdjointValueKind::AddElement: {
12451247
auto baseAdj = makeZeroAdjointValue(tangentVectorSILTy);
12461248
addAdjointValue(bb, sei->getOperand(),
12471249
makeAddElementAdjointValue(baseAdj, eltAdj, tanField),
12481250
loc);
12491251
break;
12501252
}
1251-
case AdjointValueKind::AddElement: {
1252-
llvm_unreachable(
1253-
"Adjoint of extracted element in `StructExtractInst` cannot be of kind `AddElement`");
1254-
}
12551253
}
12561254
break;
12571255
}
@@ -1439,7 +1437,8 @@ class PullbackCloner::Implementation final
14391437
break;
14401438
}
14411439
case AdjointValueKind::Aggregate:
1442-
case AdjointValueKind::Concrete: {
1440+
case AdjointValueKind::Concrete:
1441+
case AdjointValueKind::AddElement: {
14431442
auto tupleTy = tei->getTupleType();
14441443
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
14451444
if (!tupleTanTupleTy) {
@@ -1466,10 +1465,6 @@ class PullbackCloner::Implementation final
14661465
}
14671466
break;
14681467
}
1469-
case AdjointValueKind::AddElement: {
1470-
llvm_unreachable(
1471-
"Adjoint of extracted element in `TupleExtractInst` cannot be of kind `AddElement`");
1472-
}
14731468
}
14741469
}
14751470

test/AutoDiff/SILOptimizer/pullback_generation.sil

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,13 @@ bb0(%0 : $(Float, Float)):
124124
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'
125125

126126
//===----------------------------------------------------------------------===//
127-
// Pullback generation - `tuple_extract`
128-
// - Input to pullback has non-owned ownership semantics which requires copying
129-
// this value to stack before lifetime-ending uses.
127+
// Pullback generation - Inner values of concrete adjoints must be copied
128+
// during direct materialization.
129+
// - If the input to pullback BB has non-owned ownership semantics we cannot
130+
// perform a lifetime-ending operation on it.
131+
// - If the input to the pullback BB is an owned, non-trivial value we must
132+
// copy it or there will be a double consume when all owned parameters are
133+
// destroyed at the end of the basic block.
130134
//===----------------------------------------------------------------------===//
131135
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
132136
}
@@ -164,3 +168,19 @@ bb0(%0 : @guaranteed $(X, X)):
164168
// CHECK: destroy_value %19 : $(X, X)
165169
// CHECK: return %21 : $(X, X)
166170
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'
171+
172+
//===----------------------------------------------------------------------===//
173+
// Pullback generation - `tuple_extract`
174+
// - Adjoint of extracted element can be `AddElement`
175+
// - Just need to make sure that we are able to generate a pullback
176+
//===----------------------------------------------------------------------===//
177+
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
178+
}
179+
180+
sil hidden [ossa] @function_with_tuple_extract_3: $@convention(thin) (((Float, Float), Float)) -> Float {
181+
bb0(%0 : $((Float, Float), Float)):
182+
%1 = tuple_extract %0 : $((Float, Float), Float), 0
183+
%2 = tuple_extract %1 : $(Float, Float), 0
184+
return %2 : $Float
185+
}
186+
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_3TJpSpSr : $@convention(thin) (Float) -> ((Float, Float), Float) {

test/AutoDiff/SILOptimizer/pullback_generation.swift

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,73 @@ func f3(v: NonPiecewiseMaterializableWithAggDifferentiableField) -> PiecewiseMat
128128
// CHECK: destroy_value %13 : $NonPiecewiseMaterializableWithAggDifferentiableField
129129
// CHECK: return %16 : $NonPiecewiseMaterializableWithAggDifferentiableField
130130
// CHECK: } // end sil function '$s19pullback_generation2f31vAA23PiecewiseMaterializableVAA03NondE26WithAggDifferentiableFieldV_tFTJpSpSr'
131+
132+
//===----------------------------------------------------------------------===//
133+
// Pullback generation - `struct_extract`
134+
// - Adjoint of extracted element can be `AddElement`
135+
// - Just need to make sure that we are able to generate a pullback for B.x's
136+
// getter
137+
//===----------------------------------------------------------------------===//
138+
struct A: Differentiable {
139+
public var x: Float
140+
}
141+
142+
struct B: Differentiable {
143+
var y: A
144+
145+
public init(a: A) {
146+
self.y = a
147+
}
148+
149+
@differentiable(reverse)
150+
public var x: Float {
151+
get { return self.y.x }
152+
}
153+
}
154+
155+
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation1BV1xSfvgTJpSpSr : $@convention(thin) (Float) -> B.TangentVector {
156+
157+
//===----------------------------------------------------------------------===//
158+
// Pullback generation - Inner values of concrete adjoints must be copied
159+
// during indirect materialization
160+
//===----------------------------------------------------------------------===//
161+
162+
struct NonTrivial {
163+
var x: Float
164+
var y: String
165+
}
166+
167+
extension NonTrivial: Differentiable, Equatable, AdditiveArithmetic {
168+
public typealias TangentVector = Self
169+
mutating func move(by offset: TangentVector) {fatalError()}
170+
public static var zero: Self {fatalError()}
171+
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
172+
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
173+
}
174+
175+
@differentiable(reverse)
176+
func f4(a: NonTrivial) -> Float {
177+
var sum: Float = 0
178+
for _ in 0..<1 {
179+
sum += a.x
180+
}
181+
return sum
182+
}
183+
184+
// CHECK-LABEL: sil private [ossa] @$s19pullback_generation2f41aSfAA10NonTrivialV_tFTJpSpSr : $@convention(thin) (Float, @guaranteed Builtin.NativeObject) -> @owned NonTrivial {
185+
// CHECK: bb5(%67 : @owned $NonTrivial, %68 : $Float, %69 : @owned $(predecessor: _AD__$s19pullback_generation2f41aSfAA10NonTrivialV_tF_bb2__Pred__src_0_wrt_0, @callee_guaranteed (@inout Float) -> Float)):
186+
// CHECK: %88 = alloc_stack $NonTrivial
187+
188+
// Non-trivial value must be copied or there will be a
189+
// double consume when all owned parameters are destroyed
190+
// at the end of the basic block.
191+
// CHECK: %89 = copy_value %67 : $NonTrivial
192+
193+
// CHECK: store %89 to [init] %88 : $*NonTrivial
194+
// CHECK: %91 = struct_element_addr %88 : $*NonTrivial, #NonTrivial.x
195+
// CHECK: %92 = alloc_stack $Float
196+
// CHECK: store %86 to [trivial] %92 : $*Float
197+
// CHECK: %94 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
198+
// CHECK: %95 = metatype $@thick Float.Type
199+
// CHECK: %96 = apply %94<Float>(%91, %92, %95) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
200+
// CHECK: destroy_value %67 : $NonTrivial

0 commit comments

Comments
 (0)