Skip to content

Commit 36bfbd6

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent 35aefd3 commit 36bfbd6

File tree

6 files changed

+240
-410
lines changed

6 files changed

+240
-410
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ enum AdjointValueKind {
4141
/// A concrete SIL value.
4242
Concrete,
4343

44-
/// A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45-
/// adjoint to add to it. This case exists to avoid eager materialization of
46-
/// `Aggregate` and `Zero` adjoints upon addition.
44+
/// A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
45+
/// an element adjoint to add to one of its fields. This case exists to avoid
46+
/// eager materialization of a base adjoint upon addition with one of its
47+
/// fields.
4748
AddElement,
4849
};
4950

@@ -201,49 +202,64 @@ class AdjointValue final {
201202
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
202203
};
203204

204-
/// @brief The underlying value for an `AddElement` adjoint.
205-
struct AddElementValue final {
206-
AdjointValue baseAdjoint;
207-
AdjointValue eltToAdd;
205+
/// An abstraction that represents the field locator in
206+
/// an `AddElement` adjoint kind. Depending on the aggregate
207+
/// kind - tuple or struct, of the `baseAdjoint` in an
208+
/// `AddElement` adjoint, the field locator may be an `unsigned int`
209+
/// or a `VarDecl *`.
210+
struct FieldLocator final {
211+
FieldLocator(VarDecl *field) : inner(field) {}
212+
FieldLocator(unsigned int index) : inner(index) {}
213+
214+
friend AddElementValue;
208215

209216
private:
210-
union FieldLocator {
211-
VarDecl *field;
212-
unsigned int index;
217+
bool isTupleFieldLocator() const {
218+
return std::holds_alternative<unsigned int>(inner);
219+
}
213220

214-
FieldLocator(VarDecl *field) : field(field) {}
215-
FieldLocator(unsigned int index) : index(index) {}
216-
} fieldLocator;
221+
const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
222+
std::true_type{};
223+
const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
224+
std::false_type{};
217225

218-
public:
219-
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
220-
VarDecl *field)
221-
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
222-
assert(baseAdjoint.getKind() == AdjointValueKind::Zero ||
223-
baseAdjoint.getKind() == AdjointValueKind::Aggregate);
226+
unsigned int getInner(std::true_type) const {
227+
return std::get<unsigned int>(inner);
224228
}
225229

226-
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
227-
unsigned int index)
228-
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
229-
assert(baseAdjoint.getKind() == AdjointValueKind::Zero ||
230-
baseAdjoint.getKind() == AdjointValueKind::Aggregate);
230+
VarDecl *getInner(std::false_type) const {
231+
return std::get<VarDecl *>(inner);
231232
}
232233

233-
bool isStructAdjoint() const {
234-
return !baseAdjoint.getType().is<TupleType>();
234+
std::variant<unsigned int, VarDecl *> inner;
235+
};
236+
237+
/// The underlying value for an `AddElement` adjoint.
238+
struct AddElementValue final {
239+
AdjointValue baseAdjoint;
240+
AdjointValue eltToAdd;
241+
FieldLocator fieldLocator;
242+
243+
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
244+
FieldLocator fieldLocator)
245+
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
246+
fieldLocator(fieldLocator) {
247+
assert(baseAdjoint.getType().is<TupleType>() ||
248+
baseAdjoint.getType().getStructOrBoundGenericStruct() != nullptr);
235249
}
236250

237-
bool isTupleAdjoint() const { return baseAdjoint.getType().is<TupleType>(); }
251+
bool isTupleAdjoint() const { return fieldLocator.isTupleFieldLocator(); }
252+
253+
bool isStructAdjoint() const { return !isTupleAdjoint(); }
238254

239255
VarDecl *getFieldDecl() const {
240256
assert(isStructAdjoint());
241-
return this->fieldLocator.field;
257+
return this->fieldLocator.getInner(FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
242258
}
243259

244260
unsigned int getFieldIndex() const {
245261
assert(isTupleAdjoint());
246-
return this->fieldLocator.index;
262+
return this->fieldLocator.getInner(FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
247263
}
248264
};
249265

lib/SILOptimizer/Differentiation/AdjointValue.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,20 @@ void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
5151
auto *addElementValue = getAddElementValue();
5252
auto baseAdjoint = addElementValue->baseAdjoint;
5353
auto eltToAdd = addElementValue->eltToAdd;
54-
auto baseAdjointKind =
55-
baseAdjoint.getKind() == AdjointValueKind::Zero ? "Zero" : "Aggregate";
5654

5755
s << "AddElement[";
56+
baseAdjoint.print(s);
57+
58+
s << ", Field(";
5859
if (addElementValue->isTupleAdjoint()) {
59-
s << baseAdjointKind << "[" << baseAdjoint.getType() << "].#n["
60-
<< addElementValue->getFieldIndex() << "] += Concrete["
61-
<< eltToAdd.getType() << "]";
60+
s << addElementValue->getFieldIndex();
6261
} else {
63-
s << baseAdjointKind << "[" << baseAdjoint.getType() << "].#field["
64-
<< addElementValue->getFieldDecl()->getNameStr() << "] += Concrete["
65-
<< eltToAdd.getType() << "]";
62+
s << addElementValue->getFieldDecl()->getNameStr();
6663
}
64+
s << "), ";
65+
66+
eltToAdd.print(s);
67+
6768
s << "]";
6869
break;
6970
}

0 commit comments

Comments
 (0)