Skip to content

Commit 2effcab

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent 211b5ae commit 2effcab

File tree

5 files changed

+487
-107
lines changed

5 files changed

+487
-107
lines changed

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 86 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADJOINTVALUE_H
2020

2121
#include "swift/AST/Decl.h"
22+
#include "swift/SIL/SILDebugVariable.h"
23+
#include "swift/SIL/SILLocation.h"
2224
#include "swift/SIL/SILValue.h"
2325
#include "llvm/ADT/ArrayRef.h"
2426
#include "llvm/Support/Debug.h"
@@ -38,10 +40,25 @@ enum AdjointValueKind {
3840

3941
/// A concrete SIL value.
4042
Concrete,
43+
44+
/// A special adjoint, made up of 2 adjoints -- a base adjoint and an element
45+
/// adjoint to add to it. This case exists due to the existence of custom
46+
/// tangent vectors which may comprise of non-differentiable fields and may
47+
/// be used in the adjoint of the `struct_extract` and `tuple_extract` SIL
48+
/// instructions.
49+
///
50+
/// The adjoints for such tangent vectors are not pieceswise materializable,
51+
/// i.e., cannot be materialized by materializing individual fields. Therefore
52+
/// when used w/ a `struct_extact`/`tuple_extract` they must be materialized
53+
/// by first creating a zero tangent vector of the base adjoint and then
54+
/// in-place adding element adjoint to the specified field.
55+
AddElement,
4156
};
4257

4358
class AdjointValue;
4459

