@@ -42,16 +42,8 @@ enum AdjointValueKind {
42
42
Concrete,
43
43
44
44
// / A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45
- // / adjoint to add to it. This case exists due to the existence of custom
46
- // / tangent vectors which may comprise of non-differentiable fields and may
47
- // / be used in the adjoint of the `struct_extract` and `tuple_extract` SIL
48
- // / instructions.
49
- // /
50
- // / The adjoints for such tangent vectors are not pieceswise materializable,
51
- // / i.e., cannot be materialized by materializing individual fields. Therefore
52
- // / when used w/ a `struct_extact`/`tuple_extract` they must be materialized
53
- // / by first creating a zero tangent vector of the base adjoint and then
54
- // / in-place adding element adjoint to the specified field.
45
+ // / adjoint to add to it. This case exists to avoid eager materialization of
46
+ // / `Aggregate` and `Zero` adjoints upon addition.
55
47
AddElement,
56
48
};
57
49
@@ -227,13 +219,15 @@ struct AddElementValue final {
227
219
AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
228
220
VarDecl *field)
229
221
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
230
- assert (baseAdjoint.getKind () == AdjointValueKind::Zero);
222
+ assert (baseAdjoint.getKind () == AdjointValueKind::Zero ||
223
+ baseAdjoint.getKind () == AdjointValueKind::Aggregate);
231
224
}
232
225
233
226
AddElementValue (AdjointValue baseAdjoint, AdjointValue eltToAdd,
234
227
unsigned int index)
235
228
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
236
- assert (baseAdjoint.getKind () == AdjointValueKind::Zero);
229
+ assert (baseAdjoint.getKind () == AdjointValueKind::Zero ||
230
+ baseAdjoint.getKind () == AdjointValueKind::Aggregate);
237
231
}
238
232
239
233
bool isStructAdjoint () const {
0 commit comments