@@ -79,107 +79,60 @@ class AutoDiffParameter {
79
79
};
80
80
81
81
class AnyFunctionType ;
82
+ class AutoDiffParameterIndicesBuilder ;
82
83
class Type ;
83
84
84
85
// / Identifies a subset of a function's parameters.
85
86
// /
86
87
// / Works with AST-level function decls and types. Requires further lowering to
87
88
// / work with SIL-level functions and types. (In particular, tuples must be
88
89
// / exploded).
89
- class AutoDiffParameterIndices {
90
+ // /
91
+ // / Is uniquely allocated within an ASTContext so that it can be hashed and
92
+ // / compared by opaque pointer value.
93
+ class AutoDiffParameterIndices : public llvm ::FoldingSetNode {
94
+ friend AutoDiffParameterIndicesBuilder;
95
+
90
96
// / Bits corresponding to parameters in the set are "on", and bits
91
97
// / corresponding to parameters not in the set are "off".
92
98
// /
93
- // / Normally , the bits correspond to the function's parameters in order. For
94
- // / example,
99
+ // / For non-method functions , the bits correspond to the function's
100
+ // // parameters in order. For example,
95
101
// /
96
102
// / Function type: (A, B, C) -> R
97
103
// / Bits: [A][B][C]
98
104
// /
99
- // / When `isMethodFlag` is set, the bits correspond to the function's
100
- // / non-self parameters in order, followed by the function's self parameter.
101
- // / For example,
105
+ // / For methods, the bits correspond to the function's non-self parameters
106
+ // / in order, followed by the function's self parameter. For example,
102
107
// /
103
108
// / Function type: (Self) -> (A, B, C) -> R
104
109
// / Bits: [A][B][C][Self]
105
110
// /
106
- llvm::SmallBitVector indices;
111
+ const llvm::SmallBitVector indices;
107
112
108
- // / Whether the function is a method.
109
- // /
110
- bool isMethodFlag;
113
+ AutoDiffParameterIndices (llvm::SmallBitVector indices)
114
+ : indices(indices) {}
111
115
112
- unsigned getNumNonSelfParameters () const ;
113
-
114
- AutoDiffParameterIndices (unsigned numIndices, bool isMethodFlag,
115
- bool setAllParams = false )
116
- : indices(numIndices, setAllParams), isMethodFlag(isMethodFlag) {}
117
-
118
- AutoDiffParameterIndices (llvm::SmallBitVector indices, bool isMethodFlag)
119
- : indices(indices), isMethodFlag(isMethodFlag) {}
116
+ static AutoDiffParameterIndices *get (llvm::SmallBitVector indices,
117
+ ASTContext &C);
120
118
121
119
public:
122
- // / Allocates and initializes an empty `AutoDiffParameterIndices` for the
123
- // / given `functionType`. `isMethod` specifies whether to treat the function
124
- // / as a method.
125
- static AutoDiffParameterIndices *
126
- create (ASTContext &C, AnyFunctionType *functionType, bool isMethod,
127
- bool setAllParams = false );
128
-
129
- bool isMethod () const { return isMethodFlag; }
130
-
131
120
// / Allocates and initializes an `AutoDiffParameterIndices` corresponding to
132
121
// / the given `string` generated by `getString()`. If the string is invalid,
133
122
// / returns nullptr.
134
123
static AutoDiffParameterIndices *create (ASTContext &C, StringRef string);
135
124
136
125
// / Returns a textual string description of these indices,
137
126
// /
138
- // / [FM][ SU]+
127
+ // / [SU]+
139
128
// /
140
- // / "F" means that `isMethodFlag` is false
141
- // / "M" means that `isMethodFlag` is true
142
129
// / "S" means that the corresponding index is set
143
130
// / "U" means that the corresponding index is unset
144
131
std::string getString () const ;
145
132
146
133
// / Tests whether this set of parameters is empty.
147
134
bool isEmpty () const { return indices.none (); }
148
135
149
- // / Adds the indexed parameter to the set. When `isMethodFlag` is not set,
150
- // / the indices index into the first parameter list. For example,
151
- // /
152
- // / functionType = (A, B, C) -> R
153
- // / paramIndex = 0
154
- // / ==> adds "A" to the set.
155
- // /
156
- // / When `isMethodFlag` is set, the indices index into the first non-self
157
- // / parameter list. For example,
158
- // /
159
- // / functionType = (Self) -> (A, B, C) -> R
160
- // / paramIndex = 0
161
- // / ==> adds "A" to the set.
162
- // /
163
- void setNonSelfParameter (unsigned parameterIndex);
164
-
165
- // / Adds all the paramaters from the first non-self parameter list to the set.
166
- // / For example,
167
- // /
168
- // / functionType = (A, B, C) -> R
169
- // / ==> adds "A", B", and "C" to the set.
170
- // /
171
- // / functionType = (Self) -> (A, B, C) -> R
172
- // / ==> adds "A", B", and "C" to the set.
173
- // /
174
- void setAllNonSelfParameters ();
175
-
176
- // / Adds the self parameter to the set. `isMethodFlag` must be set. For
177
- // / example,
178
- // / functionType = (Self) -> (A, B, C) -> R
179
- // / ==> adds "Self" to the set
180
- // /
181
- void setSelfParameter ();
182
-
183
136
// / Pushes the subset's parameter's types to `paramTypes`, in the order in
184
137
// / which they appear in the function type. For example,
185
138
// /
@@ -191,11 +144,14 @@ class AutoDiffParameterIndices {
191
144
// / if "Self" and "C" are in the set,
192
145
// / ==> pushes {Self, C} to `paramTypes`.
193
146
// /
147
+ // / Pass `isMethod = true` when the function is a method.
148
+ // /
194
149
// / Pass `selfUncurried = true` when the function type is for a method whose
195
150
// / self parameter has been uncurried as in (A, B, C, Self) -> R.
196
151
// /
197
152
void getSubsetParameterTypes (AnyFunctionType *functionType,
198
153
SmallVectorImpl<Type> ¶mTypes,
154
+ bool isMethod,
199
155
bool selfUncurried = false ) const ;
200
156
201
157
// / Returns a bitvector for the SILFunction parameters corresponding to the
@@ -217,17 +173,74 @@ class AutoDiffParameterIndices {
217
173
// / ==> returns 1110
218
174
// / (because the lowered SIL type is (A, B, C, D) -> R)
219
175
// /
176
+ // / Pass `isMethod = true` when the function is a method.
177
+ // /
220
178
// / Pass `selfUncurried = true` when the function type is for a method whose
221
179
// / self parameter has been uncurried as in (A, B, C, Self) -> R.
222
180
// /
223
181
llvm::SmallBitVector getLowered (AnyFunctionType *functionType,
182
+ bool isMethod,
224
183
bool selfUncurried = false ) const ;
225
184
226
- bool operator ==(const AutoDiffParameterIndices &other) const {
227
- return isMethodFlag == other.isMethodFlag && indices == other.indices ;
185
+ void Profile (llvm::FoldingSetNodeID &ID) const {
186
+ ID.AddInteger (indices.size ());
187
+ for (unsigned setBit : indices.set_bits ())
188
+ ID.AddInteger (setBit);
228
189
}
229
190
};
230
191
192
+ // / Builder for `AutoDiffParameterIndices`.
193
+ class AutoDiffParameterIndicesBuilder {
194
+ llvm::SmallBitVector indices;
195
+ bool isMethod;
196
+
197
+ unsigned getNumNonSelfParameters () const ;
198
+
199
+ public:
200
+ // / Start building an `AutoDiffParameterIndices` for the given function type.
201
+ // / `isMethod` specifies whether to treat the function as a method.
202
+ AutoDiffParameterIndicesBuilder (AnyFunctionType *functionType, bool isMethod,
203
+ bool setAllParams = false );
204
+
205
+ // / Builds the `AutoDiffParameterIndices`, returning a pointer to an existing
206
+ // / one if it has already been allocated in the `ASTContext`.
207
+ AutoDiffParameterIndices *build (ASTContext &C) const ;
208
+
209
+ // / Adds the indexed parameter to the set. When `isMethod` is not set,
210
+ // / the indices index into the first parameter list. For example,
211
+ // /
212
+ // / functionType = (A, B, C) -> R
213
+ // / paramIndex = 0
214
+ // / ==> adds "A" to the set.
215
+ // /
216
+ // / When `isMethod` is set, the indices index into the first non-self
217
+ // / parameter list. For example,
218
+ // /
219
+ // / functionType = (Self) -> (A, B, C) -> R
220
+ // / paramIndex = 0
221
+ // / ==> adds "A" to the set.
222
+ // /
223
+ void setNonSelfParameter (unsigned parameterIndex);
224
+
225
+ // / Adds all the paramaters from the first non-self parameter list to the set.
226
+ // / For example,
227
+ // /
228
+ // / functionType = (A, B, C) -> R
229
+ // / ==> adds "A", B", and "C" to the set.
230
+ // /
231
+ // / functionType = (Self) -> (A, B, C) -> R
232
+ // / ==> adds "A", B", and "C" to the set.
233
+ // /
234
+ void setAllNonSelfParameters ();
235
+
236
+ // / Adds the self parameter to the set. `isMethod` must be set. For
237
+ // / example,
238
+ // / functionType = (Self) -> (A, B, C) -> R
239
+ // / ==> adds "Self" to the set
240
+ // /
241
+ void setSelfParameter ();
242
+ };
243
+
231
244
// / Differentiability of a function specifies the differentiation mode,
232
245
// / parameter indices at which the function is differentiable with respect to,
233
246
// / and indices of results which can be differentiated.
@@ -413,25 +426,35 @@ struct AutoDiffAssociatedFunctionKind {
413
426
414
427
// / In conjunction with the original function decl, identifies an associated
415
428
// / autodiff function.
416
- class AutoDiffAssociatedFunctionIdentifier {
417
- AutoDiffAssociatedFunctionKind kind;
418
- unsigned differentiationOrder;
419
- AutoDiffParameterIndices *parameterIndices;
429
+ // /
430
+ // / Is uniquely allocated within an ASTContext so that it can be hashed and
431
+ // / compared by opaque pointer value.
432
+ class AutoDiffAssociatedFunctionIdentifier : public llvm ::FoldingSetNode {
433
+ const AutoDiffAssociatedFunctionKind kind;
434
+ const unsigned differentiationOrder;
435
+ AutoDiffParameterIndices * const parameterIndices;
436
+
437
+ AutoDiffAssociatedFunctionIdentifier (
438
+ AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
439
+ AutoDiffParameterIndices *parameterIndices) :
440
+ kind (kind), differentiationOrder(differentiationOrder),
441
+ parameterIndices (parameterIndices) {}
420
442
421
443
public:
422
444
AutoDiffAssociatedFunctionKind getKind () const { return kind; }
423
445
unsigned getDifferentiationOrder () const { return differentiationOrder; }
424
- const AutoDiffParameterIndices *getParameterIndices () const {
446
+ AutoDiffParameterIndices *getParameterIndices () const {
425
447
return parameterIndices;
426
448
}
427
449
428
450
static AutoDiffAssociatedFunctionIdentifier *get (
429
451
AutoDiffAssociatedFunctionKind kind, unsigned differentiationOrder,
430
452
AutoDiffParameterIndices *parameterIndices, ASTContext &C);
431
453
432
- bool operator ==(const AutoDiffAssociatedFunctionIdentifier &other) const {
433
- return kind == other.kind && differentiationOrder == other.differentiationOrder &&
434
- *parameterIndices == *other.parameterIndices ;
454
+ void Profile (llvm::FoldingSetNodeID &ID) {
455
+ ID.AddInteger (kind);
456
+ ID.AddInteger (differentiationOrder);
457
+ ID.AddPointer (parameterIndices);
435
458
}
436
459
};
437
460
0 commit comments