Skip to content

Commit 207139c

Browse files
author
Marc Rasi
committed
delete isMethodFlag from the indices
1 parent 2c93727 commit 207139c

File tree

13 files changed

+99
-103
lines changed

13 files changed

+99
-103
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -96,45 +96,36 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
9696
/// Bits corresponding to parameters in the set are "on", and bits
9797
/// corresponding to parameters not in the set are "off".
9898
///
99-
/// Normally, the bits correspond to the function's parameters in order. For
100-
/// example,
99+
/// For non-method functions, the bits correspond to the function's
100+
//// parameters in order. For example,
101101
///
102102
/// Function type: (A, B, C) -> R
103103
/// Bits: [A][B][C]
104104
///
105-
/// When `isMethodFlag` is set, the bits correspond to the function's
106-
/// non-self parameters in order, followed by the function's self parameter.
107-
/// For example,
105+
/// For methods, the bits correspond to the function's non-self parameters
106+
/// in order, followed by the function's self parameter. For example,
108107
///
109108
/// Function type: (Self) -> (A, B, C) -> R
110109
/// Bits: [A][B][C][Self]
111110
///
112111
const llvm::SmallBitVector indices;
113112

114-
/// Whether the function is a method.
115-
///
116-
const bool isMethodFlag;
117-
118-
AutoDiffParameterIndices(llvm::SmallBitVector indices, bool isMethodFlag)
119-
: indices(indices), isMethodFlag(isMethodFlag) {}
113+
AutoDiffParameterIndices(llvm::SmallBitVector indices)
114+
: indices(indices) {}
120115

121116
static AutoDiffParameterIndices *get(llvm::SmallBitVector indices,
122-
bool isMethodFlag, ASTContext &C);
117+
ASTContext &C);
123118

124119
public:
125-
bool isMethod() const { return isMethodFlag; }
126-
127120
/// Allocates and initializes an `AutoDiffParameterIndices` corresponding to
128121
/// the given `string` generated by `getString()`. If the string is invalid,
129122
/// returns nullptr.
130123
static AutoDiffParameterIndices *create(ASTContext &C, StringRef string);
131124

132125
/// Returns a textual string description of these indices,
133126
///
134-
/// [FM][SU]+
127+
/// [SU]+
135128
///
136-
/// "F" means that `isMethodFlag` is false
137-
/// "M" means that `isMethodFlag` is true
138129
/// "S" means that the corresponding index is set
139130
/// "U" means that the corresponding index is unset
140131
std::string getString() const;
@@ -153,11 +144,14 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
153144
/// if "Self" and "C" are in the set,
154145
/// ==> pushes {Self, C} to `paramTypes`.
155146
///
147+
/// Pass `isMethod = true` when the function is a method.
148+
///
156149
/// Pass `selfUncurried = true` when the function type is for a method whose
157150
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
158151
///
159152
void getSubsetParameterTypes(AnyFunctionType *functionType,
160153
SmallVectorImpl<Type> &paramTypes,
154+
bool isMethod,
161155
bool selfUncurried = false) const;
162156

163157
/// Returns a bitvector for the SILFunction parameters corresponding to the
@@ -179,14 +173,16 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
179173
/// ==> returns 1110
180174
/// (because the lowered SIL type is (A, B, C, D) -> R)
181175
///
176+
/// Pass `isMethod = true` when the function is a method.
177+
///
182178
/// Pass `selfUncurried = true` when the function type is for a method whose
183179
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
184180
///
185181
llvm::SmallBitVector getLowered(AnyFunctionType *functionType,
182+
bool isMethod,
186183
bool selfUncurried = false) const;
187184

