Skip to content

Commit 35aefd3

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent 2effcab commit 35aefd3

File tree

4 files changed

+182
-139
lines changed

4 files changed

+182
-139
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,8 @@ enum AdjointValueKind {
4242
Concrete,
4343

4444
/// A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45-
/// adjoint to add to it. This case exists due to the existence of custom
46-
/// tangent vectors which may comprise of non-differentiable fields and may
47-
/// be used in the adjoint of the `struct_extract` and `tuple_extract` SIL
48-
/// instructions.
49-
///
50-
/// The adjoints for such tangent vectors are not pieceswise materializable,
51-
/// i.e., cannot be materialized by materializing individual fields. Therefore
52-
/// when used w/ a `struct_extact`/`tuple_extract` they must be materialized
53-
/// by first creating a zero tangent vector of the base adjoint and then
54-
/// in-place adding element adjoint to the specified field.
45+
/// adjoint to add to it. This case exists to avoid eager materialization of
46+
/// `Aggregate` and `Zero` adjoints upon addition.
5547
AddElement,
5648
};
5749

@@ -227,13 +219,15 @@ struct AddElementValue final {
227219
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
228220
VarDecl *field)
229221
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
230-
assert(baseAdjoint.getKind() == AdjointValueKind::Zero);
222+
assert(baseAdjoint.getKind() == AdjointValueKind::Zero ||
223+
baseAdjoint.getKind() == AdjointValueKind::Aggregate);
231224
}
232225

233226
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
234227
unsigned int index)
235228
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
236-
assert(baseAdjoint.getKind() == AdjointValueKind::Zero);
229+
assert(baseAdjoint.getKind() == AdjointValueKind::Zero ||
230+
baseAdjoint.getKind() == AdjointValueKind::Aggregate);
237231
}
238232

239233
bool isStructAdjoint() const {

lib/SILOptimizer/Differentiation/AdjointValue.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@ void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
5151
auto *addElementValue = getAddElementValue();
5252
auto baseAdjoint = addElementValue->baseAdjoint;
5353
auto eltToAdd = addElementValue->eltToAdd;
54+
auto baseAdjointKind =
55+
baseAdjoint.getKind() == AdjointValueKind::Zero ? "Zero" : "Aggregate";
5456

5557
s << "AddElement[";
5658
if (addElementValue->isTupleAdjoint()) {
57-
s << "Zero[" << baseAdjoint.getType() << "].#n["
59+
s << baseAdjointKind << "[" << baseAdjoint.getType() << "].#n["
5860
<< addElementValue->getFieldIndex() << "] += Concrete["
5961
<< eltToAdd.getType() << "]";
6062
} else {
61-
s << "Zero[" << baseAdjoint.getType() << "].#field["
63+
s << baseAdjointKind << "[" << baseAdjoint.getType() << "].#field["
6264
<< addElementValue->getFieldDecl()->getNameStr() << "] += Concrete["
6365
<< eltToAdd.getType() << "]";
6466
}

0 commit comments

Comments
 (0)