Skip to content

Commit f1a604d

Browse files
committed
[AutoDiff] Support differentiation of apply with inout arguments.
Add reverse-mode differentiation support for `apply` with `inout` arguments. Notable pullback generation changes: - If the pullback seed argument is `inout`, assign it (rather than a copy) directly as the adjoint buffer of the original result. This is important so the value is updated in-place. - In `visitApplyInst`: skip adjoint accumulation for `inout` arguments. Adjoint accumulation for `inout` arguments occurs when callee pullbacks are applied, so no extra accumulation is necessary. Add derivatives for functions with `inout` parameters in the stdlib for testing: - `FloatingPoint` operations: `+=`, `-=`, `*=`, `/=` - `Array.append` Resolves TF-1165. Todos: - Add more tests, e.g. SILGen tests for `inout` derivative typing rules. - Evaluate performance of `inout` derivatives vs functional derivatives + mutation. - TF-1166: enable `@differentiable` attribute on `set` accessors. - TF-1173: add forward-mode differentiation support for `apply` with `inout` parameters.
1 parent 2cddff0 commit f1a604d

25 files changed

+629
-234
lines changed

include/swift/SIL/AbstractionPattern.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,8 +1146,7 @@ class AbstractionPattern {
11461146
/// The arguments are the same as the arguments to
11471147
/// `AnyFunctionType::getAutoDiffDerivativeFunctionType()`.
11481148
AbstractionPattern getAutoDiffDerivativeFunctionType(
1149-
IndexSubset *indices, unsigned resultIndex,
1150-
AutoDiffDerivativeFunctionKind kind,
1149+
IndexSubset *indices, AutoDiffDerivativeFunctionKind kind,
11511150
LookupConformanceFn lookupConformance,
11521151
GenericSignature derivativeGenericSignature = GenericSignature(),
11531152
bool makeSelfParamFirst = false);

lib/AST/Builtins.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ static ValueDecl *getAutoDiffApplyDerivativeFunction(
10351035
BuiltinFunctionBuilder::LambdaGenerator resultGen{
10361036
[=, &Context](BuiltinFunctionBuilder &builder) -> Type {
10371037
auto derivativeFnTy = diffFnType->getAutoDiffDerivativeFunctionType(
1038-
paramIndices, /*resultIndex*/ 0, kind,
1038+
paramIndices, kind,
10391039
LookUpConformanceInModule(Context.TheBuiltinModule));
10401040
return derivativeFnTy->getResult();
10411041
}};

lib/SIL/AbstractionPattern.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -926,17 +926,17 @@ const {
926926
}
927927

928928
AbstractionPattern AbstractionPattern::getAutoDiffDerivativeFunctionType(
929-
IndexSubset *indices, unsigned resultIndex,
930-
AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
929+
IndexSubset *indices, AutoDiffDerivativeFunctionKind kind,
930+
LookupConformanceFn lookupConformance,
931931
GenericSignature derivativeGenericSignature, bool makeSelfParamFirst) {
932932
switch (getKind()) {
933933
case Kind::Type: {
934934
auto fnTy = dyn_cast<AnyFunctionType>(getType());
935935
if (!fnTy)
936936
return getOpaqueDerivativeFunction();
937937
auto derivativeFnTy = fnTy->getAutoDiffDerivativeFunctionType(
938-
indices, resultIndex, kind, lookupConformance,
939-
derivativeGenericSignature, makeSelfParamFirst);
938+
indices, kind, lookupConformance, derivativeGenericSignature,
939+
makeSelfParamFirst);
940940
assert(derivativeFnTy);
941941
return AbstractionPattern(getGenericSignature(),
942942
derivativeFnTy->getCanonicalType());

lib/SIL/SILFunctionType.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
231231
LookupConformanceFn lookupConformance,
232232
CanGenericSignature derivativeFnGenSig, bool isReabstractionThunk) {
233233
auto &ctx = getASTContext();
234-
auto resultIndices = IndexSubset::get(ctx, getNumResults(), {resultIndex});
234+
auto *resultIndices = IndexSubset::get(
235+
ctx, getNumResults() + getNumIndirectMutatingParameters(), {resultIndex});
235236
SILAutoDiffDerivativeFunctionKey key{
236237
this, parameterIndices, resultIndices,
237238
kind, derivativeFnGenSig, isReabstractionThunk};

lib/SIL/TypeLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,8 +2181,8 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
21812181
auto originalFnTy =
21822182
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
21832183
auto *fnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
2184-
autoDiffFuncId->getParameterIndices(), /*resultIndex*/ 0,
2185-
autoDiffFuncId->getKind(), LookUpConformanceInModule(&M));
2184+
autoDiffFuncId->getParameterIndices(), autoDiffFuncId->getKind(),
2185+
LookUpConformanceInModule(&M));
21862186
return cast<AnyFunctionType>(fnTy->getCanonicalType());
21872187
}
21882188

lib/SILGen/SILGenPoly.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,18 +3342,18 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33423342
auto *parameterIndices = IndexSubset::get(SGF.getASTContext(), parameterBits);
33433343

33443344
auto getDerivativeFnTy =
3345-
[&](CanAnyFunctionType fnTy, AutoDiffDerivativeFunctionKind kind)
3346-
-> CanAnyFunctionType {
3347-
auto assocTy = fnTy->getAutoDiffDerivativeFunctionType(
3348-
parameterIndices, /*resultIndex*/ 0,
3349-
kind, LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
3350-
return cast<AnyFunctionType>(assocTy->getCanonicalType());
3351-
};
3345+
[&](CanAnyFunctionType fnTy,
3346+
AutoDiffDerivativeFunctionKind kind) -> CanAnyFunctionType {
3347+
auto assocTy = fnTy->getAutoDiffDerivativeFunctionType(
3348+
parameterIndices, kind,
3349+
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
3350+
return cast<AnyFunctionType>(assocTy->getCanonicalType());
3351+
};
33523352
auto getDerivativeFnPattern =
33533353
[&](AbstractionPattern pattern,
33543354
AutoDiffDerivativeFunctionKind kind) -> AbstractionPattern {
33553355
return pattern.getAutoDiffDerivativeFunctionType(
3356-
parameterIndices, /*resultIndex*/ 0, kind,
3356+
parameterIndices, kind,
33573357
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
33583358
};
33593359
auto createDerivativeFnThunk =

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -508,9 +508,10 @@ emitDerivativeFunctionReference(
508508
auto *originalFn = originalFRI->getReferencedFunctionOrNull();
509509
assert(originalFn);
510510
auto originalFnTy = originalFn->getLoweredFunctionType();
511-
auto *desiredResultIndices =
512-
IndexSubset::get(context.getASTContext(), originalFnTy->getNumResults(),
513-
{desiredIndices.source});
511+
auto numResults = originalFnTy->getNumResults() +
512+
originalFnTy->getNumIndirectMutatingParameters();
513+
auto *desiredResultIndices = IndexSubset::get(
514+
context.getASTContext(), numResults, {desiredIndices.source});
514515
auto *desiredParameterIndices = desiredIndices.parameters;
515516
// NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
516517
// parameters corresponding to captured variables.
@@ -550,9 +551,19 @@ emitDerivativeFunctionReference(
550551
}
551552
}
552553
// Check and diagnose non-differentiable results.
553-
if (!originalFnTy->getResults()[desiredIndices.source]
554-
.getSILStorageInterfaceType()
555-
.isDifferentiable(context.getModule())) {
554+
SILType resultType;
555+
if (desiredIndices.source >= originalFnTy->getNumResults()) {
556+
auto inoutParamIdx =
557+
desiredIndices.source - originalFnTy->getNumResults();
558+
auto inoutParam =
559+
*std::next(originalFnTy->getIndirectMutatingParameters().begin(),
560+
inoutParamIdx);
561+
resultType = inoutParam.getSILStorageInterfaceType();
562+
} else {
563+
resultType = originalFnTy->getResults()[desiredIndices.source]
564+
.getSILStorageInterfaceType();
565+
}
566+
if (!resultType.isDifferentiable(context.getModule())) {
556567
context.emitNondifferentiabilityError(
557568
original, invoker, diag::autodiff_nondifferentiable_result);
558569
return None;

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ void collectAllFormalResultsInTypeOrder(SILFunction &function,
119119
for (auto &resInfo : convs.getResults())
120120
results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
121121
: indResults[indResIdx++]);
122+
// Treat `inout` parameters as semantic results.
123+
// Append `inout` parameters after formal results.
124+
for (auto i : range(convs.getNumParameters())) {
125+
auto paramInfo = convs.getParameters()[i];
126+
if (!paramInfo.isIndirectMutating())
127+
continue;
128+
auto *argument = function.getArgumentsWithoutIndirectResults()[i];
129+
results.push_back(argument);
130+
}
122131
}
123132

124133
/// Given a function, gathers all of its direct results in an order defined by
@@ -194,8 +203,21 @@ void collectMinimalIndicesForFunctionCall(
194203
++indResIdx;
195204
}
196205
}
206+
// Record all `inout` parameters as results.
207+
auto inoutParamResultIndex = calleeFnTy->getNumResults();
208+
for (auto &paramAndIdx : enumerate(calleeConvs.getParameters())) {
209+
auto &param = paramAndIdx.value();
210+
if (!param.isIndirectMutating())
211+
continue;
212+
unsigned idx = paramAndIdx.index();
213+
auto inoutArg = ai->getArgument(idx);
214+
results.push_back(inoutArg);
215+
resultIndices.push_back(inoutParamResultIndex++);
216+
}
197217
// Make sure the function call has active results.
198-
assert(results.size() == calleeFnTy->getNumResults());
218+
auto numResults = calleeFnTy->getNumResults() +
219+
calleeFnTy->getNumIndirectMutatingParameters();
220+
assert(results.size() == numResults);
199221
assert(llvm::any_of(results, [&](SILValue result) {
200222
return activityInfo.isActive(result, parentIndices);
201223
}));

lib/SILOptimizer/Utils/Differentiation/JVPEmitter.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,10 +1000,12 @@ void JVPEmitter::prepareForDifferentialGeneration() {
10001000
// Initialize tangent mapping for indirect results.
10011001
auto origIndResults = original->getIndirectResults();
10021002
auto diffIndResults = differential.getIndirectResults();
1003-
assert(origIndResults.size() == diffIndResults.size());
1004-
1003+
size_t numInoutParameters = llvm::count_if(
1004+
original->getLoweredFunctionType()->getParameters(),
1005+
[](SILParameterInfo paramInfo) { return paramInfo.isIndirectInOut(); });
1006+
assert(origIndResults.size() + numInoutParameters == diffIndResults.size());
10051007
for (auto &origBB : *original)
1006-
for (auto i : indices(diffIndResults))
1008+
for (auto i : indices(origIndResults))
10071009
setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
10081010
}
10091011

@@ -1036,14 +1038,29 @@ JVPEmitter::createEmptyDifferential(ADContext &context,
10361038
auto indices = witness->getSILAutoDiffIndices();
10371039

10381040
// Add differential results.
1039-
auto origResult = origTy->getResults()[indices.source];
1040-
origResult = origResult.getWithInterfaceType(
1041-
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
1042-
dfResults.push_back(
1043-
SILResultInfo(origResult.getInterfaceType()
1044-
->getAutoDiffTangentSpace(lookupConformance)
1045-
->getCanonicalType(),
1046-
origResult.getConvention()));
1041+
Optional<SILParameterInfo> inoutDiffParam = None;
1042+
for (auto origParam : origTy->getParameters()) {
1043+
if (!origParam.isIndirectInOut())
1044+
continue;
1045+
inoutDiffParam = origParam;
1046+
}
1047+
1048+
if (inoutDiffParam) {
1049+
dfResults.push_back(
1050+
SILResultInfo(inoutDiffParam->getInterfaceType()
1051+
->getAutoDiffTangentSpace(lookupConformance)
1052+
->getCanonicalType(),
1053+
ResultConvention::Indirect));
1054+
} else {
1055+
auto origResult = origTy->getResults()[indices.source];
1056+
origResult = origResult.getWithInterfaceType(
1057+
origResult.getInterfaceType()->getCanonicalType(witnessCanGenSig));
1058+
dfResults.push_back(
1059+
SILResultInfo(origResult.getInterfaceType()
1060+
->getAutoDiffTangentSpace(lookupConformance)
1061+
->getCanonicalType(),
1062+
origResult.getConvention()));
1063+
}
10471064

10481065
// Add differential parameters for the requested wrt parameters.
10491066
for (auto i : indices.parameters->getIndices()) {

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,13 @@ VarDecl *LinearMapInfo::addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
244244
// the same parameters and results.
245245
auto silFnTy = linearMapType.castTo<SILFunctionType>();
246246
SmallVector<AnyFunctionType::Param, 8> params;
247-
for (auto &param : silFnTy->getParameters())
248-
params.push_back(AnyFunctionType::Param(param.getInterfaceType()));
247+
for (auto &param : silFnTy->getParameters()) {
248+
ParameterTypeFlags flags;
249+
if (param.isIndirectMutating())
250+
flags = flags.withInOut(true);
251+
params.push_back(
252+
AnyFunctionType::Param(param.getInterfaceType(), Identifier(), flags));
253+
}
249254
AnyFunctionType *astFnTy;
250255
if (auto genSig = silFnTy->getSubstGenericSignature())
251256
astFnTy = GenericFunctionType::get(
@@ -283,11 +288,22 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
283288
auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
284289
return activityInfo.isActive(res, indices);
285290
});
286-
auto hasActiveArguments =
287-
llvm::any_of(ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
288-
return activityInfo.isActive(arg, indices);
289-
});
290-
if (!hasActiveResults || !hasActiveArguments)
291+
bool hasActiveInoutArgument = false;
292+
bool hasActiveArguments = false;
293+
auto numIndirectResults = ai->getNumIndirectResults();
294+
for (auto argIdx : range(ai->getSubstCalleeConv().getNumParameters())) {
295+
auto arg = ai->getArgumentsWithoutIndirectResults()[argIdx];
296+
if (activityInfo.isActive(arg, indices)) {
297+
hasActiveArguments = true;
298+
auto paramInfo = ai->getSubstCalleeConv().getParamInfoForSILArg(
299+
numIndirectResults + argIdx);
300+
if (paramInfo.isIndirectMutating())
301+
hasActiveInoutArgument = true;
302+
}
303+
}
304+
if (!hasActiveArguments)
305+
return;
306+
if (!hasActiveResults && !hasActiveInoutArgument)
291307
return;
292308

293309
// Compute differentiation result index.
@@ -323,8 +339,16 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
323339
return true;
324340
}
325341
// Check non-differentiable results.
326-
auto remappedResultType = origFnTy->getResults()[applyIndices.source]
327-
.getSILStorageInterfaceType();
342+
SILType remappedResultType;
343+
if (applyIndices.source >= origFnTy->getNumResults()) {
344+
auto inoutArgIdx = applyIndices.source - origFnTy->getNumResults();
345+
auto inoutArg =
346+
*std::next(ai->getInoutArguments().begin(), inoutArgIdx);
347+
remappedResultType = inoutArg->getType();
348+
} else {
349+
remappedResultType = origFnTy->getResults()[applyIndices.source]
350+
.getSILStorageInterfaceType();
351+
}
328352
if (!remappedResultType.isDifferentiable(derivative->getModule()))
329353
return true;
330354
return false;
@@ -338,13 +362,14 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
338362
parameters, source, derivativeFnKind, context.getTypeConverter(),
339363
LookUpConformanceInModule(derivative->getModule().getSwiftModule()));
340364

341-
auto derivativeFnResultTypes =
342-
derivativeFnType->getAllResultsInterfaceType().castTo<TupleType>();
343-
auto linearMapSILType = SILType::getPrimitiveObjectType(
344-
derivativeFnResultTypes
345-
->getElement(derivativeFnResultTypes->getElements().size() - 1)
346-
.getType()
347-
->getCanonicalType());
365+
auto derivativeFnResultTypes = derivativeFnType->getAllResultsInterfaceType();
366+
auto linearMapSILType = derivativeFnResultTypes;
367+
if (auto tupleType = derivativeFnResultTypes.getAs<TupleType>()) {
368+
linearMapSILType = SILType::getPrimitiveObjectType(
369+
tupleType->getElement(tupleType->getElements().size() - 1)
370+
.getType()
371+
->getCanonicalType());
372+
}
348373
addLinearMapDecl(ai, linearMapSILType);
349374
}
350375

@@ -396,21 +421,11 @@ void LinearMapInfo::generateDifferentiationDataStructures(
396421
for (auto &origBB : *original) {
397422
for (auto &inst : origBB) {
398423
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
399-
// Skip `apply` instructions with active `inout` arguments.
400-
// TODO(TF-129): Support `inout` argument differentiation.
401-
bool hasActiveInoutArgument =
402-
llvm::any_of(ai->getInoutArguments(), [&](SILValue inoutArg) {
403-
return activityInfo.isActive(inoutArg, indices);
404-
});
405-
if (hasActiveInoutArgument)
406-
continue;
407-
408424
// Add linear map field to struct for active `apply` instructions.
409425
// Skip array literal intrinsic applications since array literal
410426
// initialization is linear and handled separately.
411427
if (!shouldDifferentiateApplySite(ai) || isArrayLiteralIntrinsic(ai))
412428
continue;
413-
414429
LLVM_DEBUG(getADDebugStream()
415430
<< "Adding linear map struct field for " << *ai);
416431
addLinearMapToStruct(context, ai, indices);

0 commit comments

Comments
 (0)