Skip to content

Commit 8b645c3

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent 63cfd7c commit 8b645c3

File tree

3 files changed

+343
-44
lines changed

3 files changed

+343
-44
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 59 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,22 +1349,35 @@ class PullbackCloner::Implementation final
13491349
break;
13501350
case AdjointValueKind::Concrete:
13511351
case AdjointValueKind::Aggregate: {
1352-
SmallVector<AdjointValue, 8> eltVals;
1353-
for (auto *field : tangentVectorDecl->getStoredProperties()) {
1354-
if (field == tanField) {
1355-
eltVals.push_back(av);
1356-
} else {
1357-
auto substMap = tangentVectorTy->getMemberSubstitutionMap(
1358-
field->getModuleContext(), field);
1359-
auto fieldTy = field->getType().subst(substMap);
1360-
auto fieldSILTy = getTypeLowering(fieldTy).getLoweredType();
1361-
assert(fieldSILTy.isObject());
1362-
eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
1363-
}
1364-
}
1365-
addAdjointValue(bb, sei->getOperand(),
1366-
makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
1367-
loc);
1352+
// Materialize adj[x]
1353+
auto *adjBaseAlloc = builder.createAllocStack(loc, tangentVectorSILTy);
1354+
materializeAdjointIndirect(makeZeroAdjointValue(tangentVectorSILTy),
1355+
adjBaseAlloc, loc);
1356+
1357+
// Materialize adj[y]
1358+
auto adjElt = materializeAdjointDirect(av, loc);
1359+
// Copy `adjElt` so we can get a value with owned ownership
1360+
// semantics, required for using `adjElt` in the store instruction.
1361+
auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
1362+
auto *adjEltAlloc = builder.createAllocStack(loc, adjElt->getType());
1363+
builder.emitStoreValueOperation(loc, adjEltCopy, adjEltAlloc,
1364+
StoreOwnershipQualifier::Init);
1365+
1366+
// Get adj[x].#field`
1367+
auto *adjBaseEltAddr =
1368+
builder.createStructElementAddr(loc, adjBaseAlloc, tanField);
1369+
1370+
// adj[x].#field` += adj[y]
1371+
builder.emitInPlaceAdd(loc, adjBaseEltAddr, adjEltAlloc);
1372+
1373+
auto adjBaseConcrete = recordTemporary(builder.emitLoadValueOperation(
1374+
loc, adjBaseAlloc, LoadOwnershipQualifier::Take));
1375+
builder.createDestroyAddr(loc, adjEltAlloc);
1376+
builder.createDeallocStack(loc, adjEltAlloc);
1377+
builder.createDeallocStack(loc, adjBaseAlloc);
1378+
setAdjointValue(bb, sei->getOperand(),
1379+
makeConcreteAdjointValue(adjBaseConcrete));
1380+
break;
13681381
}
13691382
}
13701383
break;
@@ -1538,42 +1551,44 @@ class PullbackCloner::Implementation final
15381551
/// index corresponding to n
15391552
void visitTupleExtractInst(TupleExtractInst *tei) {
15401553
auto *bb = tei->getParent();
1554+
auto loc = tei->getLoc();
15411555
auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
15421556
auto av = getAdjointValue(bb, tei);
15431557
switch (av.getKind()) {
15441558
case AdjointValueKind::Zero:
15451559
addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
1546-
tei->getLoc());
1560+
loc);
15471561
break;
15481562
case AdjointValueKind::Aggregate:
15491563
case AdjointValueKind::Concrete: {
1550-
auto tupleTy = tei->getTupleType();
1551-
auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
1552-
if (!tupleTanTupleTy) {
1553-
addAdjointValue(bb, tei->getOperand(), av, tei->getLoc());
1554-
break;
1555-
}
1556-
SmallVector<AdjointValue, 8> elements;
1557-
unsigned adjIdx = 0;
1558-
for (unsigned i : range(tupleTy->getNumElements())) {
1559-
if (!getTangentSpace(
1560-
tupleTy->getElement(i).getType()->getCanonicalType()))
1561-
continue;
1562-
if (tei->getFieldIndex() == i)
1563-
elements.push_back(av);
1564-
else
1565-
elements.push_back(makeZeroAdjointValue(
1566-
getRemappedTangentType(SILType::getPrimitiveObjectType(
1567-
tupleTanTupleTy->getElementType(adjIdx++)
1568-
->getCanonicalType()))));
1569-
}
1570-
if (elements.size() == 1) {
1571-
addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc());
1572-
break;
1573-
}
1574-
addAdjointValue(bb, tei->getOperand(),
1575-
makeAggregateAdjointValue(tupleTanTy, elements),
1576-
tei->getLoc());
1564+
// Materialize adj[x]
1565+
auto *adjBaseAlloc = builder.createAllocStack(loc, tupleTanTy);
1566+
materializeAdjointIndirect(makeZeroAdjointValue(tupleTanTy), adjBaseAlloc,
1567+
loc);
1568+
1569+
// Materialize adj[y]
1570+
auto adjElt = materializeAdjointDirect(av, loc);
1571+
// Copy `adjElt` so we can get a value with owned ownership
1572+
// semantics, required for using `adjElt` in the store instruction.
1573+
auto adjEltCopy = builder.emitCopyValueOperation(loc, adjElt);
1574+
auto *adjEltAlloc = builder.createAllocStack(loc, adjElt->getType());
1575+
builder.emitStoreValueOperation(loc, adjEltCopy, adjEltAlloc,
1576+
StoreOwnershipQualifier::Init);
1577+
1578+
// Get adj[x][n`]
1579+
auto *adjBaseEltAddr = builder.createTupleElementAddr(
1580+
loc, adjBaseAlloc, tei->getFieldIndex());
1581+
1582+
// adj[x][n`] += adj[y]
1583+
builder.emitInPlaceAdd(loc, adjBaseEltAddr, adjEltAlloc);
1584+
1585+
auto adjBaseConcrete = recordTemporary(builder.emitLoadValueOperation(
1586+
loc, adjBaseAlloc, LoadOwnershipQualifier::Take));
1587+
builder.createDestroyAddr(loc, adjEltAlloc);
1588+
builder.createDeallocStack(loc, adjEltAlloc);
1589+
builder.createDeallocStack(loc, adjBaseAlloc);
1590+
setAdjointValue(bb, tei->getOperand(),
1591+
makeConcreteAdjointValue(adjBaseConcrete));
15771592
break;
15781593
}
15791594
}
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Pullback generation tests written in SIL for features
2+
// that may not be directly supported by the Swift frontend
3+
4+
// RUN: %target-sil-opt --differentiation -debug-only=differentiation -emit-sorted-sil %s 2>&1 | %FileCheck %s
5+
6+
//===----------------------------------------------------------------------===//
7+
// Pullback generation - `struct_extract`
8+
// - Input to pullback has non-owned ownership semantics which requires copying
9+
// this value to stack before lifetime-ending uses.
10+
//===----------------------------------------------------------------------===//
11+
12+
sil_stage raw
13+
14+
import Builtin
15+
import Swift
16+
import SwiftShims
17+
18+
import _Differentiation
19+
20+
struct X {
21+
@_hasStorage var a: Float { get set }
22+
@_hasStorage var b: String { get set }
23+
init(a: Float, b: String)
24+
}
25+
26+
extension X : Differentiable, Equatable, AdditiveArithmetic {
27+
public typealias TangentVector = X
28+
mutating func move(by offset: X)
29+
public static var zero: X { get }
30+
public static func + (lhs: X, rhs: X) -> X
31+
public static func - (lhs: X, rhs: X) -> X
32+
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: X, _ b: X) -> Bool
33+
}
34+
35+
struct Y {
36+
@_hasStorage var a: X { get set }
37+
@_hasStorage var b: String { get set }
38+
init(a: X, b: String)
39+
}
40+
41+
extension Y : Differentiable, Equatable, AdditiveArithmetic {
42+
public typealias TangentVector = Y
43+
mutating func move(by offset: Y)
44+
public static var zero: Y { get }
45+
public static func + (lhs: Y, rhs: Y) -> Y
46+
public static func - (lhs: Y, rhs: Y) -> Y
47+
@_implements(Equatable, ==(_:_:)) static func __derived_struct_equals(_ a: Y, _ b: Y) -> Bool
48+
}
49+
50+
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
51+
}
52+
53+
sil hidden [ossa] @$function_with_struct_extract_1 : $@convention(thin) (@guaranteed Y) -> @owned X {
54+
bb0(%0 : @guaranteed $Y):
55+
%1 = struct_extract %0 : $Y, #Y.a
56+
%2 = copy_value %1 : $X
57+
return %2 : $X
58+
}
59+
60+
// CHECK-LABEL: [ORIG] %1 = struct_extract %0 : $Y, #Y.a // user: %2
61+
// CHECK: [ADJ] Emitted in pullback (pb bb0):
62+
// CHECK: %1 = alloc_stack $Y // users: {{.*}}
63+
// CHECK: %2 = witness_method $Y, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %4
64+
// CHECK: %3 = metatype $@thick Y.Type // user: %4
65+
// CHECK: %4 = apply %2<Y>(%1, %3) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
66+
67+
// Since input parameter $0 has non-owned ownership semantics, it
68+
// needs to be copied before a lifetime-ending use.
69+
// CHECK: %5 = copy_value %0 : $X // user: %7
70+
71+
// CHECK: %6 = alloc_stack $X // users: {{.*}}
72+
// CHECK: store %5 to [init] %6 : $*X // id: %7
73+
// CHECK: %8 = struct_element_addr %1 : $*Y, #Y.a // user: %11
74+
// CHECK: %9 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %11
75+
// CHECK: %10 = metatype $@thick X.Type // user: %11
76+
// CHECK: %11 = apply %9<X>(%8, %6, %10) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
77+
// CHECK: %12 = load [take] %1 : $*Y
78+
// CHECK: destroy_addr %6 : $*X // id: %13
79+
// CHECK: dealloc_stack %6 : $*X // id: %14
80+
// CHECK: dealloc_stack %1 : $*Y // id: %15
81+
82+
//===----------------------------------------------------------------------===//
83+
// Pullback generation - `tuple_extract`
84+
// - Tuples as differentiable input arguments are not supported yet, so creating
85+
// a basic test in SIL instead.
86+
//===----------------------------------------------------------------------===//
87+
88+
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
89+
}
90+
91+
sil hidden [ossa] @function_with_tuple_extract_1: $@convention(thin) ((Float, Float)) -> Float {
92+
bb0(%0 : $(Float, Float)):
93+
%1 = tuple_extract %0 : $(Float, Float), 0
94+
return %1 : $Float
95+
}
96+
97+
// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(Float, Float), 0 // user: %2
98+
// CHECK: [ADJ] Emitted in pullback (pb bb0):
99+
// CHECK: %1 = alloc_stack $(Float, Float) // users: {{.*}}
100+
// CHECK: %2 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %5
101+
// CHECK: %3 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
102+
// CHECK: %4 = metatype $@thick Float.Type // user: %5
103+
// CHECK: %5 = apply %3<Float>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
104+
// CHECK: %6 = tuple_element_addr %1 : $*(Float, Float), 1 // user: %9
105+
// CHECK: %7 = witness_method $Float, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
106+
// CHECK: %8 = metatype $@thick Float.Type // user: %9
107+
// CHECK: %9 = apply %7<Float>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
108+
// CHECK: %10 = alloc_stack $Float // users: {{.*}}
109+
// CHECK: store %0 to [trivial] %10 : $*Float // id: %11
110+
// CHECK: %12 = tuple_element_addr %1 : $*(Float, Float), 0 // user: %15
111+
// CHECK: %13 = witness_method $Float, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %15
112+
// CHECK: %14 = metatype $@thick Float.Type // user: %15
113+
// CHECK: %15 = apply %13<Float>(%12, %10, %14) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
114+
// CHECK: %16 = load [trivial] %1 : $*(Float, Float)
115+
// CHECK: destroy_addr %10 : $*Float // id: %17
116+
// CHECK: dealloc_stack %10 : $*Float // id: %18
117+
// CHECK: dealloc_stack %1 : $*(Float, Float) // id: %19
118+
119+
//===----------------------------------------------------------------------===//
120+
// Pullback generation - `tuple_extract`
121+
// - Input to pullback has non-owned ownership semantics which requires copying
122+
// this value to stack before lifetime-ending uses.
123+
//===----------------------------------------------------------------------===//
124+
sil_differentiability_witness hidden [reverse] [parameters 0] [results 0] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
125+
}
126+
127+
sil hidden [ossa] @function_with_tuple_extract_2: $@convention(thin) (@guaranteed (X, X)) -> @owned X {
128+
bb0(%0 : @guaranteed $(X, X)):
129+
%1 = tuple_extract %0 : $(X, X), 0
130+
%2 = copy_value %1: $X
131+
return %2 : $X
132+
}
133+
134+
// CHECK-LABEL: [ORIG] %1 = tuple_extract %0 : $(X, X), 0 // user: %2
135+
// CHECK: [ADJ] Emitted in pullback (pb bb0):
136+
// CHECK: %1 = alloc_stack $(X, X) // users: {{.*}}
137+
// CHECK: %2 = tuple_element_addr %1 : $*(X, X), 0 // user: %5
138+
// CHECK: %3 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %5
139+
// CHECK: %4 = metatype $@thick X.Type // user: %5
140+
// CHECK: %5 = apply %3<X>(%2, %4) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
141+
// CHECK: %6 = tuple_element_addr %1 : $*(X, X), 1 // user: %9
142+
// CHECK: %7 = witness_method $X, #AdditiveArithmetic.zero!getter : <Self where Self : AdditiveArithmetic> (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // user: %9
143+
// CHECK: %8 = metatype $@thick X.Type // user: %9
144+
// CHECK: %9 = apply %7<X>(%6, %8) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0
145+
// CHECK: %10 = copy_value %0 : $X // user: %12
146+
// CHECK: %11 = alloc_stack $X // users: {{.*}}
147+
// CHECK: store %10 to [init] %11 : $*X // id: %12
148+
// CHECK: %13 = tuple_element_addr %1 : $*(X, X), 0 // user: %16
149+
// CHECK: %14 = witness_method $X, #AdditiveArithmetic."+=" : <Self where Self : AdditiveArithmetic> (Self.Type) -> (inout Self, Self) -> () : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> () // user: %16
150+
// CHECK: %15 = metatype $@thick X.Type // user: %16
151+
// CHECK: %16 = apply %14<X>(%13, %11, %15) : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@inout τ_0_0, @in_guaranteed τ_0_0, @thick τ_0_0.Type) -> ()
152+
// CHECK: %17 = load [take] %1 : $*(X, X)
153+
// CHECK: destroy_addr %11 : $*X // id: %18
154+
// CHECK: dealloc_stack %11 : $*X // id: %19
155+
// CHECK: dealloc_stack %1 : $*(X, X) // id: %20

0 commit comments

Comments
 (0)