Skip to content

Commit 211b5ae

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent cd14d72 commit 211b5ae

File tree

4 files changed

+147
-169
lines changed

4 files changed

+147
-169
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,12 +1163,7 @@ class PullbackCloner::Implementation final
11631163
// Accumulate adjoint for the `struct_extract` operand.
11641164
auto av = getAdjointValue(bb, sei);
11651165
switch (av.getKind()) {
1166-
case AdjointValueKind::Zero:
1167-
addAdjointValue(bb, sei->getOperand(),
1168-
makeZeroAdjointValue(tangentVectorSILTy), loc);
1169-
break;
1170-
case AdjointValueKind::Concrete:
1171-
case AdjointValueKind::Aggregate: {
1166+
case AdjointValueKind::Concrete: {
11721167
// Materialize adj[x]
11731168
auto *adjBaseAlloc = builder.createAllocStack(loc, tangentVectorSILTy);
11741169
materializeAdjointIndirect(makeZeroAdjointValue(tangentVectorSILTy),
@@ -1199,6 +1194,11 @@ class PullbackCloner::Implementation final
11991194
makeConcreteAdjointValue(adjBaseConcrete));
12001195
break;
12011196
}
1197+
case AdjointValueKind::Zero:
1198+
case AdjointValueKind::Aggregate: {
1199+
llvm_unreachable("Adjoint value kind of the input to `struct_extract` "
1200+
"can only ever be concrete.");
1201+
}
12021202
}
12031203
break;
12041204
}
@@ -1375,11 +1375,6 @@ class PullbackCloner::Implementation final
13751375
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
13761376
auto av = getAdjointValue(bb, tei);
13771377
switch (av.getKind()) {
1378-
case AdjointValueKind::Zero:
1379-
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
1380-
loc);
1381-
break;
1382-
case AdjointValueKind::Aggregate:
13831378
case AdjointValueKind::Concrete: {
13841379
// Materialize adj[x]
13851380
auto *adjBaseAlloc = builder.createAllocStack(loc, tupleTanTy);
@@ -1411,6 +1406,11 @@ class PullbackCloner::Implementation final
14111406
makeConcreteAdjointValue(adjBaseConcrete));
14121407
break;
14131408
}
1409+
case AdjointValueKind::Zero:
1410+
case AdjointValueKind::Aggregate: {
1411+
llvm_unreachable("Adjoint value kind of the input to `tuple_extract` can "
1412+
"only ever be concrete.");
1413+
}
14141414
}
14151415
}
14161416

test/AutoDiff/SILOptimizer/pullback_generation.sil

Lines changed: 74 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// Pullback generation tests written in SIL for features
22
// that may not be directly supported by the Swift frontend
33

4-
// RUN: %target-sil-opt --differentiation -debug-only=differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
4+
// RUN: %target-sil-opt --differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
55

66
//===----------------------------------------------------------------------===//
77
// Pullback generation - `struct_extract`
@@ -57,27 +57,31 @@ bb0(%0 : @guaranteed $Y):
5757
return %2 : $X
5858
}
5959

60-
// CHECK-LABEL: [ORIG] %1 = struct_extract %0 : $Y, #Y.a // user: %2
61-
// CHECK: [ADJ] Emitted in pullback (pb bb0):
62-
// CHECK: %1 = alloc_stack $Y // users: {{.*}}
63-
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %4
64-
// CHECK: %3 = metatype $@thick Y.Type // user: %4
65-
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
60+
// CHECK-LABEL: sil private [ossa] @$function_with_struct_extract_1TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned Y {
61+
// CHECK: bb0(%0 : @guaranteed $X):
62+
// CHECK: %1 = alloc_stack $Y // users: {{.*}}
63+
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %4
64+
// CHECK: %3 = metatype $@thick Y.Type // user: %4
65+
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
6666

