Skip to content

Commit e3b480b

Browse files
committed
[AutoDiff] Enable variable debugging support for pullback functions.
- Properly clone and use debug scopes for all instructions in pullback functions. - Emit `debug_value` instructions for adjoint values. - Add debug locations and variable info to adjoint buffer allocations. - Add `TangentBuilder` (a `SILBuilder` subclass) to unify and simplify special emitter utilities for tangent vector code generation. More simplifications to come. Pullback variable inspection example: ```console (lldb) n Process 50984 stopped * thread #1, queue = 'com.apple.main-thread', stop reason = step over frame #0: 0x0000000100003497 main`pullback of foo(x=0) at main.swift:12:11 9 import _Differentiation 10 11 func foo(_ x: Float) -> Float { -> 12 let y = sin(x) 13 let z = cos(y) 14 let k = tanh(z) + cos(y) 15 return k Target 0: (main) stopped. (lldb) fr v (Float) x = 0 (Float) k = 1 (Float) z = 0.495846391 (Float) y = -0.689988375 ``` Resolves rdar://68616528 / SR-13535.
1 parent 1c2b80f commit e3b480b

File tree

14 files changed

+573
-425
lines changed

14 files changed

+573
-425
lines changed

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ class ADContext {
115115
mutable FuncDecl *cachedPlusFn = nullptr;
116116
/// `AdditiveArithmetic.+=` declaration.
117117
mutable FuncDecl *cachedPlusEqualFn = nullptr;
118+
/// `AdditiveArithmetic.zero` declaration.
119+
mutable AccessorDecl *cachedZeroGetter = nullptr;
118120

119121
public:
120122
/// Construct an ADContext for the given module.
@@ -201,6 +203,7 @@ class ADContext {
201203

202204
FuncDecl *getPlusDecl() const;
203205
FuncDecl *getPlusEqualDecl() const;
206+
AccessorDecl *getAdditiveArithmeticZeroGetter() const;
204207

205208
/// Cleans up all the internal state.
206209
void cleanUp();
@@ -269,6 +272,10 @@ class ADContext {
269272
Diag<T...> diag, U &&... args);
270273
};
271274

275+
raw_ostream &getADDebugStream();
276+
SILLocation getValidLocation(SILValue v);
277+
SILLocation getValidLocation(SILInstruction *inst);
278+
272279
template <typename... T, typename... U>
273280
InFlightDiagnostic
274281
ADContext::emitNondifferentiabilityError(SILValue value,

include/swift/SILOptimizer/Differentiation/AdjointValue.h

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,29 +51,45 @@ class AdjointValueBase {
5151
/// The type of this value as if it were materialized as a SIL value.
5252
SILType type;
5353

54+
using DebugInfo = std::pair<SILDebugLocation, SILDebugVariable>;
55+
56+
/// The debug location and variable info associated with the original value.
57+
Optional<DebugInfo> debugInfo;
58+
5459
/// The underlying value.
5560
union Value {
56-
llvm::ArrayRef<AdjointValue> aggregate;
61+
unsigned numAggregateElements;
5762
SILValue concrete;
58-
Value(llvm::ArrayRef<AdjointValue> v) : aggregate(v) {}
63+
Value(unsigned numAggregateElements)
64+
: numAggregateElements(numAggregateElements) {}
5965
Value(SILValue v) : concrete(v) {}
6066
Value() {}
6167
} value;
6268

69+
// Begins tail-allocated aggregate elements, if
70+
// `kind == AdjointValueKind::Aggregate`.
71+
6372
explicit AdjointValueBase(SILType type,
64-
llvm::ArrayRef<AdjointValue> aggregate)
65-
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
73+
llvm::ArrayRef<AdjointValue> aggregate,
74+
Optional<DebugInfo> debugInfo)
75+
: kind(AdjointValueKind::Aggregate), type(type), debugInfo(debugInfo),
76+
value(aggregate.size()) {
77+
MutableArrayRef<AdjointValue> tailElements(
78+
reinterpret_cast<AdjointValue *>(this + 1), aggregate.size());
79+
std::uninitialized_copy(
80+
aggregate.begin(), aggregate.end(), tailElements.begin());
81+
}
6682

67-
explicit AdjointValueBase(SILValue v)
68-
: kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {}
83+
explicit AdjointValueBase(SILValue v, Optional<DebugInfo> debugInfo)
84+
: kind(AdjointValueKind::Concrete), type(v->getType()),
85+
debugInfo(debugInfo), value(v) {}
6986

70-
explicit AdjointValueBase(SILType type)
71-
: kind(AdjointValueKind::Zero), type(type) {}
87+
explicit AdjointValueBase(SILType type, Optional<DebugInfo> debugInfo)
88+
: kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
7289
};
7390

74-
/// A symbolic adjoint value that is capable of representing zero value 0 and
75-
/// 1, in addition to a materialized SILValue. This is expected to be passed
76-
/// around by value in most cases, as it's two words long.
91+
/// A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
92+
/// thereof.
7793
class AdjointValue final {
7894

7995
private:
@@ -85,26 +101,37 @@ class AdjointValue final {
85101
AdjointValueBase *operator->() const { return base; }
86102
AdjointValueBase &operator*() const { return *base; }
87103

88-
static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator,
89-
SILValue value) {
90-
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value);
104+
using DebugInfo = AdjointValueBase::DebugInfo;
105+
106+
static AdjointValue createConcrete(
107+
llvm::BumpPtrAllocator &allocator, SILValue value,
108+
Optional<DebugInfo> debugInfo = None) {
109+
auto *buf = allocator.Allocate<AdjointValueBase>();
110+
return new (buf) AdjointValueBase(value, debugInfo);
91111
}
92112

93-
static AdjointValue createZero(llvm::BumpPtrAllocator &allocator,
94-
SILType type) {
95-
return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type);
113+
static AdjointValue createZero(
114+
llvm::BumpPtrAllocator &allocator, SILType type,
115+
Optional<DebugInfo> debugInfo = None) {
116+
auto *buf = allocator.Allocate<AdjointValueBase>();
117+
return new (buf) AdjointValueBase(type, debugInfo);
96118
}
97119

98-
static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator,
99-
SILType type,
100-
llvm::ArrayRef<AdjointValue> aggregate) {
101-
return new (allocator.Allocate<AdjointValueBase>())
102-
AdjointValueBase(type, aggregate);
120+
static AdjointValue createAggregate(
121+
llvm::BumpPtrAllocator &allocator, SILType type,
122+
ArrayRef<AdjointValue> elements,
123+
Optional<DebugInfo> debugInfo = None) {
124+
AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate(
125+
sizeof(AdjointValueBase) + elements.size() * sizeof(AdjointValue),
126+
alignof(AdjointValueBase)));
127+
return new (buf) AdjointValueBase(type, elements, debugInfo);
103128
}
104129

105130
AdjointValueKind getKind() const { return base->kind; }
106131
SILType getType() const { return base->type; }
107132
CanType getSwiftType() const { return getType().getASTType(); }
133+
Optional<DebugInfo> getDebugInfo() const { return base->debugInfo; }
134+
void setDebugInfo(DebugInfo debugInfo) const { base->debugInfo = debugInfo; }
108135

109136
NominalTypeDecl *getAnyNominal() const {
110137
return getSwiftType()->getAnyNominal();
@@ -116,16 +143,18 @@ class AdjointValue final {
116143

117144
unsigned getNumAggregateElements() const {
118145
assert(isAggregate());
119-
return base->value.aggregate.size();
146+
return base->value.numAggregateElements;
120147
}
121148

122149
AdjointValue getAggregateElement(unsigned i) const {
123-
assert(isAggregate());
124-
return base->value.aggregate[i];
150+
return getAggregateElements()[i];
125151
}
126152

127153
llvm::ArrayRef<AdjointValue> getAggregateElements() const {
128-
return base->value.aggregate;
154+
assert(isAggregate());
155+
return {
156+
reinterpret_cast<const AdjointValue *>(base + 1),
157+
getNumAggregateElements()};
129158
}
130159

131160
SILValue getConcreteValue() const {
@@ -143,15 +172,15 @@ class AdjointValue final {
143172
if (auto *decl =
144173
getType().getASTType()->getStructOrBoundGenericStruct()) {
145174
interleave(
146-
llvm::zip(decl->getStoredProperties(), base->value.aggregate),
175+
llvm::zip(decl->getStoredProperties(), getAggregateElements()),
147176
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
148177
s << std::get<0>(elt)->getName() << ": ";
149178
std::get<1>(elt).print(s);
150179
},
151180
[&s] { s << ", "; });
152181
} else if (getType().is<TupleType>()) {
153182
interleave(
154-
base->value.aggregate,
183+
getAggregateElements(),
155184
[&s](const AdjointValue &elt) { elt.print(s); },
156185
[&s] { s << ", "; });
157186
} else {

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
#include "swift/SIL/TypeSubstCloner.h"
2828
#include "swift/SILOptimizer/Analysis/ArraySemantic.h"
2929
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
30+
#include "swift/SILOptimizer/Differentiation/ADContext.h"
3031
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
32+
#include "swift/SILOptimizer/Differentiation/TangentBuilder.h"
3133

3234
namespace swift {
3335

@@ -142,6 +144,9 @@ template <class Inst> Inst *peerThroughFunctionConversions(SILValue value) {
142144
return nullptr;
143145
}
144146

147+
Optional<std::pair<SILDebugLocation, SILDebugVariable>>
148+
findDebugLocationAndVariable(SILValue originalValue);
149+
145150
//===----------------------------------------------------------------------===//
146151
// Diagnostic utilities
147152
//===----------------------------------------------------------------------===//
@@ -190,12 +195,6 @@ SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
190195
void extractAllElements(SILValue value, SILBuilder &builder,
191196
SmallVectorImpl<SILValue> &results);
192197

193-
/// Emit a zero value into the given buffer access by calling
194-
/// `AdditiveArithmetic.zero`. The given type must conform to
195-
/// `AdditiveArithmetic`.
196-
void emitZeroIntoBuffer(SILBuilder &builder, CanType type,
197-
SILValue bufferAccess, SILLocation loc);
198-
199198
/// Emit a `Builtin.Word` value that represents the given type's memory layout
200199
/// size.
201200
SILValue emitMemoryLayoutSize(
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//===--- TangentBuilder.h - Tangent SIL builder --------------*- C++ -*----===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file defines a helper class for emitting tangent code for automatic
14+
// differentiation.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
19+
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H
20+
21+
#include "swift/SIL/SILBuilder.h"
22+
23+
namespace swift {
24+
namespace autodiff {
25+
26+
class ADContext;
27+
28+
class TangentBuilder: public SILBuilder {
29+
private:
30+
ADContext &adContext;
31+
32+
public:
33+
TangentBuilder(SILFunction &fn, ADContext &adContext)
34+
: SILBuilder(fn), adContext(adContext) {}
35+
TangentBuilder(SILBasicBlock *bb, ADContext &adContext)
36+
: SILBuilder(bb), adContext(adContext) {}
37+
TangentBuilder(SILBasicBlock::iterator insertionPt, ADContext &adContext)
38+
: SILBuilder(insertionPt), adContext(adContext) {}
39+
TangentBuilder(SILBasicBlock *bb, SILBasicBlock::iterator insertionPt,
40+
ADContext &adContext)
41+
: SILBuilder(bb, insertionPt), adContext(adContext) {}
42+
43+
/// Emits an `AdditiveArithmetic.zero` into the given buffer. If it is not an
44+
/// initialization (`isInit`), a `destroy_addr` will be emitted on the buffer
45+
/// first. The buffer must have a type that conforms to `AdditiveArithmetic`
46+
/// or be a tuple thereof.
47+
void emitZeroIntoBuffer(SILLocation loc, SILValue buffer,
48+
IsInitialization_t isInit);
49+
50+
/// Emits an `AdditiveArithmetic.zero` of the given type. The type must be a
51+
/// loadable type, and must conform to `AddditiveArithmetic` or be a tuple
52+
/// thereof.
53+
SILValue emitZero(SILLocation loc, CanType type);
54+
55+
/// Emits an `AdditiveArithmetic.+=` for the given destination buffer and
56+
/// operand. The type of the buffer and the operand must conform to
57+
/// `AddditiveArithmetic` or be a tuple thereof. The operand will not be
58+
/// consumed.
59+
void emitInPlaceAdd(SILLocation loc, SILValue destinationBuffer,
60+
SILValue operand);
61+
62+
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
63+
/// operands must conform to `AddditiveArithmetic` or be a tuple thereof. The
64+
/// operands will not be consumed.
65+
void emitAddIntoBuffer(SILLocation loc, SILValue destinationBuffer,
66+
SILValue lhsAddress, SILValue rhsAddress);
67+
68+
/// Emits an `AdditiveArithmetic.+` for the given operands. The type of the
69+
/// operands must be a loadable type, and must conform to
70+
/// `AddditiveArithmetic` or be a tuple thereof. The operands will not be
71+
/// consumed.
72+
SILValue emitAdd(SILLocation loc, SILValue lhs, SILValue rhs);
73+
};
74+
75+
} // end namespace autodiff
76+
} // end namespace swift
77+
78+
#endif /* SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_TANGENTBUILDER_H */

include/swift/SILOptimizer/Differentiation/Thunk.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class ArchetypeType;
3838

3939
namespace autodiff {
4040

41+
class ADContext;
42+
4143
//===----------------------------------------------------------------------===//
4244
// Thunk helpers
4345
//===----------------------------------------------------------------------===//
@@ -107,7 +109,7 @@ std::pair<SILFunction *, SubstitutionMap>
107109
getOrCreateSubsetParametersThunkForDerivativeFunction(
108110
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
109111
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
110-
AutoDiffConfig actualConfig);
112+
AutoDiffConfig actualConfig, ADContext &adContext);
111113

112114
/// Get or create a derivative function parameter index subset thunk from
113115
/// `actualIndices` to `desiredIndices` for the given associated function
@@ -119,7 +121,8 @@ getOrCreateSubsetParametersThunkForLinearMap(
119121
SILOptFunctionBuilder &fb, SILFunction *assocFn,
120122
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
121123
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
122-
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig);
124+
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig,
125+
ADContext &adContext);
123126

124127
} // end namespace autodiff
125128

lib/SILOptimizer/Differentiation/ADContext.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,17 @@ FuncDecl *ADContext::getPlusEqualDecl() const {
9595
return cachedPlusEqualFn;
9696
}
9797

98+
AccessorDecl *ADContext::getAdditiveArithmeticZeroGetter() const {
99+
if (cachedZeroGetter)
100+
return cachedZeroGetter;
101+
auto zeroDeclLookup = getAdditiveArithmeticProtocol()
102+
->lookupDirect(getASTContext().Id_zero);
103+
auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
104+
assert(zeroDecl->isProtocolRequirement());
105+
cachedZeroGetter = zeroDecl->getOpaqueAccessor(AccessorKind::Get);
106+
return cachedZeroGetter;
107+
}
108+
98109
void ADContext::cleanUp() {
99110
// Delete all references to generated functions.
100111
for (auto fnRef : generatedFunctionReferences) {

lib/SILOptimizer/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ target_sources(swiftSILOptimizer PRIVATE
55
JVPCloner.cpp
66
LinearMapInfo.cpp
77
PullbackCloner.cpp
8+
TangentBuilder.cpp
89
Thunk.cpp
910
VJPCloner.cpp)

0 commit comments

Comments
 (0)