19
19
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
20
20
21
21
#include " swift/AST/Decl.h"
22
+ #include " swift/SIL/SILDebugVariable.h"
23
+ #include " swift/SIL/SILLocation.h"
22
24
#include " swift/SIL/SILValue.h"
23
25
#include " llvm/ADT/ArrayRef.h"
24
26
#include " llvm/Support/Debug.h"
@@ -38,10 +40,18 @@ enum AdjointValueKind {
38
40
39
41
// / A concrete SIL value.
40
42
Concrete,
43
+
44
+ // / A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
45
+ // / an element adjoint to add to one of its fields. This case exists to avoid
46
+ // / eager materialization of a base adjoint upon addition with one of its
47
+ // / fields.
48
+ AddElement,
41
49
};
42
50
43
51
class AdjointValue ;
44
52
53
+ struct AddElementValue ;
54
+
45
55
class AdjointValueBase {
46
56
friend class AdjointValue ;
47
57
@@ -60,9 +70,13 @@ class AdjointValueBase {
60
70
union Value {
61
71
unsigned numAggregateElements;
62
72
SILValue concrete;
73
+ AddElementValue *addElementValue;
74
+
63
75
Value (unsigned numAggregateElements)
64
76
: numAggregateElements (numAggregateElements) {}
65
77
Value (SILValue v) : concrete (v) {}
78
+ Value (AddElementValue *addElementValue)
79
+ : addElementValue (addElementValue) {}
66
80
Value () {}
67
81
} value;
68
82
@@ -86,6 +100,11 @@ class AdjointValueBase {
86
100
87
101
explicit AdjointValueBase (SILType type, llvm::Optional<DebugInfo> debugInfo)
88
102
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
103
+
104
+ explicit AdjointValueBase (SILType type, AddElementValue *addElementValue,
105
+ llvm::Optional<DebugInfo> debugInfo)
106
+ : kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
107
+ value(addElementValue) {}
89
108
};
90
109
91
110
// / A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +146,14 @@ class AdjointValue final {
127
146
return new (buf) AdjointValueBase (type, elements, debugInfo);
128
147
}
129
148
149
+ static AdjointValue
150
+ createAddElement (llvm::BumpPtrAllocator &allocator, SILType type,
151
+ AddElementValue *addElementValue,
152
+ llvm::Optional<DebugInfo> debugInfo = llvm::None) {
153
+ auto *buf = allocator.Allocate <AdjointValueBase>();
154
+ return new (buf) AdjointValueBase (type, addElementValue, debugInfo);
155
+ }
156
+
130
157
AdjointValueKind getKind () const { return base->kind ; }
131
158
SILType getType () const { return base->type ; }
132
159
CanType getSwiftType () const { return getType ().getASTType (); }
@@ -140,6 +167,9 @@ class AdjointValue final {
140
167
bool isZero () const { return getKind () == AdjointValueKind::Zero; }
141
168
bool isAggregate () const { return getKind () == AdjointValueKind::Aggregate; }
142
169
bool isConcrete () const { return getKind () == AdjointValueKind::Concrete; }
170
+ bool isAddElement () const {
171
+ return getKind () == AdjointValueKind::AddElement;
172
+ }
143
173
144
174
unsigned getNumAggregateElements () const {
145
175
assert (isAggregate ());
@@ -162,41 +192,77 @@ class AdjointValue final {
162
192
return base->value .concrete ;
163
193
}
164
194
165
- void print (llvm::raw_ostream &s) const {
166
- switch (getKind ()) {
167
- case AdjointValueKind::Zero:
168
- s << " Zero[" << getType () << ' ]' ;
169
- break ;
170
- case AdjointValueKind::Aggregate:
171
- s << " Aggregate[" << getType () << " ](" ;
172
- if (auto *decl =
173
- getType ().getASTType ()->getStructOrBoundGenericStruct ()) {
174
- interleave (
175
- llvm::zip (decl->getStoredProperties (), getAggregateElements ()),
176
- [&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
177
- s << std::get<0 >(elt)->getName () << " : " ;
178
- std::get<1 >(elt).print (s);
179
- },
180
- [&s] { s << " , " ; });
181
- } else if (getType ().is <TupleType>()) {
182
- interleave (
183
- getAggregateElements (),
184
- [&s](const AdjointValue &elt) { elt.print (s); },
185
- [&s] { s << " , " ; });
186
- } else {
187
- llvm_unreachable (" Invalid aggregate" );
188
- }
189
- s << ' )' ;
190
- break ;
191
- case AdjointValueKind::Concrete:
192
- s << " Concrete[" << getType () << " ](" << base->value .concrete << ' )' ;
193
- break ;
194
- }
195
+ AddElementValue *getAddElementValue () const {
196
+ assert (isAddElement ());
197
+ return base->value .addElementValue ;
195
198
}
196
199
200
+ void print (llvm::raw_ostream &s) const ;
201
+
197
202
SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); };
198
203
};
199
204
205
+ // / An abstraction that represents the field locator in
206
+ // / an `AddElement` adjoint kind. Depending on the aggregate
207
+ // / kind - tuple or struct, of the `baseAdjoint` in an
208
+ // / `AddElement` adjoint, the field locator may be an `unsigned int`
209
+ // / or a `VarDecl *`.
210
+ struct FieldLocator final {
211
+ FieldLocator (VarDecl *field) : inner(field) {}
212
+ FieldLocator (unsigned int index) : inner(index) {}
213
+
214
+ friend AddElementValue;
215
+
216
+ private:
217
+ bool isTupleFieldLocator () const {
218
+ return std::holds_alternative<unsigned int >(inner);
219
+ }
220
+
221
+ const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
222
+ std::true_type{};
223
+ const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
224
+ std::false_type{};
225
+
226
+ unsigned int getInner (std::true_type) const {
227
+ return std::get<unsigned int >(inner);
228
+ }
229
+
230
+ VarDecl *getInner (std::false_type) const {
231
+ return std::get<VarDecl *>(inner);
232
+ }
233
+
234
+ std::variant<unsigned int , VarDecl *> inner;
235
+ };
236
+
237
+ // / The underlying value for an `AddElement` adjoint.
238
+ struct AddElementValue final {
239
+ AdjointValue baseAdjoint;
240
+ AdjointValue eltToAdd;
241
+ FieldLocator fieldLocator;
242
+
243
+ AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
244
+ FieldLocator fieldLocator)
245
+ : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
246
+ fieldLocator (fieldLocator) {
247
+ assert (baseAdjoint.getType ().is <TupleType>() ||
248
+ baseAdjoint.getType ().getStructOrBoundGenericStruct () != nullptr );
249
+ }
250
+
251
+ bool isTupleAdjoint () const { return fieldLocator.isTupleFieldLocator (); }
252
+
253
+ bool isStructAdjoint () const { return !isTupleAdjoint (); }
254
+
255
+ VarDecl *getFieldDecl () const {
256
+ assert (isStructAdjoint ());
257
+ return this ->fieldLocator .getInner (FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
258
+ }
259
+
260
+ unsigned int getFieldIndex () const {
261
+ assert (isTupleAdjoint ());
262
+ return this ->fieldLocator .getInner (FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
263
+ }
264
+ };
265
+
200
266
inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
201
267
const AdjointValue &adjVal) {
202
268
adjVal.print (os);
0 commit comments