6767
// Since input parameter $0 has non-owned ownership semantics, it
6868
// needs to be copied before a lifetime-ending use.
69-
// CHECK: %5 = copy_value %0 : $X // user: %7
70-
71-
// CHECK: %6 = alloc_stack $X // users: {{.*}}
72-
// CHECK: store %5 to [init] %6 : $*X // id: %7
73-
// CHECK: %8 = struct_element_addr %1 : $*Y, #Y.a // user: %11
74-
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %11
75-
// CHECK: %10 = metatype $@thick X.Type // user: %11
76-
// CHECK: %11 = apply %9<X>(%8, %6, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
77-
// CHECK: %12 = load [take] %1 : $*Y
78-
// CHECK: destroy_addr %6 : $*X // id: %13
79-
// CHECK: dealloc_stack %6 : $*X // id: %14
80-
// CHECK: dealloc_stack %1 : $*Y // id: %15
69+
// CHECK: %5 = copy_value %0 : $X // user: %7
70+
71+
// CHECK: %6 = alloc_stack $X // users: {{.*}}
72+
// CHECK: store %5 to [init] %6 : $*X // id: %7
73+
// CHECK: %8 = struct_element_addr %1 : $*Y, #Y.a // user: %11
74+
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %11
75+
// CHECK: %10 = metatype $@thick X.Type // user: %11
76+
// CHECK: %11 = apply %9<X>(%8, %6, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
77+
// CHECK: %12 = load [take] %1 : $*Y // users: {{.*}}
78+
// CHECK: destroy_addr %6 : $*X // id: %13
79+
// CHECK: dealloc_stack %6 : $*X // id: %14
80+
// CHECK: dealloc_stack %1 : $*Y // id: %15
81+
// CHECK: %16 = copy_value %12 : $Y // user: %18
82+
// CHECK: destroy_value %12 : $Y // id: %17
83+
// CHECK: return %16 : $Y // id: %18
84+
// CHECK: } // end sil function '$function_with_struct_extract_1TJpSpSr'
8185

8286
//===----------------------------------------------------------------------===//
8387
// Pullback generation - `tuple_extract`
@@ -94,27 +98,30 @@ bb0(%0 : $(Float, Float)):
9498
return %1 : $Float
9599
}
96100

97-
// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(Float, Float), 0 // user: %2
98-
// CHECK: [ADJ] Emitted in pullback (pb bb0):
99-
// CHECK: %1 = alloc_stack $(Float, Float) // users: {{.*}}
100-
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %5
101-
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
102-
// CHECK: %4 = metatype $@thick Float.Type // user: %5
103-
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
104-
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1 // user: %9
105-
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
106-
// CHECK: %8 = metatype $@thick Float.Type // user: %9
107-
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
108-
// CHECK: %10 = alloc_stack $Float // users: {{.*}}
109-
// CHECK: store %0 to [trivial] %10 : $*Float // id: %11
110-
// CHECK: %12 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %15
111-
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %15
112-
// CHECK: %14 = metatype $@thick Float.Type // user: %15
113-
// CHECK: %15 = apply %13<Float>(%12, %10, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
114-
// CHECK: %16 = load [trivial] %1 : $*(Float, Float)
115-
// CHECK: destroy_addr %10 : $*Float // id: %17
116-
// CHECK: dealloc_stack %10 : $*Float // id: %18
117-
// CHECK: dealloc_stack %1 : $*(Float, Float) // id: %19
101+
102+
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_1TJpSpSr : $@convention(thin) (Float) -> (Float, Float) {
103+
// CHECK: bb0(%0 : $Float):
104+
// CHECK: %1 = alloc_stack $(Float, Float) // users: {{.*}}
105+
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %5
106+
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
107+
// CHECK: %4 = metatype $@thick Float.Type // user: %5
108+
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
109+
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1 // user: %9
110+
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
111+
// CHECK: %8 = metatype $@thick Float.Type // user: %9
112+
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
113+
// CHECK: %10 = alloc_stack $Float // users: {{.*}}
114+
// CHECK: store %0 to [trivial] %10 : $*Float // id: %11
115+
// CHECK: %12 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %15
116+
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %15
117+
// CHECK: %14 = metatype $@thick Float.Type // user: %15
118+
// CHECK: %15 = apply %13<Float>(%12, %10, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
119+
// CHECK: %16 = load [trivial] %1 : $*(Float, Float) // user: %20
120+
// CHECK: destroy_addr %10 : $*Float // id: %17
121+
// CHECK: dealloc_stack %10 : $*Float // id: %18
122+
// CHECK: dealloc_stack %1 : $*(Float, Float) // id: %19
123+
// CHECK: return %16 : $(Float, Float) // id: %20
124+
// CHECK: } // end sil function 'function_with_tuple_extract_1TJpSpSr'
118125

119126
//===----------------------------------------------------------------------===//
120127
// Pullback generation - `tuple_extract`
@@ -131,25 +138,30 @@ bb0(%0 : @guaranteed $(X, X)):
131138
return %2 : $X
132139
}
133140

