Skip to content

Commit 34dcfe3

Browse files
committed
Move self-parameter-reordering type computation from SIL to AST.
Address review feedback from @marcrasi. Computing self-parameter-reordered JVP/VJP types on AST function types instead of SIL function types is safer and simpler.
1 parent a2d0e3f commit 34dcfe3

File tree

7 files changed

+37
-134
lines changed

7 files changed

+37
-134
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,17 @@ class AutoDiffParameterIndices : public llvm::FoldingSetNode {
149149
///
150150
/// functionType = (A, B) -> (C, D) -> R
151151
/// if "A", "C", and "D" are in the set,
152-
/// ==> pushes {A, C, D} to `paramTypes`.
152+
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is false,
153+
/// or pushes {C, D, A} otherwise.
153154
///
154155
/// functionType = (Self) -> (A, B, C) -> R
155156
/// if "Self" and "C" are in the set,
156-
/// ==> pushes {Self, C} to `paramTypes`.
157+
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is false,
158+
/// or pushes {C, Self} otherwise.
157159
///
158160
void getSubsetParameterTypes(AnyFunctionType *functionType,
159-
SmallVectorImpl<Type> &paramTypes) const;
161+
SmallVectorImpl<Type> &paramTypes,
162+
bool reverseCurryLevels = false) const;
160163

161164
/// Returns a bitvector for the SILFunction parameters corresponding to the
162165
/// parameters in this set. In particular, this explodes tuples. For example,

include/swift/AST/Types.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3072,14 +3072,24 @@ class AnyFunctionType : public TypeBase {
30723072
/// Given `indices`, `differentiationOrder`, and `kind`, calculates the type
30733073
/// of the corresponding autodiff associated function.
30743074
///
3075-
/// \note The original function type (`self`) need not be `@differentiable`,
3076-
/// and the resulting function will preserve all `ExtInfo` of the original
3075+
/// By default, if the original type has a self parameter list and parameter
3076+
/// indices include self, the computed associated function type will return a
3077+
/// linear map taking/returning self's tangent/cotangent *last* instead of
3078+
/// first, for consistency with SIL.
3079+
///
3080+
/// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to
3081+
/// appear first. This should be used during type-checking, e.g.
3082+
/// type-checking `@differentiable` and `@differentiating` attributes.
3083+
///
3084+
/// \note The original function type (`self`) need not be `@differentiable`.
3085+
/// The resulting function will preserve all `ExtInfo` of the original
30773086
/// function, including `@differentiable`.
30783087
AnyFunctionType *getAutoDiffAssociatedFunctionType(
30793088
AutoDiffParameterIndices *indices, unsigned resultIndex,
30803089
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
30813090
LookupConformanceFn lookupConformance,
3082-
GenericSignature *whereClauseGenericSignature = nullptr);
3091+
GenericSignature *whereClauseGenericSignature = nullptr,
3092+
bool makeSelfParamFirst = false);
30833093

30843094
/// Given the type of an autodiff associated function, returns the
30853095
/// corresponding original function type.

