@@ -79,14 +79,20 @@ class AutoDiffParameter {
79
79
};
80
80
81
81
class AnyFunctionType ;
82
+ class AutoDiffParameterIndicesBuilder ;
82
83
class Type ;
83
84
84
85
// / Identifies a subset of a function's parameters.
85
86
// /
86
87
// / Works with AST-level function decls and types. Requires further lowering to
87
88
// / work with SIL-level functions and types. (In particular, tuples must be
88
89
// / exploded).
89
- class AutoDiffParameterIndices {
90
+ // /
91
+ // / Is uniquely allocated within an ASTContext so that it can be hashed and
92
+ // / compared by opaque pointer value.
93
+ class AutoDiffParameterIndices : public llvm ::FoldingSetNode {
94
+ friend AutoDiffParameterIndicesBuilder;
95
+
90
96
// / Bits corresponding to parameters in the set are "on", and bits
91
97
// / corresponding to parameters not in the set are "off".
92
98
// /
@@ -103,29 +109,19 @@ class AutoDiffParameterIndices {
103
109
// / Function type: (Self) -> (A, B, C) -> R
104
110
// / Bits: [A][B][C][Self]
105
111
// /
106
- llvm::SmallBitVector indices;
112
+ const llvm::SmallBitVector indices;
107
113
108
114
// / Whether the function is a method.
109
115
// /
110
- bool isMethodFlag;
111
-
112
- unsigned getNumNonSelfParameters () const ;
113
-
114
- AutoDiffParameterIndices (unsigned numIndices, bool isMethodFlag,
115
- bool setAllParams = false )
116
- : indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}
116
+ const bool isMethodFlag;
117
117
118
118
AutoDiffParameterIndices (llvm::SmallBitVector indices, bool isMethodFlag)
119
119
: indices(indices), isMethodFlag(isMethodFlag) {}
120
120
121
- public:
122
- // / Allocates and initializes an empty `AutoDiffParameterIndices` for the
123
- // / given `functionType`. `isMethod` specifies whether to treat the function
124
- // / as a method.
125
- static AutoDiffParameterIndices *
126
- create (ASTContext &C, AnyFunctionType *functionType, bool isMethod,
127
- bool setAllParams = false );
121
+ static AutoDiffParameterIndices *get (llvm::SmallBitVector indices,
122
+ bool isMethodFlag, ASTContext &C);
128
123
124
+ public:
129
125
bool isMethod () const { return isMethodFlag; }
130
126
131
127
// / Allocates and initializes an `AutoDiffParameterIndices` corresponding to
@@ -146,40 +142,6 @@ class AutoDiffParameterIndices {
146
142
// / Tests whether this set of parameters is empty.
147
143
bool isEmpty () const { return indices.none (); }
148
144
149
- // / Adds the indexed parameter to the set. When `isMethodFlag` is not set,
150
- // / the indices index into the first parameter list. For example,
151
- // /
152
- // / functionType = (A, B, C) -> R
153
- // / paramIndex = 0
154
- // / ==> adds "A" to the set.
155
- // /
156
- // / When `isMethodFlag` is set, the indices index into the first non-self
157
- // / parameter list. For example,
158
- // /
159
- // / functionType = (Self) -> (A, B, C) -> R
160
- // / paramIndex = 0
161
- // / ==> adds "A" to the set.
162
- // /
163
- void setNonSelfParameter (unsigned parameterIndex);
164
-
165
- // / Adds all the paramaters from the first non-self parameter list to the set.
166
- // / For example,
167
- // /
168
- // / functionType = (A, B, C) -> R
169
- // / ==> adds "A", B", and "C" to the set.
170
- // /
171
- // / functionType = (Self) -> (A, B, C) -> R
172
- // / ==> adds "A", B", and "C" to the set.
173
- // /
174
- void setAllNonSelfParameters ();
175
-
176
- // / Adds the self parameter to the set. `isMethodFlag` must be set. For
177
- // / example,
178
- // / functionType = (Self) -> (A, B, C) -> R
179
- // / ==> adds "Self" to the set
180
- // /
181
- void setSelfParameter ();
182
-
183
145
// / Pushes the subset's parameter's types to `paramTypes`, in the order in
184
146
// / which they appear in the function type. For example,
185
147
// /
@@ -223,11 +185,66 @@ class AutoDiffParameterIndices {
223
185
llvm::SmallBitVector getLowered (AnyFunctionType *functionType,
224
186
bool selfUncurried = false ) const ;
225
187
226
- bool operator ==(const AutoDiffParameterIndices &other) const {
227
- return isMethodFlag == other.isMethodFlag && indices == other.indices ;
188
+ void Profile (llvm::FoldingSetNodeID &ID) const {
189
+ ID.AddBoolean (isMethodFlag);
190
+ ID.AddInteger (indices.size ());
191
+ for (unsigned setBit : indices.set_bits ())
192
+ ID.AddInteger (setBit);
228
193
}
229
194
};
230
195
196
+ // / Builder for `AutoDiffParameterIndices`.
197
+ class AutoDiffParameterIndicesBuilder {
198
+ llvm::SmallBitVector indices;
199
+ bool isMethodFlag;
200
+
201
+ unsigned getNumNonSelfParameters () const ;
202
+
203
+ public:
204
+ // / Start building an `AutoDiffParameterIndices` for the given function type.
205
+ // / `isMethod` specifies whether to treat the function as a method.
206
+ AutoDiffParameterIndicesBuilder (AnyFunctionType *functionType, bool isMethod,
207
+ bool setAllParams = false );
208
+
209
+ // / Builds the `AutoDiffParameterIndices`, returning a pointer to an existing
210
+ // / one if it has already been allocated in the `ASTContext`.
211
+ AutoDiffParameterIndices *build (ASTContext &C) const ;
212
+
213
+ // / Adds the indexed parameter to the set. When `isMethodFlag` is not set,
214
+ // / the indices index into the first parameter list. For example,
215
+ // /
216
+ // / functionType = (A, B, C) -> R
217
+ // / paramIndex = 0
218
+ // / ==> adds "A" to the set.
219
+ // /
220
+ // / When `isMethodFlag` is set, the indices index into the first non-self
221
+ // / parameter list. For example,
222
+ // /
223
+ // / functionType = (Self) -> (A, B, C) -> R
224
+ // / paramIndex = 0
225
+ // / ==> adds "A" to the set.
226
+ // /
227
+ void setNonSelfParameter (unsigned parameterIndex);
228
+
229
+ // / Adds all the paramaters from the first non-self parameter list to the set.
230
+ // / For example,
231
+ // /
232
+ // / functionType = (A, B, C) -> R
233
+ // / ==> adds "A", B", and "C" to the set.
234
+ // /
235
+ // / functionType = (Self) -> (A, B, C) -> R
236
+ // / ==> adds "A", B", and "C" to the set.
237
+ // /
238
+ void setAllNonSelfParameters ();
239
+
240
+ // / Adds the self parameter to the set. `isMethodFlag` must be set. For
241
+ // / example,
242
+ // / functionType = (Self) -> (A, B, C) -> R
243
+ // / ==> adds "Self" to the set
244
+ // /
245
+ void setSelfParameter ();
246
+ };
247
+
231
248
// / Differentiability of a function specifies the differentiation mode,
232
249
// / parameter indices at which the function is differentiable with respect to,
233
250
// / and indices of results which can be differentiated.
@@ -413,25 +430,35 @@ struct AutoDiffAssociatedFunctionKind {
413
430
414
431
// / In conjunction with the original function decl, identifies an associated
415
432
// / autodiff function.
416
- class AutoDiffAssociatedFunctionIdentifier {
417
- AutoDiffAssociatedFunctionKind kind;
418
- unsigned differentiationOrder;
419
- AutoDiffParameterIndices *parameterIndices;
433
+ // /
434
+ // / Is uniquely allocated within an ASTContext so that it can be hashed and
435
+ // / compared by opaque pointer value.
436
+ class AutoDiffAssociatedFunctionIdentifier : public llvm ::FoldingSetNode {
437
+ const AutoDiffAssociatedFunctionKind kind;
438
+ const unsigned differentiationOrder;
439
+ AutoDiffParameterIndices * const parameterIndices;
440
+
441
+ AutoDiffAssociatedFunctionIdentifier (
442
+ AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
443
+ AutoDiffParameterIndices *parameterIndices) :
444
+ kind (kind), differentiationOrder(differentiationOrder),
445
+ parameterIndices (parameterIndices) {}
420
446
421
447
public:
422
448
AutoDiffAssociatedFunctionKind getKind () const { return kind; }
423
449
unsigned getDifferentiationOrder () const { return differentiationOrder; }
424
- const AutoDiffParameterIndices *getParameterIndices () const {
450
+ AutoDiffParameterIndices *getParameterIndices () const {
425
451
return parameterIndices;
426
452
}
427
453
428
454
static AutoDiffAssociatedFunctionIdentifier *get (
429
455
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
430
456
AutoDiffParameterIndices *parameterIndices, ASTContext &C);
431
457
432
- bool operator ==(const AutoDiffAssociatedFunctionIdentifier &other) const {
433
- return kind == other.kind && differentiationOrder == other.differentiationOrder &&
434
- *parameterIndices == *other.parameterIndices ;
458
+ void Profile (llvm::FoldingSetNodeID &ID) {
459
+ ID.AddInteger (kind);
460
+ ID.AddInteger (differentiationOrder);
461
+ ID.AddPointer (parameterIndices);
435
462
}
436
463
};
437
464
0 commit comments