60+
struct AddElementValue;
61+
4562
class AdjointValueBase {
4663
friend class AdjointValue;
4764

@@ -60,9 +77,13 @@ class AdjointValueBase {
6077
union Value {
6178
unsigned numAggregateElements;
6279
SILValue concrete;
80+
AddElementValue *addElementValue;
81+
6382
Value(unsigned numAggregateElements)
6483
: numAggregateElements(numAggregateElements) {}
6584
Value(SILValue v) : concrete(v) {}
85+
Value(AddElementValue *addElementValue)
86+
: addElementValue(addElementValue) {}
6687
Value() {}
6788
} value;
6889

@@ -86,6 +107,11 @@ class AdjointValueBase {
86107

87108
explicit AdjointValueBase(SILType type, llvm::Optional<DebugInfo> debugInfo)
88109
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
110+
111+
explicit AdjointValueBase(SILType type, AddElementValue *addElementValue,
112+
llvm::Optional<DebugInfo> debugInfo)
113+
: kind(AdjointValueKind::AddElement), type(type), debugInfo(debugInfo),
114+
value(addElementValue) {}
89115
};
90116

91117
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
@@ -127,6 +153,14 @@ class AdjointValue final {
127153
return new (buf) AdjointValueBase(type, elements, debugInfo);
128154
}
129155

156+
static AdjointValue
157+
createAddElement(llvm::BumpPtrAllocator &allocator, SILType type,
158+
AddElementValue *addElementValue,
159+
llvm::Optional<DebugInfo> debugInfo = llvm::None) {
160+
auto *buf = allocator.Allocate<AdjointValueBase>();
161+
return new (buf) AdjointValueBase(type, addElementValue, debugInfo);
162+
}
163+
130164
AdjointValueKind getKind() const { return base->kind; }
131165
SILType getType() const { return base->type; }
132166
CanType getSwiftType() const { return getType().getASTType(); }
@@ -140,6 +174,9 @@ class AdjointValue final {
140174
bool isZero() const { return getKind() == AdjointValueKind::Zero; }
141175
bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
142176
bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }
177+
bool isAddElement() const {
178+
return getKind() == AdjointValueKind::AddElement;
179+
}
143180

144181
unsigned getNumAggregateElements() const {
145182
assert(isAggregate());
@@ -162,41 +199,60 @@ class AdjointValue final {
162199
return base->value.concrete;
163200
}
164201

165-
void print(llvm::raw_ostream &s) const {
166-
switch (getKind()) {
167-
case AdjointValueKind::Zero:
168-
s << "Zero[" << getType() << ']';
169-
break;
170-
case AdjointValueKind::Aggregate:
171-
s << "Aggregate[" << getType() << "](";
172-
if (auto *decl =
173-
getType().getASTType()->getStructOrBoundGenericStruct()) {
174-
interleave(
175-
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
176-
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
177-
s << std::get<0>(elt)->getName() << ": ";
178-
std::get<1>(elt).print(s);
179-
},
180-
[&s] { s << ", "; });
181-
} else if (getType().is<TupleType>()) {
182-
interleave(
183-
getAggregateElements(),
184-
[&s](const AdjointValue &elt) { elt.print(s); },
185-
[&s] { s << ", "; });
186-
} else {
187-
llvm_unreachable("Invalid aggregate");
188-
}
189-
s << ')';
190-
break;
191-
case AdjointValueKind::Concrete:
192-
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
193-
break;
194-
}
202+
AddElementValue *getAddElementValue() const {
203+
assert(isAddElement());
204+
return base->value.addElementValue;
195205
}
196206

207+
void print(llvm::raw_ostream &s) const;
208+
197209
SWIFT_DEBUG_DUMP { print(llvm::dbgs()); };
198210
};
199211

212+
/// @brief The underlying value for an `AddElement` adjoint.
213+
struct AddElementValue final {
214+
AdjointValue baseAdjoint;
215+
AdjointValue eltToAdd;
216+
217+
private:
218+
union FieldLocator {
219+
VarDecl *field;
220+
unsigned int index;
221+
222+
FieldLocator(VarDecl *field) : field(field) {}
223+
FieldLocator(unsigned int index) : index(index) {}
224+
} fieldLocator;
225+
226+
public:
227+
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
228+
VarDecl *field)
229+
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(field) {
230+
assert(baseAdjoint.getKind() == AdjointValueKind::Zero);
231+
}
232+
233+
AddElementValue(AdjointValue baseAdjoint, AdjointValue eltToAdd,
234+
unsigned int index)
235+
: baseAdjoint(baseAdjoint), eltToAdd(eltToAdd), fieldLocator(index) {
236+
assert(baseAdjoint.getKind() == AdjointValueKind::Zero);
237+
}
238+
239+
bool isStructAdjoint() const {
240+
return !baseAdjoint.getType().is<TupleType>();
241+
}
242+
243+
bool isTupleAdjoint() const { return baseAdjoint.getType().is<TupleType>(); }
244+
245+
VarDecl *getFieldDecl() const {
246+
assert(isStructAdjoint());
247+
return this->fieldLocator.field;
248+
}
249+
250+
unsigned int getFieldIndex() const {
251+
assert(isTupleAdjoint());
252+
return this->fieldLocator.index;
253+
}
254+
};
255+
200256
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
201257
const AdjointValue &adjVal) {
202258
adjVal.print(os);
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
//===--- AdjointValue.cpp - Helper class for differentiation ----*- C++
2+
//-*---===//
3+
//
4+
// This source file is part of the Swift.org open source project
5+
//
6+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
7+
// Licensed under Apache License v2.0 with Runtime Library Exception
8+
//
9+
// See https://swift.org/LICENSE.txt for license information
10+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
11+
//
12+
//===----------------------------------------------------------------------===//
13+
//
14+
// AdjointValue - a symbolic representation for adjoint values enabling
15+
// efficient differentiation by avoiding zero materialization.
16+
//
17+
//===----------------------------------------------------------------------===//
18+
19+
#define DEBUG_TYPE "differentiation"
20+
21+
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
22+
23+
void swift::autodiff::AdjointValue::print(llvm::raw_ostream &s) const {
24+
switch (getKind()) {
25+
case AdjointValueKind::Zero:
26+
s << "Zero[" << getType() << ']';
27+
break;
28+
case AdjointValueKind::Aggregate:
29+
s << "Aggregate[" << getType() << "](";
30+
if (auto *decl = getType().getASTType()->getStructOrBoundGenericStruct()) {
31+
interleave(
32+
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
33+
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
34+
s << std::get<0>(elt)->getName() << ": ";
35+
std::get<1>(elt).print(s);
36+
},
37+
[&s] { s << ", "; });
38+
} else if (getType().is<TupleType>()) {
39+
interleave(
40+
getAggregateElements(),
41+
[&s](const AdjointValue &elt) { elt.print(s); }, [&s] { s << ", "; });
42+
} else {
43+
llvm_unreachable("Invalid aggregate");
44+
}
45+
s << ')';
46+
break;
47+
case AdjointValueKind::Concrete:
48+
s << "Concrete[" << getType() << "](" << base->value.concrete << ')';
49+
break;
50+
case AdjointValueKind::AddElement:
51+
auto *addElementValue = getAddElementValue();
52+
auto baseAdjoint = addElementValue->baseAdjoint;
53+
auto eltToAdd = addElementValue->eltToAdd;
54+
55+
s << "AddElement[";
56+
if (addElementValue->isTupleAdjoint()) {
57+
s << "Zero[" << baseAdjoint.getType() << "].#n["
58+
<< addElementValue->getFieldIndex() << "] += Concrete["
59+
<< eltToAdd.getType() << "]";
60+
} else {
61+
s << "Zero[" << baseAdjoint.getType() << "].#field["
62+
<< addElementValue->getFieldDecl()->getNameStr() << "] += Concrete["
63+
<< eltToAdd.getType() << "]";
64+
}
65+
s << "]";
66+
break;
67+
}
68+
}

lib/SILOptimizer/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(swiftSILOptimizer PRIVATE
22
ADContext.cpp
3+
AdjointValue.cpp
34
Common.cpp
45
DifferentiationInvoker.cpp
56
JVPCloner.cpp

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,12 @@ class JVPCloner::Implementation final
228228
auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
229229
return zeroVal;
230230
}
231+
case AdjointValueKind::Concrete:
232+
return val.getConcreteValue();
231233
case AdjointValueKind::Aggregate:
234+
case AdjointValueKind::AddElement:
232235
llvm_unreachable(
233236
"Tuples and structs are not supported in forward mode yet.");
234-
case AdjointValueKind::Concrete:
235-
return val.getConcreteValue();
236237
}
237238
llvm_unreachable("Invalid adjoint value kind"); // silences MSVC C4715
238239
}

0 commit comments

Comments
 (0)