19
19
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
20
20
21
21
#include " swift/AST/Decl.h"
22
+ #include " swift/SIL/SILDebugVariable.h"
23
+ #include " swift/SIL/SILLocation.h"
22
24
#include " swift/SIL/SILValue.h"
23
25
#include " llvm/ADT/ArrayRef.h"
24
26
#include " llvm/Support/Debug.h"
@@ -38,10 +40,25 @@ enum AdjointValueKind {
38
40
39
41
// / A concrete SIL value.
40
42
Concrete,
43
+
44
+ // / A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45
+ // / adjoint to add to it. This case exists due to the existence of custom
46
+ // / tangent vectors which may comprise of non-differentiable fields and may
47
+ // / be used in the adjoint of the `struct_extract` and `tuple_extract` SIL
48
+ // / instructions.
49
+ // /
50
+ // / The adjoints for such tangent vectors are not pieceswise materializable,
51
+ // / i.e., cannot be materialized by materializing individual fields. Therefore
52
+ // / when used w/ a `struct_extact`/`tuple_extract` they must be materialized
53
+ // / by first creating a zero tangent vector of the base adjoint and then
54
+ // / in-place adding element adjoint to the specified field.
55
+ AddElement,
41
56
};
42
57
43
58
class AdjointValue ;
44
59
60
+ struct AddElementValue ;
61
+
45
62
class AdjointValueBase {
46
63
friend class AdjointValue ;
47
64
@@ -60,9 +77,13 @@ class AdjointValueBase {
60
77
union Value {
61
78
unsigned numAggregateElements;
62
79
SILValue concrete;
80
+ AddElementValue *addElementValue;
81
+
63
82
Value (unsigned numAggregateElements)
64
83
: numAggregateElements (numAggregateElements) {}
65
84
Value (SILValue v) : concrete (v) {}
85
+ Value (AddElementValue *addElementValue)
86
+ : addElementValue (addElementValue) {}
66
87
Value () {}
67
88
} value;
68
89
@@ -86,6 +107,11 @@ class AdjointValueBase {
86
107
87
108
explicit AdjointValueBase (SILType type, llvm::Optional<DebugInfo> debugInfo)
88
109
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
110
+
111
+ explicit AdjointValueBase (SILType type, AddElementValue *addElementValue,
112
+ llvm::Optional<DebugInfo> debugInfo)
113
+ : kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
114
+ value(addElementValue) {}
89
115
};
90
116
91
117
// / A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +153,14 @@ class AdjointValue final {
127
153
return new (buf) AdjointValueBase (type, elements, debugInfo);
128
154
}
129
155
156
+ static AdjointValue
157
+ createAddElement (llvm::BumpPtrAllocator &allocator, SILType type,
158
+ AddElementValue *addElementValue,
159
+ llvm::Optional<DebugInfo> debugInfo = llvm::None) {
160
+ auto *buf = allocator.Allocate <AdjointValueBase>();
161
+ return new (buf) AdjointValueBase (type, addElementValue, debugInfo);
162
+ }
163
+
130
164
AdjointValueKind getKind () const { return base->kind ; }
131
165
SILType getType () const { return base->type ; }
132
166
CanType getSwiftType () const { return getType ().getASTType (); }
@@ -140,6 +174,9 @@ class AdjointValue final {
140
174
bool isZero () const { return getKind () == AdjointValueKind::Zero; }
141
175
bool isAggregate () const { return getKind () == AdjointValueKind::Aggregate; }
142
176
bool isConcrete () const { return getKind () == AdjointValueKind::Concrete; }
177
+ bool isAddElement () const {
178
+ return getKind () == AdjointValueKind::AddElement;
179
+ }
143
180
144
181
unsigned getNumAggregateElements () const {
145
182
assert (isAggregate ());
@@ -162,41 +199,60 @@ class AdjointValue final {
162
199
return base->value .concrete ;
163
200
}
164
201
165
- void print (llvm::raw_ostream &s) const {
166
- switch (getKind ()) {
167
- case AdjointValueKind::Zero:
168
- s << " Zero[" << getType () << ' ]' ;
169
- break ;
170
- case AdjointValueKind::Aggregate:
171
- s << " Aggregate[" << getType () << " ](" ;
172
- if (auto *decl =
173
- getType ().getASTType ()->getStructOrBoundGenericStruct ()) {
174
- interleave (
175
- llvm::zip (decl->getStoredProperties (), getAggregateElements ()),
176
- [&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
177
- s << std::get<0 >(elt)->getName () << " : " ;
178
- std::get<1 >(elt).print (s);
179
- },
180
- [&s] { s << " , " ; });
181
- } else if (getType ().is <TupleType>()) {
182
- interleave (
183
- getAggregateElements (),
184
- [&s](const AdjointValue &elt) { elt.print (s); },
185
- [&s] { s << " , " ; });
186
- } else {
187
- llvm_unreachable (" Invalid aggregate" );
188
- }
189
- s << ' )' ;
190
- break ;
191
- case AdjointValueKind::Concrete:
192
- s << " Concrete[" << getType () << " ](" << base->value .concrete << ' )' ;
193
- break ;
194
- }
202
+ AddElementValue *getAddElementValue () const {
203
+ assert (isAddElement ());
204
+ return base->value .addElementValue ;
195
205
}
196
206
207
+ void print (llvm::raw_ostream &s) const ;
208
+
197
209
SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); };
198
210
};
199
211
212
+ // / @brief The underlying value for an `AddElement` adjoint.
213
+ struct AddElementValue final {
214
+ AdjointValue baseAdjoint;
215
+ AdjointValue eltToAdd;
216
+
217
+ private:
218
+ union FieldLocator {
219
+ VarDecl *field;
220
+ unsigned int index;
221
+
222
+ FieldLocator (VarDecl *field) : field (field) {}
223
+ FieldLocator (unsigned int index) : index (index) {}
224
+ } fieldLocator;
225
+
226
+ public:
227
+ AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
228
+ VarDecl *field)
229
+ : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
230
+ assert (baseAdjoint.getKind () == AdjointValueKind::Zero);
231
+ }
232
+
233
+ AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
234
+ unsigned int index)
235
+ : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
236
+ assert (baseAdjoint.getKind () == AdjointValueKind::Zero);
237
+ }
238
+
239
+ bool isStructAdjoint () const {
240
+ return !baseAdjoint.getType ().is <TupleType>();
241
+ }
242
+
243
+ bool isTupleAdjoint () const { return baseAdjoint.getType ().is <TupleType>(); }
244
+
245
+ VarDecl *getFieldDecl () const {
246
+ assert (isStructAdjoint ());
247
+ return this ->fieldLocator .field ;
248
+ }
249
+
250
+ unsigned int getFieldIndex () const {
251
+ assert (isTupleAdjoint ());
252
+ return this ->fieldLocator .index ;
253
+ }
254
+ };
255
+
200
256
inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
201
257
const AdjointValue &adjVal) {
202
258
adjVal.print (os);
0 commit comments