188185
void Profile(llvm::FoldingSetNodeID &ID) const {
189-
ID.AddBoolean(isMethodFlag);
190186
ID.AddInteger(indices.size());
191187
for (unsigned setBit : indices.set_bits())
192188
ID.AddInteger(setBit);
@@ -196,7 +192,7 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
196192
/// Builder for `AutoDiffParameterIndices`.
197193
class AutoDiffParameterIndicesBuilder {
198194
llvm::SmallBitVector indices;
199-
bool isMethodFlag;
195+
bool isMethod;
200196

201197
unsigned getNumNonSelfParameters() const;
202198

@@ -210,14 +206,14 @@ class AutoDiffParameterIndicesBuilder {
210206
/// one if it has already been allocated in the `ASTContext`.
211207
AutoDiffParameterIndices *build(ASTContext &C) const;
212208

213-
/// Adds the indexed parameter to the set. When `isMethodFlag` is not set,
209+
/// Adds the indexed parameter to the set. When `isMethod` is not set,
214210
/// the indices index into the first parameter list. For example,
215211
///
216212
/// functionType = (A, B, C) -> R
217213
/// paramIndex = 0
218214
/// ==> adds "A" to the set.
219215
///
220-
/// When `isMethodFlag` is set, the indices index into the first non-self
216+
/// When `isMethod` is set, the indices index into the first non-self
221217
/// parameter list. For example,
222218
///
223219
/// functionType = (Self) -> (A, B, C) -> R
@@ -237,7 +233,7 @@ class AutoDiffParameterIndicesBuilder {
237233
///
238234
void setAllNonSelfParameters();
239235

240-
/// Adds the self parameter to the set. `isMethodFlag` must be set. For
236+
/// Adds the self parameter to the set. `isMethod` must be set. For
241237
/// example,
242238
/// functionType = (Self) -> (A, B, C) -> R
243239
/// ==> adds "Self" to the set

include/swift/AST/Types.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3127,11 +3127,13 @@ class AnyFunctionType : public TypeBase {
31273127
AnyFunctionType *getAutoDiffAssociatedFunctionType(
31283128
AutoDiffParameterIndices *indices, unsigned resultIndex,
31293129
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
3130-
LookupConformanceFn lookupConformance, bool selfUncurried = false);
3130+
LookupConformanceFn lookupConformance, bool isMethod,
3131+
bool selfUncurried = false);
31313132

31323133
AnyFunctionType *
31333134
getAutoDiffAdjointFunctionType(AutoDiffParameterIndices *indices,
3134-
const TupleType *primalResultTy);
3135+
const TupleType *primalResultTy,
3136+
bool isMethod);
31353137

31363138
/// \brief True if this type allows an implicit conversion from a function
31373139
/// argument expression of type T to a function of type () -> T.

lib/AST/ASTContext.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5278,12 +5278,10 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
52785278
}
52795279

52805280
AutoDiffParameterIndices *
5281-
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, bool isMethodFlag,
5282-
ASTContext &C) {
5281+
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
52835282
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;
52845283

52855284
llvm::FoldingSetNodeID id;
5286-
id.AddBoolean(isMethodFlag);
52875285
id.AddInteger(indices.size());
52885286
for (unsigned setBit : indices.set_bits())
52895287
id.AddInteger(setBit);
@@ -5298,7 +5296,7 @@ AutoDiffParameterIndices::get(llvm::SmallBitVector indices, bool isMethodFlag,
52985296
// SmallBitVector decides to allocate some heap space.
52995297
void *mem = C.Allocate(sizeof(AutoDiffParameterIndices),
53005298
alignof(AutoDiffParameterIndices));
5301-
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices, isMethodFlag);
5299+
auto *newNode = ::new (mem) AutoDiffParameterIndices(indices);
53025300
foldingSet.InsertNode(newNode, insertPos);
53035301

53045302
return newNode;

lib/AST/AutoDiff.cpp

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -110,39 +110,28 @@ static AnyFunctionType *unwrapSelfParameter(AnyFunctionType *functionType,
110110
/// returns nullptr.
111111
AutoDiffParameterIndices *
112112
AutoDiffParameterIndices::create(ASTContext &C, StringRef string) {
113-
if (string.size() < 2)
113+
if (string.size() < 1)
114114
return nullptr;
115115

116-
bool isMethod = false;
117116
llvm::SmallBitVector indices(string.size() - 1);
118-
if (string[0] == 'M')
119-
isMethod = true;
120-
else if (string[0] != 'F')
121-
return nullptr;
122117
for (unsigned i : range(indices.size())) {
123118
if (string[i + 1] == 'S')
124119
indices.set(i);
125120
else if (string[i + 1] != 'U')
126121
return nullptr;
127122
}
128123

129-
return get(indices, isMethod, C);
124+
return get(indices, C);
130125
}
131126

132127
/// Returns a textual string description of these indices,
133128
///
134-
/// [FM][SU]+
129+
/// [SU]+
135130
///
136-
/// "F" means that `isMethodFlag` is false
137-
/// "M" means that `isMethodFlag` is true
138131
/// "S" means that the corresponding index is set
139132
/// "U" means that the corresponding index is unset
140133
std::string AutoDiffParameterIndices::getString() const {
141134
std::string result;
142-
if (isMethodFlag)
143-
result += "M";
144-
else
145-
result += "F";
146135
for (unsigned i : range(indices.size())) {
147136
if (indices[i])
148137
result += "S";
@@ -163,21 +152,23 @@ std::string AutoDiffParameterIndices::getString() const {
163152
/// if "Self" and "C" are in the set,
164153
/// ==> pushes {Self, C} to `paramTypes`.
165154
///
155+
/// Pass `isMethod = true` when the function is a method.
156+
///
166157
/// Pass `selfUncurried = true` when the function type is for a method whose
167158
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
168159
///
169160
void AutoDiffParameterIndices::getSubsetParameterTypes(
170161
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes,
171-
bool selfUncurried) const {
172-
if (selfUncurried && isMethodFlag) {
173-
if (isMethodFlag && indices[indices.size() - 1])
162+
bool isMethod, bool selfUncurried) const {
163+
if (selfUncurried && isMethod) {
164+
if (isMethod && indices[indices.size() - 1])
174165
paramTypes.push_back(functionType->getParams()[functionType->getNumParams() - 1].getPlainType());
175166
for (unsigned paramIndex : range(functionType->getNumParams() - 1))
176167
if (indices[paramIndex])
177168
paramTypes.push_back(functionType->getParams()[paramIndex].getPlainType());
178169
} else {
179-
AnyFunctionType *unwrapped = unwrapSelfParameter(functionType, isMethodFlag);
180-
if (isMethodFlag && indices[indices.size() - 1])
170+
AnyFunctionType *unwrapped = unwrapSelfParameter(functionType, isMethod);
171+
if (isMethod && indices[indices.size() - 1])
181172
paramTypes.push_back(functionType->getParams()[0].getPlainType());
182173
for (unsigned paramIndex : range(unwrapped->getNumParams()))
183174
if (indices[paramIndex])
@@ -213,16 +204,18 @@ static unsigned countNumFlattenedElementTypes(Type type) {
213204
/// ==> returns 1110
214205
/// (because the lowered SIL type is (A, B, C, D) -> R)
215206
///
207+
/// Pass `isMethod = true` when the function is a method.
208+
///
216209
/// Pass `selfUncurried = true` when the function type is a for method whose
217210
/// self parameter has been uncurried as in (A, B, C, Self) -> R.
218211
///
219212
llvm::SmallBitVector
220213
AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType,
221-
bool selfUncurried) const {
214+
bool isMethod, bool selfUncurried) const {
222215
// Calculate the lowered sizes of all the parameters.
223216
AnyFunctionType *unwrapped = selfUncurried
224217
? functionType
225-
: unwrapSelfParameter(functionType, isMethodFlag);
218+
: unwrapSelfParameter(functionType, isMethod);
226219
SmallVector<unsigned, 8> paramLoweredSizes;
227220
unsigned totalLoweredSize = 0;
228221
auto addLoweredParamInfo = [&](Type type) {
@@ -232,7 +225,7 @@ AutoDiffParameterIndices::getLowered(AnyFunctionType *functionType,
232225
};
233226
for (auto &param : unwrapped->getParams())
234227
addLoweredParamInfo(param.getPlainType());
235-
if (isMethodFlag && !selfUncurried)
228+
if (isMethod && !selfUncurried)
236229
addLoweredParamInfo(functionType->getParams()[0].getPlainType());
237230

238231
// Construct the result by setting each range of bits that corresponds to each
@@ -256,29 +249,29 @@ static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *functionType,
256249
}
257250

258251
unsigned AutoDiffParameterIndicesBuilder::getNumNonSelfParameters() const {
259-
return indices.size() - (isMethodFlag ? 1 : 0);
252+
return indices.size() - (isMethod ? 1 : 0);
260253
}
261254

262255
AutoDiffParameterIndicesBuilder::AutoDiffParameterIndicesBuilder(
263256
AnyFunctionType *functionType, bool isMethod, bool setAllParams) :
264257
indices(getNumAutoDiffParameterIndices(functionType, isMethod),
265258
setAllParams),
266-
isMethodFlag(isMethod) {
259+
isMethod(isMethod) {
267260
}
268261

269262
AutoDiffParameterIndices *
270263
AutoDiffParameterIndicesBuilder::build(ASTContext &C) const {
271-
return AutoDiffParameterIndices::get(indices, isMethodFlag, C);
264+
return AutoDiffParameterIndices::get(indices, C);
272265
}
273266

274-
/// Adds the indexed parameter to the set. When `isMethodFlag` is not set, the
267+
/// Adds the indexed parameter to the set. When `isMethod` is not set, the
275268
/// indices index into the first parameter list. For example,
276269
///
277270
/// functionType = (A, B, C) -> R
278271
/// paramIndex = 0
279272
/// ==> adds "A" to the set.
280273
///
281-
/// When `isMethodFlag` is set, the indices index into the first non-self
274+
/// When `isMethod` is set, the indices index into the first non-self
282275
/// parameter list. For example,
283276
///
284277
/// functionType = (Self) -> (A, B, C) -> R
@@ -303,14 +296,14 @@ void AutoDiffParameterIndicesBuilder::setAllNonSelfParameters() {
303296
indices.set(0, getNumNonSelfParameters());
304297
}
305298

306-
/// Adds the self parameter to the set. `isMethodFlag` must be set. For
299+
/// Adds the self parameter to the set. `isMethod` must be set. For
307300
/// example,
308301
///
309302
/// functionType = (Self) -> (A, B, C) -> R
310303
/// ==> adds "Self" to the set
311304
///
312305
void AutoDiffParameterIndicesBuilder::setSelfParameter() {
313-
assert(isMethodFlag &&
306+
assert(isMethod &&
314307
"trying to add self param to non-method parameter indices");
315308
indices.set(indices.size() - 1);
316309
}

lib/AST/Builtins.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ static ValueDecl *getAutoDiffGetAssociatedFunction(
10771077
// stored in the function type.
10781078
auto *vjpType = origFnTy->getAutoDiffAssociatedFunctionType(
10791079
paramIndices, /*resultIndex*/ 0, /*differentiationOrder*/ 1,
1080-
kind, /*lookupConformance*/ nullptr);
1080+
kind, /*lookupConformance*/ nullptr, /*isMethod*/ false);
10811081
vjpType = vjpType->withExtInfo(vjpType->getExtInfo().withNoEscape(false));
10821082
return vjpType;
10831083
}

lib/AST/Type.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,13 +4102,14 @@ makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
41024102
}
41034103

41044104
AnyFunctionType *AnyFunctionType::getAutoDiffAdjointFunctionType(
4105-
AutoDiffParameterIndices *indices, const TupleType *primalResultTy) {
4105+
AutoDiffParameterIndices *indices, const TupleType *primalResultTy,
4106+
bool isMethod) {
41064107
assert(!indices->isEmpty() && "there must be at least one wrt index");
41074108

41084109
// Compute the return type of the adjoint.
41094110
SmallVector<TupleTypeElt, 8> retElts;
41104111
SmallVector<Type, 8> wrtParamTypes;
4111-
indices->getSubsetParameterTypes(this, wrtParamTypes);
4112+
indices->getSubsetParameterTypes(this, wrtParamTypes, isMethod);
41124113
for (auto wrtParamType : wrtParamTypes)
41134114
retElts.push_back(wrtParamType);
41144115

@@ -4121,7 +4122,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAdjointFunctionType(
41214122
// If this is a method, unwrap the function type so that we can see the
41224123
// non-self parameters.
41234124
AnyFunctionType *unwrapped = this;
4124-
if (indices->isMethod())
4125+
if (isMethod)
41254126
unwrapped = unwrapped->getResult()->castTo<AnyFunctionType>();
41264127

41274128
// Compute the adjoint parameters.
@@ -4149,7 +4150,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAdjointFunctionType(
41494150

41504151
// If this is a method, wrap the adjoint type in an additional "(Self) ->"
41514152
// curry level.
4152-
if (indices->isMethod())
4153+
if (isMethod)
41534154
adjoint = makeFunctionType(this, getParams(), adjoint);
41544155

41554156
return adjoint;
@@ -4158,7 +4159,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAdjointFunctionType(
41584159
AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
41594160
AutoDiffParameterIndices *indices, unsigned resultIndex,
41604161
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
4161-
LookupConformanceFn lookupConformance, bool selfUncurried) {
4162+
LookupConformanceFn lookupConformance, bool isMethod, bool selfUncurried) {
41624163
// JVP: (T...) -> ((R...),
41634164
// (T.TangentVector...) -> (R.TangentVector...))
41644165
// VJP: (T...) -> ((R...),
@@ -4221,12 +4222,12 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
42214222
};
42224223

42234224
SmallVector<Type, 8> wrtParamTypes;
4224-
indices->getSubsetParameterTypes(this, wrtParamTypes, selfUncurried);
4225+
indices->getSubsetParameterTypes(this, wrtParamTypes, isMethod, selfUncurried);
42254226

42264227
// If this is a method, unwrap the function type so that we can see the
42274228
// non-self parameters and the final result.
42284229
AnyFunctionType *unwrapped = this;
4229-
if (indices->isMethod() && !selfUncurried)
4230+
if (isMethod && !selfUncurried)
42304231
unwrapped = unwrapped->getResult()->castTo<AnyFunctionType>();
42314232
Type originalResult = unwrapped->getResult();
42324233

@@ -4298,7 +4299,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
42984299

42994300
// If this is a method, wrap the associated function type in an additional
43004301
// "(Self) ->" curry level.
4301-
if (indices->isMethod() && !selfUncurried)
4302+
if (isMethod && !selfUncurried)
43024303
associatedFunction =
43034304
makeFunctionType(this, getParams(), associatedFunction);
43044305

0 commit comments

Comments
 (0)