lib/AST/AutoDiff.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,16 @@ static void unwrapCurryLevels(AnyFunctionType *fnTy,
165165
/// ==> pushes {A, C} to `paramTypes`.
166166
///
167167
/// functionType = (A, B) -> (C, D) -> R
168-
/// if "A", "C", and "D" are in the set,
169-
/// ==> pushes {A, C, D} to `paramTypes`.
168+
/// ==> pushes {A, C, D} to `paramTypes` if `reverseCurryLevels` is true,
169+
/// or pushes {C, D, A} otherwise.
170170
///
171171
/// functionType = (Self) -> (A, B, C) -> R
172-
/// if "Self" and "C" are in the set,
173-
/// ==> pushes {Self, C} to `paramTypes`.
172+
/// ==> pushes {Self, C} to `paramTypes` if `reverseCurryLevels` is true,
173+
/// or pushes {C, Self} otherwise.
174174
///
175175
void AutoDiffParameterIndices::getSubsetParameterTypes(
176-
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes) const {
176+
AnyFunctionType *functionType, SmallVectorImpl<Type> &paramTypes,
177+
bool reverseCurryLevels) const {
177178
SmallVector<AnyFunctionType *, 2> curryLevels;
178179
unwrapCurryLevels(functionType, curryLevels);
179180

@@ -184,6 +185,13 @@ void AutoDiffParameterIndices::getSubsetParameterTypes(
184185
currentOffset += curryLevels[curryLevelIndex]->getNumParams();
185186
}
186187

188+
// If `reverseCurryLevels` is true, reverse the curry levels and offsets.
189+
if (reverseCurryLevels) {
190+
std::reverse(curryLevels.begin(), curryLevels.end());
191+
std::reverse(curryLevelParameterIndexOffsets.begin(),
192+
curryLevelParameterIndexOffsets.end());
193+
}
194+
187195
for (unsigned curryLevelIndex : indices(curryLevels)) {
188196
auto *curryLevel = curryLevels[curryLevelIndex];
189197
unsigned parameterIndexOffset =

lib/AST/Type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4444,7 +4444,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44444444
AutoDiffParameterIndices *indices, unsigned resultIndex,
44454445
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
44464446
LookupConformanceFn lookupConformance,
4447-
GenericSignature *whereClauseGenSig) {
4447+
GenericSignature *whereClauseGenSig, bool makeSelfParamFirst) {
44484448
// JVP: (T...) -> ((R...),
44494449
// (T.TangentVector...) -> (R.TangentVector...))
44504450
// VJP: (T...) -> ((R...),
@@ -4460,7 +4460,8 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
44604460
auto &ctx = getASTContext();
44614461

44624462
SmallVector<Type, 8> wrtParamTypes;
4463-
indices->getSubsetParameterTypes(this, wrtParamTypes);
4463+
indices->getSubsetParameterTypes(
4464+
this, wrtParamTypes, /*reverseCurryLevels*/ !makeSelfParamFirst);
44644465

44654466
// Unwrap curry levels. At most, two parameter lists are necessary, for
44664467
// curried method types with a `(Self)` parameter list.

lib/SIL/SILFunctionType.cpp

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,74 +1222,6 @@ static CanSILFunctionType getSILFunctionType(
12221222
? isPseudogeneric(*constant)
12231223
: false;
12241224

1225-
// SWIFT_ENABLE_TENSORFLOW
1226-
// If constant is an autodiff associated function and is differentiable wrt
1227-
// self, reorder self so that it appears as:
1228-
// - The last parameter in the differential.
1229-
// - The last result in the pullback.
1230-
if (constant && constant->hasFuncDecl() &&
1231-
constant->autoDiffAssociatedFunctionIdentifier) {
1232-
auto *AFD = constant->getAbstractFunctionDecl();
1233-
auto *autoDiffFuncId = constant->autoDiffAssociatedFunctionIdentifier;
1234-
assert(results.size() == 2);
1235-
auto linearMapResultInfo = results.back();
1236-
auto linearMapType =
1237-
linearMapResultInfo.getSILStorageType().castTo<SILFunctionType>();
1238-
// Compute autodiff indices.
1239-
auto loweredParamIndices =
1240-
autoDiffFuncId->getParameterIndices()->getLowered(
1241-
M.getASTContext(),
1242-
AFD->getInterfaceType()->castTo<AnyFunctionType>());
1243-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
1244-
auto selfParamIndex = inputs.size() - 1;
1245-
// Reorder *only* if differentiable wrt to self.
1246-
if (indices.isWrtParameter(selfParamIndex)) {
1247-
switch (autoDiffFuncId->getKind()) {
1248-
// Reorder self as the first differential parameter.
1249-
case AutoDiffAssociatedFunctionKind::JVP: {
1250-
SmallVector<SILParameterInfo, 8> differentialParams;
1251-
differentialParams.append(linearMapType->getParameters().begin(),
1252-
linearMapType->getParameters().end());
1253-
std::rotate(differentialParams.begin(), differentialParams.begin() + 1,
1254-
differentialParams.end());
1255-
auto differentialType = SILFunctionType::get(
1256-
linearMapType->getGenericSignature(), linearMapType->getExtInfo(),
1257-
linearMapType->getCoroutineKind(),
1258-
linearMapType->getCalleeConvention(), differentialParams,
1259-
linearMapType->getYields(), linearMapType->getResults(),
1260-
linearMapType->getOptionalErrorResult(), M.getASTContext());
1261-
results.pop_back();
1262-
auto newDifferentialResult = SILResultInfo(
1263-
differentialType->getCanonicalType(),
1264-
linearMapResultInfo.getConvention());
1265-
results.push_back(newDifferentialResult);
1266-
break;
1267-
}
1268-
// Reorder self as the last pullback result.
1269-
case AutoDiffAssociatedFunctionKind::VJP: {
1270-
SmallVector<SILResultInfo, 8> pullbackResults;
1271-
pullbackResults.append(linearMapType->getResults().begin(),
1272-
linearMapType->getResults().end());
1273-
std::rotate(pullbackResults.begin(), pullbackResults.begin() + 1,
1274-
pullbackResults.end());
1275-
auto pullbackType = SILFunctionType::get(
1276-
linearMapType->getGenericSignature(), linearMapType->getExtInfo(),
1277-
linearMapType->getCoroutineKind(),
1278-
linearMapType->getCalleeConvention(),
1279-
linearMapType->getParameters(), linearMapType->getYields(),
1280-
pullbackResults, linearMapType->getOptionalErrorResult(),
1281-
M.getASTContext());
1282-
auto newPullbackResult = SILResultInfo(
1283-
pullbackType->getCanonicalType(),
1284-
linearMapResultInfo.getConvention());
1285-
results.pop_back();
1286-
results.push_back(newPullbackResult);
1287-
break;
1288-
}
1289-
}
1290-
}
1291-
}
1292-
12931225
// NOTE: SILFunctionType::ExtInfo doesn't track everything that
12941226
// AnyFunctionType::ExtInfo tracks. For example: 'throws' or 'auto-closure'
12951227
auto silExtInfo = SILFunctionType::ExtInfo()

lib/SILGen/SILGenPoly.cpp

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4210,57 +4210,6 @@ void SILGenFunction::emitProtocolWitness(AbstractionPattern reqtOrigTy,
42104210
origWitnessFTy,
42114211
witnessKind, witnessSubs,
42124212
witnessParams, loc);
4213-
// SWIFT_ENABLE_TENSORFLOW
4214-
// Perform witness method thunking.
4215-
if (auto *autoDiffFuncId = witness.autoDiffAssociatedFunctionIdentifier) {
4216-
auto *AFD = witness.getAbstractFunctionDecl();
4217-
auto loweredParamIndices =
4218-
autoDiffFuncId->getParameterIndices()->getLowered(
4219-
this->getASTContext(),
4220-
AFD->getInterfaceType()->castTo<AnyFunctionType>());
4221-
SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices);
4222-
auto assocFnKind = autoDiffFuncId->getKind();
4223-
4224-
unsigned selfParamIndex = F.getLoweredFunctionType()->getNumParameters() - 1;
4225-
bool isWrtSelf = indices.isWrtParameter(selfParamIndex);
4226-
if (isWrtSelf && indices.parameters->getNumIndices() > 1) {
4227-
// Given the type of an autodiff associated method that is differentiable
4228-
// wrt self, return a version where self's tangent/cotangent is reordered
4229-
// in the returned linear map.
4230-
// This is done by computing the original function type and recomputing
4231-
// the corresponding associated function type.
4232-
auto getReorderedFunctionType =
4233-
[&](CanAnyFunctionType fnType) -> CanAnyFunctionType {
4234-
auto *fnOrigType = fnType->getAutoDiffOriginalFunctionType();
4235-
auto *newFnType = fnOrigType->getAutoDiffAssociatedFunctionType(
4236-
autoDiffFuncId->getParameterIndices(), /*resultIndex*/ 0,
4237-
/*differentiationOrder*/ 1, assocFnKind,
4238-
LookUpConformanceInModule(F.getModule().getSwiftModule()),
4239-
fnType->getOptGenericSignature());
4240-
CanGenericSignature canGenSig;
4241-
if (auto *genSig = newFnType->getOptGenericSignature())
4242-
canGenSig = genSig->getCanonicalSignature();
4243-
return CanAnyFunctionType::get(
4244-
canGenSig, newFnType->getParams(),
4245-
newFnType->getResult()->getCanonicalType(canGenSig));
4246-
};
4247-
4248-
// Given the abstraction pattern of an autodiff associated method that is
4249-
// differentiable wrt self, return a version where self's
4250-
// tangent/cotangent is reordered in the returned linear map.
4251-
auto getReorderedUpdatePattern =
4252-
[&](AbstractionPattern pattern) -> AbstractionPattern {
4253-
auto canFnTy = pattern.getAs<AnyFunctionType>();
4254-
canFnTy = getReorderedFunctionType(canFnTy);
4255-
return AbstractionPattern(pattern.getGenericSignature(), canFnTy);
4256-
};
4257-
4258-
witnessOrigTy = getReorderedUpdatePattern(witnessOrigTy);
4259-
reqtOrigTy = getReorderedUpdatePattern(reqtOrigTy);
4260-
witnessSubstTy = getReorderedFunctionType(witnessSubstTy);
4261-
reqtSubstTy = getReorderedFunctionType(reqtSubstTy);
4262-
}
4263-
}
42644213

42654214
auto coroutineKind =
42664215
witnessFnRef->getType().castTo<SILFunctionType>()->getCoroutineKind();

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3120,7 +3120,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
31203120
originalFnTy->getAutoDiffAssociatedFunctionType(
31213121
checkedWrtParamIndices, /*resultIndex*/ 0,
31223122
/*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::JVP,
3123-
lookupConformance, whereClauseGenSig);
3123+
lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true);
31243124

31253125
auto isValidJVP = [&](FuncDecl *jvpCandidate) {
31263126
TC.validateDeclForNameLookup(jvpCandidate);
@@ -3146,7 +3146,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
31463146
originalFnTy->getAutoDiffAssociatedFunctionType(
31473147
checkedWrtParamIndices, /*resultIndex*/ 0,
31483148
/*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::VJP,
3149-
lookupConformance, whereClauseGenSig);
3149+
lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true);
31503150

31513151
auto isValidVJP = [&](FuncDecl *vjpCandidate) {
31523152
TC.validateDeclForNameLookup(vjpCandidate);

0 commit comments

Comments
 (0)