134-
// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(X, X), 0 // user: %2
135-
// CHECK: [ADJ] Emitted in pullback (pb bb0):
136-
// CHECK: %1 = alloc_stack $(X, X) // users: {{.*}}
137-
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0 // user: %5
138-
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
139-
// CHECK: %4 = metatype $@thick X.Type // user: %5
140-
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
141-
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1 // user: %9
142-
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
143-
// CHECK: %8 = metatype $@thick X.Type // user: %9
144-
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
145-
// CHECK: %10 = copy_value %0 : $X // user: %12
146-
// CHECK: %11 = alloc_stack $X // users: {{.*}}
147-
// CHECK: store %10 to [init] %11 : $*X // id: %12
148-
// CHECK: %13 = tuple_element_addr %1 : $*(X, X), 0 // user: %16
149-
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16
150-
// CHECK: %15 = metatype $@thick X.Type // user: %16
151-
// CHECK: %16 = apply %14<X>(%13, %11, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
152-
// CHECK: %17 = load [take] %1 : $*(X, X)
153-
// CHECK: destroy_addr %11 : $*X // id: %18
154-
// CHECK: dealloc_stack %11 : $*X // id: %19
155-
// CHECK: dealloc_stack %1 : $*(X, X) // id: %20
141+
// CHECK-LABEL: sil private [ossa] @function_with_tuple_extract_2TJpSpSr : $@convention(thin) (@guaranteed X) -> @owned (X, X) {
142+
// CHECK: bb0(%0 : @guaranteed $X):
143+
// CHECK: %1 = alloc_stack $(X, X) // users: {{.*}}
144+
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0 // user: %5
145+
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
146+
// CHECK: %4 = metatype $@thick X.Type // user: %5
147+
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
148+
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1 // user: %9
149+
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
150+
// CHECK: %8 = metatype $@thick X.Type // user: %9
151+
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
152+
// CHECK: %10 = copy_value %0 : $X // user: %12
153+
// CHECK: %11 = alloc_stack $X // users: {{.*}}
154+
// CHECK: store %10 to [init] %11 : $*X // id: %12
155+
// CHECK: %13 = tuple_element_addr %1 : $*(X, X), 0 // user: %16
156+
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16
157+
// CHECK: %15 = metatype $@thick X.Type // user: %16
158+
// CHECK: %16 = apply %14<X>(%13, %11, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
159+
// CHECK: %17 = load [take] %1 : $*(X, X) // users: {{.*}}
160+
// CHECK: destroy_addr %11 : $*X // id: %18
161+
// CHECK: dealloc_stack %11 : $*X // id: %19
162+
// CHECK: dealloc_stack %1 : $*(X, X) // id: %20
163+
// CHECK: %21 = copy_value %17 : $(X, X) // user: %23
164+
// CHECK: destroy_value %17 : $(X, X) // id: %22
165+
// CHECK: return %21 : $(X, X) // id: %23
166+
// CHECK: } // end sil function 'function_with_tuple_extract_2TJpSpSr'
167+

0 commit comments

Comments
 (0)