@@ -41,9 +41,10 @@ enum AdjointValueKind {
41
41
// / A concrete SIL value.
42
42
Concrete,
43
43
44
- // / A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45
- // / adjoint to add to it. This case exists to avoid eager materialization of
46
- // / `Aggregate` and `Zero` adjoints upon addition.
44
+ // / A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
45
+ // / an element adjoint to add to one of its fields. This case exists to avoid
46
+ // / eager materialization of a base adjoint upon addition with one of its
47
+ // / fields.
47
48
AddElement,
48
49
};
49
50
@@ -201,49 +202,64 @@ class AdjointValue final {
201
202
SWIFT_DEBUG_DUMP { print (llvm::dbgs ()); };
202
203
};
203
204
204
- // / @brief The underlying value for an `AddElement` adjoint.
205
- struct AddElementValue final {
206
- AdjointValue baseAdjoint;
207
- AdjointValue eltToAdd;
205
+ // / An abstraction that represents the field locator in
206
+ // / an `AddElement` adjoint kind. Depending on the aggregate
207
+ // / kind - tuple or struct, of the `baseAdjoint` in an
208
+ // / `AddElement` adjoint, the field locator may be an `unsigned int`
209
+ // / or a `VarDecl *`.
210
+ struct FieldLocator final {
211
+ FieldLocator (VarDecl *field) : inner(field) {}
212
+ FieldLocator (unsigned int index) : inner(index) {}
213
+
214
+ friend AddElementValue;
208
215
209
216
private:
210
- union FieldLocator {
211
- VarDecl *field ;
212
- unsigned int index;
217
+ bool isTupleFieldLocator () const {
218
+ return std::holds_alternative< unsigned int >(inner) ;
219
+ }
213
220
214
- FieldLocator (VarDecl *field) : field (field) {}
215
- FieldLocator (unsigned int index) : index (index) {}
216
- } fieldLocator;
221
+ const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
222
+ std::true_type{};
223
+ const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
224
+ std::false_type{};
217
225
218
- public:
219
- AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
220
- VarDecl *field)
221
- : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
222
- assert (baseAdjoint.getKind () == AdjointValueKind::Zero ||
223
- baseAdjoint.getKind () == AdjointValueKind::Aggregate);
226
+ unsigned int getInner (std::true_type) const {
227
+ return std::get<unsigned int >(inner);
224
228
}
225
229
226
- AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
227
- unsigned int index)
228
- : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
229
- assert (baseAdjoint.getKind () == AdjointValueKind::Zero ||
230
- baseAdjoint.getKind () == AdjointValueKind::Aggregate);
230
+ VarDecl *getInner (std::false_type) const {
231
+ return std::get<VarDecl *>(inner);
231
232
}
232
233
233
- bool isStructAdjoint () const {
234
- return !baseAdjoint.getType ().is <TupleType>();
234
+ std::variant<unsigned int , VarDecl *> inner;
235
+ };
236
+
237
+ // / The underlying value for an `AddElement` adjoint.
238
+ struct AddElementValue final {
239
+ AdjointValue baseAdjoint;
240
+ AdjointValue eltToAdd;
241
+ FieldLocator fieldLocator;
242
+
243
+ AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
244
+ FieldLocator fieldLocator)
245
+ : baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
246
+ fieldLocator (fieldLocator) {
247
+ assert (baseAdjoint.getType ().is <TupleType>() ||
248
+ baseAdjoint.getType ().getStructOrBoundGenericStruct () != nullptr );
235
249
}
236
250
237
- bool isTupleAdjoint () const { return baseAdjoint.getType ().is <TupleType>(); }
251
+ bool isTupleAdjoint () const { return fieldLocator.isTupleFieldLocator (); }
252
+
253
+ bool isStructAdjoint () const { return !isTupleAdjoint (); }
238
254
239
255
VarDecl *getFieldDecl () const {
240
256
assert (isStructAdjoint ());
241
- return this ->fieldLocator .field ;
257
+ return this ->fieldLocator .getInner (FieldLocator::STRUCT_FIELD_LOCATOR_TAG) ;
242
258
}
243
259
244
260
unsigned int getFieldIndex () const {
245
261
assert (isTupleAdjoint ());
246
- return this ->fieldLocator .index ;
262
+ return this ->fieldLocator .getInner (FieldLocator::TUPLE_FIELD_LOCATOR_TAG) ;
247
263
}
248
264
};
249
265
0 commit comments