@@ -51,29 +51,45 @@ class AdjointValueBase {
51
51
// / The type of this value as if it were materialized as a SIL value.
52
52
SILType type;
53
53
54
+ using DebugInfo = std::pair<SILDebugLocation, SILDebugVariable>;
55
+
56
+ // / The debug location and variable info associated with the original value.
57
+ Optional<DebugInfo> debugInfo;
58
+
54
59
// / The underlying value.
55
60
union Value {
56
- llvm::ArrayRef<AdjointValue> aggregate ;
61
+ unsigned numAggregateElements ;
57
62
SILValue concrete;
58
- Value (llvm::ArrayRef<AdjointValue> v) : aggregate (v) {}
63
+ Value (unsigned numAggregateElements)
64
+ : numAggregateElements (numAggregateElements) {}
59
65
Value (SILValue v) : concrete (v) {}
60
66
Value () {}
61
67
} value;
62
68
69
+ // Begins tail-allocated aggregate elements, if
70
+ // `kind == AdjointValueKind::Aggregate`.
71
+
63
72
explicit AdjointValueBase (SILType type,
64
- llvm::ArrayRef<AdjointValue> aggregate)
65
- : kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
73
+ llvm::ArrayRef<AdjointValue> aggregate,
74
+ Optional<DebugInfo> debugInfo)
75
+ : kind(AdjointValueKind::Aggregate), type(type), debugInfo(debugInfo),
76
+ value(aggregate.size()) {
77
+ MutableArrayRef<AdjointValue> tailElements (
78
+ reinterpret_cast <AdjointValue *>(this + 1 ), aggregate.size ());
79
+ std::uninitialized_copy (
80
+ aggregate.begin (), aggregate.end (), tailElements.begin ());
81
+ }
66
82
67
- explicit AdjointValueBase (SILValue v)
68
- : kind(AdjointValueKind::Concrete), type(v->getType ()), value(v) {}
83
+ explicit AdjointValueBase (SILValue v, Optional<DebugInfo> debugInfo)
84
+ : kind(AdjointValueKind::Concrete), type(v->getType ()),
85
+ debugInfo(debugInfo), value(v) {}
69
86
70
- explicit AdjointValueBase (SILType type)
71
- : kind(AdjointValueKind::Zero), type(type) {}
87
+ explicit AdjointValueBase (SILType type, Optional<DebugInfo> debugInfo )
88
+ : kind(AdjointValueKind::Zero), type(type), debugInfo(debugInfo) {}
72
89
};
73
90
74
- // / A symbolic adjoint value that is capable of representing zero value 0 and
75
- // / 1, in addition to a materialized SILValue. This is expected to be passed
76
- // / around by value in most cases, as it's two words long.
91
+ // / A symbolic adjoint value that wraps a `SILValue`, a zero, or an aggregate
92
+ // / thereof.
77
93
class AdjointValue final {
78
94
79
95
private:
@@ -85,26 +101,37 @@ class AdjointValue final {
85
101
AdjointValueBase *operator ->() const { return base; }
86
102
AdjointValueBase &operator *() const { return *base; }
87
103
88
- static AdjointValue createConcrete (llvm::BumpPtrAllocator &allocator,
89
- SILValue value) {
90
- return new (allocator.Allocate <AdjointValueBase>()) AdjointValueBase (value);
104
+ using DebugInfo = AdjointValueBase::DebugInfo;
105
+
106
+ static AdjointValue createConcrete (
107
+ llvm::BumpPtrAllocator &allocator, SILValue value,
108
+ Optional<DebugInfo> debugInfo = None) {
109
+ auto *buf = allocator.Allocate <AdjointValueBase>();
110
+ return new (buf) AdjointValueBase (value, debugInfo);
91
111
}
92
112
93
- static AdjointValue createZero (llvm::BumpPtrAllocator &allocator,
94
- SILType type) {
95
- return new (allocator.Allocate <AdjointValueBase>()) AdjointValueBase (type);
113
+ static AdjointValue createZero (
114
+ llvm::BumpPtrAllocator &allocator, SILType type,
115
+ Optional<DebugInfo> debugInfo = None) {
116
+ auto *buf = allocator.Allocate <AdjointValueBase>();
117
+ return new (buf) AdjointValueBase (type, debugInfo);
96
118
}
97
119
98
- static AdjointValue createAggregate (llvm::BumpPtrAllocator &allocator,
99
- SILType type,
100
- llvm::ArrayRef<AdjointValue> aggregate) {
101
- return new (allocator.Allocate <AdjointValueBase>())
102
- AdjointValueBase (type, aggregate);
120
+ static AdjointValue createAggregate (
121
+ llvm::BumpPtrAllocator &allocator, SILType type,
122
+ ArrayRef<AdjointValue> elements,
123
+ Optional<DebugInfo> debugInfo = None) {
124
+ AdjointValue *buf = reinterpret_cast <AdjointValue *>(allocator.Allocate (
125
+ sizeof (AdjointValueBase) + elements.size () * sizeof (AdjointValue),
126
+ alignof (AdjointValueBase)));
127
+ return new (buf) AdjointValueBase (type, elements, debugInfo);
103
128
}
104
129
105
130
AdjointValueKind getKind () const { return base->kind ; }
106
131
SILType getType () const { return base->type ; }
107
132
CanType getSwiftType () const { return getType ().getASTType (); }
133
+ Optional<DebugInfo> getDebugInfo () const { return base->debugInfo ; }
134
+ void setDebugInfo (DebugInfo debugInfo) const { base->debugInfo = debugInfo; }
108
135
109
136
NominalTypeDecl *getAnyNominal () const {
110
137
return getSwiftType ()->getAnyNominal ();
@@ -116,16 +143,18 @@ class AdjointValue final {
116
143
117
144
unsigned getNumAggregateElements () const {
118
145
assert (isAggregate ());
119
- return base->value .aggregate . size () ;
146
+ return base->value .numAggregateElements ;
120
147
}
121
148
122
149
AdjointValue getAggregateElement (unsigned i) const {
123
- assert (isAggregate ());
124
- return base->value .aggregate [i];
150
+ return getAggregateElements ()[i];
125
151
}
126
152
127
153
llvm::ArrayRef<AdjointValue> getAggregateElements () const {
128
- return base->value .aggregate ;
154
+ assert (isAggregate ());
155
+ return {
156
+ reinterpret_cast <const AdjointValue *>(base + 1 ),
157
+ getNumAggregateElements ()};
129
158
}
130
159
131
160
SILValue getConcreteValue () const {
@@ -143,15 +172,15 @@ class AdjointValue final {
143
172
if (auto *decl =
144
173
getType ().getASTType ()->getStructOrBoundGenericStruct ()) {
145
174
interleave (
146
- llvm::zip (decl->getStoredProperties (), base-> value . aggregate ),
175
+ llvm::zip (decl->getStoredProperties (), getAggregateElements () ),
147
176
[&s](std::tuple<VarDecl *, const AdjointValue &> elt) {
148
177
s << std::get<0 >(elt)->getName () << " : " ;
149
178
std::get<1 >(elt).print (s);
150
179
},
151
180
[&s] { s << " , " ; });
152
181
} else if (getType ().is <TupleType>()) {
153
182
interleave (
154
- base-> value . aggregate ,
183
+ getAggregateElements () ,
155
184
[&s](const AdjointValue &elt) { elt.print (s); },
156
185
[&s] { s << " , " ; });
157
186
} else {
0 commit comments