Skip to content

Commit f81411a

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent 6de8597 commit f81411a

File tree

9 files changed

+694
-96
lines changed

9 files changed

+694
-96
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 96 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
2020

2121
#include "swift/AST/Decl.h"
22+
#include "swift/SIL/SILDebugVariable.h"
23+
#include "swift/SIL/SILLocation.h"
2224
#include "swift/SIL/SILValue.h"
2325
#include "llvm/ADT/ArrayRef.h"
2426
#include "llvm/Support/Debug.h"
@@ -38,10 +40,18 @@ enum AdjointValueKind {
3840

3941
/// A concrete SIL value.
4042
Concrete,
43+
44+
/// A special adjoint, made up of 2 adjoints -- an aggregate base adjoint and
45+
/// an element adjoint to add to one of its fields. This case exists to avoid
46+
/// eager materialization of a base adjoint upon addition with one of its
47+
/// fields.
48+
AddElement,
4149
};
4250

4351
class AdjointValue;
4452

53+
struct AddElementValue;
54+
4555
class AdjointValueBase {
4656
friend class AdjointValue;
4757

@@ -60,9 +70,13 @@ class AdjointValueBase {
6070
union Value {
6171
unsigned numAggregateElements;
6272
SILValue concrete;
73+
AddElementValue *addElementValue;
74+
6375
Value(unsigned numAggregateElements)
6476
: numAggregateElements(numAggregateElements) {}
6577
Value(SILValue v) : concrete(v) {}
78+
Value(AddElementValue *addElementValue)
79+
: addElementValue(addElementValue) {}
6680
Value() {}
6781
} value;
6882

@@ -86,6 +100,11 @@ class AdjointValueBase {
86100

87101
explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
88102
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
103+
104+
explicit AdjointValueBase(SILType type, AddElementValue *addElementValue,
105+
llvm::Optional<DebugInfo> debugInfo)
106+
: kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
107+
value(addElementValue) {}
89108
};
90109

91110
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +146,14 @@ class AdjointValue final {
127146
return new (buf) AdjointValueBase(type, elements, debugInfo);
128147
}
129148

149+
static AdjointValue
150+
createAddElement(llvm::BumpPtrAllocator &allocator, SILType type,
151+
AddElementValue *addElementValue,
152+
llvm::Optional<DebugInfo> debugInfo = llvm::None) {
153+
auto *buf = allocator.Allocate<AdjointValueBase>();
154+
return new (buf) AdjointValueBase(type, addElementValue, debugInfo);
155+
}
156+
130157
AdjointValueKind getKind() const { return base->kind; }
131158
SILType getType() const { return base->type; }
132159
CanType getSwiftType() const { return getType().getASTType(); }
@@ -140,6 +167,9 @@ class AdjointValue final {
140167
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
141168
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
142169
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
170+
bool isAddElement() const {
171+
return getKind() == AdjointValueKind::AddElement;
172+
}
143173

144174
unsigned getNumAggregateElements() const {
145175
assert(isAggregate());
@@ -162,41 +192,77 @@ class AdjointValue final {
162192
return base->value.concrete;
163193
}
164194

165-
void print(llvm::raw_ostream &s) const {
166-
switch (getKind()) {
167-
case AdjointValueKind::Zero:
168-
s << "Zero[" << getType() << ']';
169-
break;
170-
case AdjointValueKind::Aggregate:
171-
s << "Aggregate[" << getType() << "](";
172-
if (auto *decl =
173-
getType().getASTType()->getStructOrBoundGenericStruct()) {
174-
interleave(
175-
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
176-
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
177-
s << std::get<0>(elt)->getName() << ": ";
178-
std::get<1>(elt).print(s);
179-
},
180-
[&s] { s << ", "; });
181-
} else if (getType().is<TupleType>()) {
182-
interleave(
183-
getAggregateElements(),
184-
[&s](const AdjointValue &elt) { elt.print(s); },
185-
[&s] { s << ", "; });
186-
} else {
187-
llvm_unreachable("Invalid aggregate");
188-
}
189-
s << ')';
190-
break;
191-
case AdjointValueKind::Concrete:
192-
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
193-
break;
194-
}
195+
AddElementValue *getAddElementValue() const {
196+
assert(isAddElement());
197+
return base->value.addElementValue;
195198
}
196199

200+
void print(llvm::raw_ostream &s) const;
201+
197202
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
198203
};
199204

205+
/// An abstraction that represents the field locator in
206+
/// an `AddElement` adjoint kind. Depending on the aggregate
207+
/// kind - tuple or struct, of the `baseAdjoint` in an
208+
/// `AddElement` adjoint, the field locator may be an `unsigned int`
209+
/// or a `VarDecl *`.
210+
struct FieldLocator final {
211+
FieldLocator(VarDecl *field) : inner(field) {}
212+
FieldLocator(unsigned int index) : inner(index) {}
213+
214+
friend AddElementValue;
215+
216+
private:
217+
bool isTupleFieldLocator() const {
218+
return std::holds_alternative<unsigned int>(inner);
219+
}
220+
221+
const static constexpr std::true_type TUPLE_FIELD_LOCATOR_TAG =
222+
std::true_type{};
223+
const static constexpr std::false_type STRUCT_FIELD_LOCATOR_TAG =
224+
std::false_type{};
225+
226+
unsigned int getInner(std::true_type) const {
227+
return std::get<unsigned int>(inner);
228+
}
229+
230+
VarDecl *getInner(std::false_type) const {
231+
return std::get<VarDecl *>(inner);
232+
}
233+
234+
std::variant<unsigned int, VarDecl *> inner;
235+
};
236+
237+
/// The underlying value for an `AddElement` adjoint.
238+
struct AddElementValue final {
239+
AdjointValue baseAdjoint;
240+
AdjointValue eltToAdd;
241+
FieldLocator fieldLocator;
242+
243+
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
244+
FieldLocator fieldLocator)
245+
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd),
246+
fieldLocator(fieldLocator) {
247+
assert(baseAdjoint.getType().is<TupleType>() ||
248+
baseAdjoint.getType().getStructOrBoundGenericStruct() != nullptr);
249+
}
250+
251+
bool isTupleAdjoint() const { return fieldLocator.isTupleFieldLocator(); }
252+
253+
bool isStructAdjoint() const { return !isTupleAdjoint(); }
254+
255+
VarDecl *getFieldDecl() const {
256+
assert(isStructAdjoint());
257+
return this->fieldLocator.getInner(FieldLocator::STRUCT_FIELD_LOCATOR_TAG);
258+
}
259+
260+
unsigned int getFieldIndex() const {
261+
assert(isTupleAdjoint());
262+
return this->fieldLocator.getInner(FieldLocator::TUPLE_FIELD_LOCATOR_TAG);
263+
}
264+
};
265+
200266
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
201267
const AdjointValue &adjVal) {
202268
adjVal.print(os);
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//===--- AdjointValue.h - Helper class for differentiation ----*- C++ -*---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// AdjointValue - a symbolic representation for adjoint values enabling
14+
// efficient differentiation by avoiding zero materialization.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#define DEBUG_TYPE "differentiation"
19+
20+
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
21+
22+
void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
23+
switch (getKind()) {
24+
case AdjointValueKind::Zero:
25+
s << "Zero[" << getType() << ']';
26+
break;
27+
case AdjointValueKind::Aggregate:
28+
s << "Aggregate[" << getType() << "](";
29+
if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) {
30+
interleave(
31+
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
32+
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
33+
s << std::get<0>(elt)->getName() << ": ";
34+
std::get<1>(elt).print(s);
35+
},
36+
[&s] { s << ", "; });
37+
} else if (getType().is<TupleType>()) {
38+
interleave(
39+
getAggregateElements(),
40+
[&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; });
41+
} else {
42+
llvm_unreachable("Invalid aggregate");
43+
}
44+
s << ')';
45+
break;
46+
case AdjointValueKind::Concrete:
47+
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
48+
break;
49+
case AdjointValueKind::AddElement:
50+
auto *addElementValue = getAddElementValue();
51+
auto baseAdjoint = addElementValue->baseAdjoint;
52+
auto eltToAdd = addElementValue->eltToAdd;
53+
54+
s << "AddElement[";
55+
baseAdjoint.print(s);
56+
57+
s << ", Field(";
58+
if (addElementValue->isTupleAdjoint()) {
59+
s << addElementValue->getFieldIndex();
60+
} else {
61+
s << addElementValue->getFieldDecl()->getNameStr();
62+
}
63+
s << "), ";
64+
65+
eltToAdd.print(s);
66+
67+
s << "]";
68+
break;
69+
}
70+
}

lib/SILOptimizer/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(swiftSILOptimizer PRIVATE
22
ADContext.cpp
3+
AdjointValue.cpp
34
Common.cpp
45
DifferentiationInvoker.cpp
56
JVPCloner.cpp

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,12 @@ class JVPCloner::Implementation final
228228
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
229229
return zeroVal;
230230
}
231+
case AdjointValueKind::Concrete:
232+
return val.getConcreteValue();
231233
case AdjointValueKind::Aggregate:
234+
case AdjointValueKind::AddElement:
232235
llvm_unreachable(
233236
"Tuples and structs are not supported in forward mode yet.");
234-
case AdjointValueKind::Concrete:
235-
return val.getConcreteValue();
236237
}
237238
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
238239
}

0 commit comments

Comments
 (0)