Skip to content

Commit b402121

Browse files
committed
[WIP] [AutoDiff] Support indirect passing.
Slowly adding indirect passing support: removing old assertions, fixing latent bugs. There are no unexpected test regressions. Test programs fail during mandatory inlining with "Cannot construct Inlined loc from the given location". ``` @differentiable func generic<T : Differentiable>(_ x: T) -> T where T.TangentVector : AdditiveArithmetic, T.CotangentVector : AdditiveArithmetic { return x } print(pullback(at: Float(1), in: generic)) ``` ``` Cannot construct Inlined loc from the given location. UNREACHABLE executed at /Users/danielzheng/swift-dev/swift/lib/SIL/SILLocation.cpp:221! Stack dump: ... 1. While running pass swiftlang#46 SILModuleTransform "MandatoryInlining". ... ```
1 parent 0566e7d commit b402121

File tree

2 files changed

+77
-14
lines changed

2 files changed

+77
-14
lines changed

lib/SIL/SILLocation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ MandatoryInlinedLocation::getMandatoryInlinedLocation(SILLocation L) {
217217
if (L.isInTopLevel())
218218
return MandatoryInlinedLocation::getModuleLocation(L.getSpecialFlags());
219219

220+
llvm::errs() << "SILLocation kind " << (unsigned)L.getKind() << "\n";
220221
llvm_unreachable("Cannot construct Inlined loc from the given location.");
221222
}
222223

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,7 +2143,7 @@ class PrimalGenCloner final
21432143
auto *origExit = &*original->findReturnBB();
21442144
auto *exit = BBMap.lookup(origExit);
21452145
assert(exit->getParent() == getPrimal());
2146-
// Get the original's return value's corresponsing value in the primal.
2146+
// Get the original's return value's corresponding value in the primal.
21472147
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
21482148
auto origRetVal = origRetInst->getOperand();
21492149
auto origResInPrimal = getOpValue(origRetVal);
@@ -2158,8 +2158,12 @@ class PrimalGenCloner final
21582158
builder.setInsertionPoint(exit);
21592159
auto structLoweredTy =
21602160
getContext().getTypeConverter().getLoweredType(structTy);
2161-
auto primValsVal =
2162-
builder.createStruct(loc, structLoweredTy, primalValues);
2161+
llvm::errs() << "STRUCT LOWERED TY\n";
2162+
structLoweredTy.getASTType()->dump();
2163+
structLoweredTy.getASTType()->getCanonicalType()->dump();
2164+
// getOpType(getPrimalInfo().getPrimalValueStruct()->getDeclaredInterfaceType())->dump();
2165+
getOpASTType(getPrimalInfo().getPrimalValueStruct()->getDeclaredInterfaceType()->getCanonicalType())->dump();
2166+
auto primValsVal = builder.createStruct(loc, structLoweredTy, primalValues);
21632167
// FIXME: Handle tapes.
21642168
//
21652169
// If the original result was a tuple, return a tuple of all elements in the
@@ -2174,7 +2178,7 @@ class PrimalGenCloner final
21742178
elts.push_back(primValsVal);
21752179
for (unsigned i : range(numElts))
21762180
elts.push_back(builder.emitTupleExtract(loc, origResInPrimal, i));
2177-
retVal = builder.createTuple(loc, elts);
2181+
retVal = joinElements(elts, builder, loc);
21782182
}
21792183
// If the original result was a single value, return a tuple of the primal
21802184
// value struct value and the original result.
@@ -2493,8 +2497,11 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) {
24932497
// FIXME: If the original function has indirect differentiation
24942498
// parameters/result, bail out since AD does not support function calls with
24952499
// indirect parameters yet.
2500+
/*
24962501
if (diagnoseUnsupportedControlFlow(context, item.task) ||
24972502
diagnoseIndirectParametersOrResult(context, item.task)) {
2503+
*/
2504+
if (diagnoseUnsupportedControlFlow(context, item.task)) {
24982505
errorOccurred = true;
24992506
return true;
25002507
}
@@ -2980,13 +2987,16 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
29802987
auto *origRetBB = &*original.findReturnBB();
29812988
adjointBBMap.insert({origRetBB, adjointEntry});
29822989
SILFunctionConventions origConv(origTy, getModule());
2983-
// The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]).
2990+
// OLD: The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]).
2991+
// The adjoint function has type (seed, [indirect_results], pv) -> ([arg0], ..., [argn]).
29842992
auto adjParamArgs = getAdjoint().getArgumentsWithoutIndirectResults();
2985-
seed = adjParamArgs[0];
2986-
primalValueAggregateInAdj = adjParamArgs[1];
2993+
seed = adjParamArgs.front();
2994+
// primalValueAggregateInAdj = adjParamArgs[1];
2995+
primalValueAggregateInAdj = adjParamArgs.back();
29872996
// NOTE: Retaining `seed` below is a temporary hotfix for SR-9804.
29882997
builder.setInsertionPoint(adjointEntry);
2989-
builder.createRetainValue(adjLoc, seed, builder.getDefaultAtomicity());
2998+
if (seed->getType().isReferenceCounted(getModule()))
2999+
builder.createRetainValue(adjLoc, seed, builder.getDefaultAtomicity());
29903000

29913001
// Assign adjoint to the return value.
29923002
// y = tuple (y0, ..., yn)
@@ -3046,15 +3056,29 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
30463056
SmallVector<SILValue, 8> retElts;
30473057
auto origParams = original.getArgumentsWithoutIndirectResults();
30483058

3059+
adjoint.dump();
3060+
30493061
// Materializes the return element corresponding to the parameter
30503062
// `parameterIndex` into the `retElts` vector.
30513063
auto addRetElt = [&](unsigned parameterIndex) -> void {
30523064
auto origParam = origParams[parameterIndex];
3065+
// TODO: Find right adjoint index
3066+
auto adjParam = adjParamArgs[parameterIndex];
30533067
auto adjVal = getAdjointValue(origParam);
30543068
if (origParam->getType().isObject())
30553069
retElts.push_back(materializeAdjointDirect(adjVal, adjLoc));
3056-
else
3057-
llvm_unreachable("Unimplemented: Handle indirect pullback results");
3070+
else {
3071+
assert(origParam->getType().isAddress());
3072+
// builder.createCopyAddr(adjLoc, adjVal, origParam, IsNotTake, IsNotInitialization);
3073+
auto matAdj = materializeAdjoint(adjVal, adjLoc);
3074+
llvm::errs() << "DUMPING TYPES\n";
3075+
origParam->getType().dump();
3076+
adjParam->getType().dump();
3077+
matAdj->getType().dump();
3078+
// builder.createCopyAddr(adjLoc, matAdj, origParam, IsNotTake, IsInitialization);
3079+
builder.createCopyAddr(adjLoc, matAdj, adjParam, IsNotTake, IsInitialization);
3080+
// llvm_unreachable("Unimplemented: Handle indirect pullback results");
3081+
}
30583082
};
30593083

30603084
// The original's self parameter, if present, is the last parameter. But we
@@ -3437,6 +3461,38 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
34373461
materializeZeroIndirect(bufType.getASTType(), adjBuf, si->getLoc());
34383462
}
34393463

