@@ -2143,7 +2143,7 @@ class PrimalGenCloner final
2143
2143
auto *origExit = &*original->findReturnBB ();
2144
2144
auto *exit = BBMap.lookup (origExit);
2145
2145
assert (exit->getParent () == getPrimal ());
2146
- // Get the original's return value's corresponsing value in the primal.
2146
+ // Get the original's return value's corresponding value in the primal.
2147
2147
auto *origRetInst = cast<ReturnInst>(origExit->getTerminator ());
2148
2148
auto origRetVal = origRetInst->getOperand ();
2149
2149
auto origResInPrimal = getOpValue (origRetVal);
@@ -2158,8 +2158,12 @@ class PrimalGenCloner final
2158
2158
builder.setInsertionPoint (exit);
2159
2159
auto structLoweredTy =
2160
2160
getContext ().getTypeConverter ().getLoweredType (structTy);
2161
- auto primValsVal =
2162
- builder.createStruct (loc, structLoweredTy, primalValues);
2161
+ llvm::errs () << " STRUCT LOWERED TY\n " ;
2162
+ structLoweredTy.getASTType ()->dump ();
2163
+ structLoweredTy.getASTType ()->getCanonicalType ()->dump ();
2164
+ // getOpType(getPrimalInfo().getPrimalValueStruct()->getDeclaredInterfaceType())->dump();
2165
+ getOpASTType (getPrimalInfo ().getPrimalValueStruct ()->getDeclaredInterfaceType ()->getCanonicalType ())->dump ();
2166
+ auto primValsVal = builder.createStruct (loc, structLoweredTy, primalValues);
2163
2167
// FIXME: Handle tapes.
2164
2168
//
2165
2169
// If the original result was a tuple, return a tuple of all elements in the
@@ -2174,7 +2178,7 @@ class PrimalGenCloner final
2174
2178
elts.push_back (primValsVal);
2175
2179
for (unsigned i : range (numElts))
2176
2180
elts.push_back (builder.emitTupleExtract (loc, origResInPrimal, i));
2177
- retVal = builder. createTuple (loc, elts );
2181
+ retVal = joinElements (elts, builder, loc );
2178
2182
}
2179
2183
// If the original result was a single value, return a tuple of the primal
2180
2184
// value struct value and the original result.
@@ -2493,8 +2497,11 @@ bool PrimalGen::performSynthesis(FunctionSynthesisItem item) {
2493
2497
// FIXME: If the original function has indirect differentiation
2494
2498
// parameters/result, bail out since AD does not support function calls with
2495
2499
// indirect parameters yet.
2500
+ /*
2496
2501
if (diagnoseUnsupportedControlFlow(context, item.task) ||
2497
2502
diagnoseIndirectParametersOrResult(context, item.task)) {
2503
+ */
2504
+ if (diagnoseUnsupportedControlFlow (context, item.task )) {
2498
2505
errorOccurred = true ;
2499
2506
return true ;
2500
2507
}
@@ -2980,13 +2987,16 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
2980
2987
auto *origRetBB = &*original.findReturnBB ();
2981
2988
adjointBBMap.insert ({origRetBB, adjointEntry});
2982
2989
SILFunctionConventions origConv (origTy, getModule ());
2983
- // The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]).
2990
+ // OLD: The adjoint function has type (seed, pv) -> ([arg0], ..., [argn]).
2991
+ // The adjoint function has type (seed, [indirect_results], pv) -> ([arg0], ..., [argn]).
2984
2992
auto adjParamArgs = getAdjoint ().getArgumentsWithoutIndirectResults ();
2985
- seed = adjParamArgs[0 ];
2986
- primalValueAggregateInAdj = adjParamArgs[1 ];
2993
+ seed = adjParamArgs.front ();
2994
+ // primalValueAggregateInAdj = adjParamArgs[1];
2995
+ primalValueAggregateInAdj = adjParamArgs.back ();
2987
2996
// NOTE: Retaining `seed` below is a temporary hotfix for SR-9804.
2988
2997
builder.setInsertionPoint (adjointEntry);
2989
- builder.createRetainValue (adjLoc, seed, builder.getDefaultAtomicity ());
2998
+ if (seed->getType ().isReferenceCounted (getModule ()))
2999
+ builder.createRetainValue (adjLoc, seed, builder.getDefaultAtomicity ());
2990
3000
2991
3001
// Assign adjoint to the return value.
2992
3002
// y = tuple (y0, ..., yn)
@@ -3046,15 +3056,29 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3046
3056
SmallVector<SILValue, 8 > retElts;
3047
3057
auto origParams = original.getArgumentsWithoutIndirectResults ();
3048
3058
3059
+ adjoint.dump ();
3060
+
3049
3061
// Materializes the return element corresponding to the parameter
3050
3062
// `parameterIndex` into the `retElts` vector.
3051
3063
auto addRetElt = [&](unsigned parameterIndex) -> void {
3052
3064
auto origParam = origParams[parameterIndex];
3065
+ // TODO: Find right adjoint index
3066
+ auto adjParam = adjParamArgs[parameterIndex];
3053
3067
auto adjVal = getAdjointValue (origParam);
3054
3068
if (origParam->getType ().isObject ())
3055
3069
retElts.push_back (materializeAdjointDirect (adjVal, adjLoc));
3056
- else
3057
- llvm_unreachable (" Unimplemented: Handle indirect pullback results" );
3070
+ else {
3071
+ assert (origParam->getType ().isAddress ());
3072
+ // builder.createCopyAddr(adjLoc, adjVal, origParam, IsNotTake, IsNotInitialization);
3073
+ auto matAdj = materializeAdjoint (adjVal, adjLoc);
3074
+ llvm::errs () << " DUMPING TYPES\n " ;
3075
+ origParam->getType ().dump ();
3076
+ adjParam->getType ().dump ();
3077
+ matAdj->getType ().dump ();
3078
+ // builder.createCopyAddr(adjLoc, matAdj, origParam, IsNotTake, IsInitialization);
3079
+ builder.createCopyAddr (adjLoc, matAdj, adjParam, IsNotTake, IsInitialization);
3080
+ // llvm_unreachable("Unimplemented: Handle indirect pullback results");
3081
+ }
3058
3082
};
3059
3083
3060
3084
// The original's self parameter, if present, is the last parameter. But we
@@ -3437,6 +3461,38 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3437
3461
materializeZeroIndirect (bufType.getASTType (), adjBuf, si->getLoc ());
3438
3462
}
3439
3463
3464
+ // Handle `copy_addr` instruction.
3465
+ // Original: copy_addr x to y
3466
+ // Adjoint: adj[x] += adj[y]; adj[y] = 0
3467
+ void visitCopyAddrInst (CopyAddrInst *cai) {
3468
+ auto adjDest =
3469
+ materializeAdjoint (getAdjointValue (cai->getDest ()), cai->getLoc ());
3470
+ auto bufType = remapType (adjDest->getType ());
3471
+ // builder.createLoad(<#SILLocation Loc#>, <#SILValue LV#>, <#LoadOwnershipQualifier Qualifier#>)
3472
+ // auto adjVal = builder.createLoad(cai->getLoc(), adjBuf,
3473
+ // getBufferLOQ(bufType.getASTType(), getAdjoint()));
3474
+ // addAdjointValue(cai->getSrc(), adjVal);
3475
+ auto loc = cai->getLoc ();
3476
+ auto adjSrc = materializeAdjoint (getAdjointValue (cai->getSrc ()), cai->getLoc ());
3477
+ auto *resultBuf = builder.createAllocStack (loc, bufType);
3478
+ auto *resultBufAccess = builder.createBeginAccess (
3479
+ loc, resultBuf, SILAccessKind::Init, SILAccessEnforcement::Static,
3480
+ /* noNestedConflict*/ true , /* fromBuiltin*/ false );
3481
+ auto *lhsBufReadAccess = builder.createBeginAccess (loc, adjDest,
3482
+ SILAccessKind::Read, SILAccessEnforcement::Static,
3483
+ /* noNestedConflict*/ true , /* fromBuiltin*/ false );
3484
+ auto *rhsBufReadAccess = builder.createBeginAccess (loc, adjSrc,
3485
+ SILAccessKind::Read, SILAccessEnforcement::Static,
3486
+ /* noNestedConflict*/ true , /* fromBuiltin*/ false );
3487
+ accumulateMaterializedAdjointsIndirect (lhsBufReadAccess, rhsBufReadAccess, resultBufAccess);
3488
+ builder.createEndAccess (loc, resultBufAccess, /* aborted*/ false );
3489
+ builder.createEndAccess (loc, rhsBufReadAccess, /* aborted*/ false );
3490
+ builder.createEndAccess (loc, lhsBufReadAccess, /* aborted*/ false );
3491
+ // Deallocate the temporary result buffer.
3492
+ builder.createDeallocStack (loc, resultBuf);
3493
+ materializeZeroIndirect (bufType.getASTType (), adjDest, cai->getLoc ());
3494
+ }
3495
+
3440
3496
// Handle `begin_access` instruction.
3441
3497
// Original: y = begin_access x
3442
3498
// Adjoint: end_access adj[y]
@@ -3552,13 +3608,14 @@ void AdjointEmitter::materializeZeroIndirect(CanType type,
3552
3608
assert (zeroDecl->isProtocolRequirement ());
3553
3609
auto *accessorDecl = zeroDecl->getAccessor (AccessorKind::Get);
3554
3610
SILDeclRef accessorDeclRef (accessorDecl, SILDeclRef::Kind::Func);
3555
- auto *nomTypeDecl = type->getAnyNominal ();
3556
- assert (nomTypeDecl);
3611
+ // auto *nomTypeDecl = type->getAnyNominal();
3612
+ // assert(nomTypeDecl);
3557
3613
auto methodType =
3558
3614
getContext ().getTypeConverter ().getConstantType (accessorDeclRef);
3559
3615
// Lookup conformance to `AdditiveArithmetic`.
3560
3616
auto *swiftMod = getModule ().getSwiftModule ();
3561
3617
auto conf = swiftMod->lookupConformance (type, additiveArithmeticProto);
3618
+ // TODO: Diagnose for non-conforming types due to SR-9595.
3562
3619
assert (conf.hasValue () && " No conformance to AdditiveArithmetic?" );
3563
3620
ProtocolConformanceRef confRef (*conf);
3564
3621
// %wm = witness_method ...
@@ -3619,6 +3676,9 @@ SILValue AdjointEmitter::materializeAdjointDirect(AdjointValue val,
3619
3676
void AdjointEmitter::materializeAdjointIndirectHelper (
3620
3677
AdjointValue val, SILValue destBufferAccess) {
3621
3678
auto loc = destBufferAccess.getLoc ();
3679
+ llvm::errs () << " DEBUG LOC\n " ;
3680
+ loc.dump (getModule ().getSourceManager ());
3681
+ llvm::errs () << " \n " ;
3622
3682
auto soq = getBufferSOQ (val.getType ().getASTType (), builder.getFunction ());
3623
3683
switch (val.getKind ()) {
3624
3684
// / Given a `%buf : *T, emit instructions that produce a zero or an aggregate
@@ -3870,11 +3930,13 @@ void AdjointEmitter::accumulateMaterializedAdjointsIndirect(
3870
3930
assert (cotangentSpace && " No tangent space for this type" );
3871
3931
switch (cotangentSpace->getKind ()) {
3872
3932
case VectorSpace::Kind::Vector: {
3873
- auto *adjointDecl = cotangentSpace->getNominal ();
3933
+ // auto *adjointDecl = cotangentSpace->getNominal();
3874
3934
auto *proto = getContext ().getAdditiveArithmeticProtocol ();
3875
3935
auto *combinerFuncDecl = getContext ().getPlusDecl ();
3876
3936
// Call the combiner function and return.
3877
- auto adjointParentModule = adjointDecl->getModuleContext ();
3937
+ // auto adjointParentModule = cotangentSpace->getModuleContext();
3938
+ // auto confRef = *adjointParentModule->lookupConformance(adjointASTTy, proto);
3939
+ auto *adjointParentModule = getModule ().getSwiftModule ();
3878
3940
auto confRef = *adjointParentModule->lookupConformance (adjointASTTy, proto);
3879
3941
SILDeclRef declRef (combinerFuncDecl, SILDeclRef::Kind::Func);
3880
3942
auto silFnTy = getContext ().getTypeConverter ().getConstantType (declRef);
0 commit comments