3464+
// Handle `copy_addr` instruction.
3465+
// Original: copy_addr x to y
3466+
// Adjoint: adj[x] += adj[y]; adj[y] = 0
3467+
void visitCopyAddrInst(CopyAddrInst *cai) {
3468+
auto adjDest =
3469+
materializeAdjoint(getAdjointValue(cai->getDest()), cai->getLoc());
3470+
auto bufType = remapType(adjDest->getType());
3471+
// builder.createLoad(<#SILLocation Loc#>, <#SILValue LV#>, <#LoadOwnershipQualifier Qualifier#>)
3472+
// auto adjVal = builder.createLoad(cai->getLoc(), adjBuf,
3473+
// getBufferLOQ(bufType.getASTType(), getAdjoint()));
3474+
// addAdjointValue(cai->getSrc(), adjVal);
3475+
auto loc = cai->getLoc();
3476+
auto adjSrc = materializeAdjoint(getAdjointValue(cai->getSrc()), cai->getLoc());
3477+
auto *resultBuf = builder.createAllocStack(loc, bufType);
3478+
auto *resultBufAccess = builder.createBeginAccess(
3479+
loc, resultBuf, SILAccessKind::Init, SILAccessEnforcement::Static,
3480+
/*noNestedConflict*/ true, /*fromBuiltin*/ false);
3481+
auto *lhsBufReadAccess = builder.createBeginAccess(loc, adjDest,
3482+
SILAccessKind::Read, SILAccessEnforcement::Static,
3483+
/*noNestedConflict*/ true, /*fromBuiltin*/ false);
3484+
auto *rhsBufReadAccess = builder.createBeginAccess(loc, adjSrc,
3485+
SILAccessKind::Read, SILAccessEnforcement::Static,
3486+
/*noNestedConflict*/ true, /*fromBuiltin*/ false);
3487+
accumulateMaterializedAdjointsIndirect(lhsBufReadAccess, rhsBufReadAccess, resultBufAccess);
3488+
builder.createEndAccess(loc, resultBufAccess, /*aborted*/ false);
3489+
builder.createEndAccess(loc, rhsBufReadAccess, /*aborted*/ false);
3490+
builder.createEndAccess(loc, lhsBufReadAccess, /*aborted*/ false);
3491+
// Deallocate the temporary result buffer.
3492+
builder.createDeallocStack(loc, resultBuf);
3493+
materializeZeroIndirect(bufType.getASTType(), adjDest, cai->getLoc());
3494+
}
3495+
34403496
// Handle `begin_access` instruction.
34413497
// Original: y = begin_access x
34423498
// Adjoint: end_access adj[y]
@@ -3552,13 +3608,14 @@ void AdjointEmitter::materializeZeroIndirect(CanType type,
35523608
assert(zeroDecl->isProtocolRequirement());
35533609
auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get);
35543610
SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func);
3555-
auto *nomTypeDecl = type->getAnyNominal();
3556-
assert(nomTypeDecl);
3611+
// auto *nomTypeDecl = type->getAnyNominal();
3612+
// assert(nomTypeDecl);
35573613
auto methodType =
35583614
getContext().getTypeConverter().getConstantType(accessorDeclRef);
35593615
// Lookup conformance to `AdditiveArithmetic`.
35603616
auto *swiftMod = getModule().getSwiftModule();
35613617
auto conf = swiftMod->lookupConformance(type, additiveArithmeticProto);
3618+
// TODO: Diagnose for non-conforming types due to SR-9595.
35623619
assert(conf.hasValue() && "No conformance to AdditiveArithmetic?");
35633620
ProtocolConformanceRef confRef(*conf);
35643621
// %wm = witness_method ...
@@ -3619,6 +3676,9 @@ SILValue AdjointEmitter::materializeAdjointDirect(AdjointValue val,
36193676
void AdjointEmitter::materializeAdjointIndirectHelper(
36203677
AdjointValue val, SILValue destBufferAccess) {
36213678
auto loc = destBufferAccess.getLoc();
3679+
llvm::errs() << "DEBUG LOC\n";
3680+
loc.dump(getModule().getSourceManager());
3681+
llvm::errs() << "\n";
36223682
auto soq = getBufferSOQ(val.getType().getASTType(), builder.getFunction());
36233683
switch (val.getKind()) {
36243684
/// Given a `%buf : *T, emit instructions that produce a zero or an aggregate
@@ -3870,11 +3930,13 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
38703930
assert(cotangentSpace && "No tangent space for this type");
38713931
switch (cotangentSpace->getKind()) {
38723932
case VectorSpace::Kind::Vector: {
3873-
auto *adjointDecl = cotangentSpace->getNominal();
3933+
// auto *adjointDecl = cotangentSpace->getNominal();
38743934
auto *proto = getContext().getAdditiveArithmeticProtocol();
38753935
auto *combinerFuncDecl = getContext().getPlusDecl();
38763936
// Call the combiner function and return.
3877-
auto adjointParentModule = adjointDecl->getModuleContext();
3937+
// auto adjointParentModule = cotangentSpace->getModuleContext();
3938+
// auto confRef = *adjointParentModule->lookupConformance(adjointASTTy, proto);
3939+
auto *adjointParentModule = getModule().getSwiftModule();
38783940
auto confRef = *adjointParentModule->lookupConformance(adjointASTTy, proto);
38793941
SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func);
38803942
auto silFnTy = getContext().getTypeConverter().getConstantType(declRef);

0 commit comments

Comments
 (0)