@@ -54,6 +54,10 @@ using llvm::SmallDenseMap;
54
54
using llvm::SmallDenseSet;
55
55
using llvm::SmallSet;
56
56
57
+ static llvm::cl::opt<bool > DifferentiationUseVJP (
58
+ " differentiation-use-vjp" , llvm::cl::init(false ),
59
+ llvm::cl::desc(" Use the VJP during differentiation" ));
60
+
57
61
// ===----------------------------------------------------------------------===//
58
62
// Helpers
59
63
// ===----------------------------------------------------------------------===//
@@ -475,6 +479,10 @@ class PrimalInfo {
475
479
// / corresponding tape of its type.
476
480
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
477
481
482
+ // / Mapping from `apply` instructions in the original function to the
483
+ // / corresponding pullback decl in the primal struct.
484
+ DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
485
+
478
486
// / Mapping from types of control-dependent nested primal values to district
479
487
// / tapes.
480
488
DenseMap<CanType, VarDecl *> nestedTapeTypeMap;
@@ -571,6 +579,24 @@ class PrimalInfo {
571
579
return decl;
572
580
}
573
581
582
+ // / Add a pullback to the primal value struct.
583
+ VarDecl *addPullbackDecl (ApplyInst *inst, Type pullbackType) {
584
+ // Decls must have AST types (not `SILFunctionType`), so we convert the
585
+ // `SILFunctionType` of the pullback to a `FunctionType` with the same
586
+ // parameters and results.
587
+ auto *silFnTy = pullbackType->castTo <SILFunctionType>();
588
+ SmallVector<AnyFunctionType::Param, 8 > params;
589
+ for (auto ¶m : silFnTy->getParameters ())
590
+ params.push_back (AnyFunctionType::Param (param.getType ()));
591
+ Type astFnTy = FunctionType::get (
592
+ params, silFnTy->getAllResultsType ().getASTType ());
593
+
594
+ auto *decl = addVarDecl (" pullback_" + llvm::itostr (pullbackValueMap.size ()),
595
+ astFnTy);
596
+ pullbackValueMap.insert ({inst, decl});
597
+ return decl;
598
+ }
599
+
574
600
// / Finds the primal value decl in the primal value struct for a static primal
575
601
// / value in the original function.
576
602
VarDecl *lookupDirectStaticPrimalValueDecl (SILValue originalValue) const {
@@ -586,6 +612,14 @@ class PrimalInfo {
586
612
: lookup->getSecond ();
587
613
}
588
614
615
+ // / Finds the pullback decl in the primal value struct for an `apply` in the
616
+ // / original function.
617
+ VarDecl *lookUpPullbackDecl (ApplyInst *inst) {
618
+ auto lookup = pullbackValueMap.find (inst);
619
+ return lookup == pullbackValueMap.end () ? nullptr
620
+ : lookup->getSecond ();
621
+ }
622
+
589
623
// / Retrieves the tape decl in the primal value struct for the specified type.
590
624
VarDecl *getOrCreateTapeDeclForType (CanType type) {
591
625
auto &astCtx = primalValueStruct->getASTContext ();
@@ -2390,11 +2424,139 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2390
2424
SILClonerWithScopes::visitReleaseValueInst (rvi);
2391
2425
}
2392
2426
2427
+ void visitApplyInst (ApplyInst *ai) {
2428
+ if (DifferentiationUseVJP)
2429
+ visitApplyInstWithVJP (ai);
2430
+ else
2431
+ visitApplyInstWithoutVJP (ai);
2432
+ }
2433
+
2434
+ void visitApplyInstWithVJP (ApplyInst *ai) {
2435
+ auto &context = getContext ();
2436
+ SILBuilder &builder = getBuilder ();
2437
+
2438
+ // Special handling logic only applies when `apply` is active. If not, just
2439
+ // do standard cloning.
2440
+ if (!activityInfo.isActive (ai, synthesis.indices )) {
2441
+ LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *ai << ' \n ' );
2442
+ SILClonerWithScopes::visitApplyInst (ai);
2443
+ return ;
2444
+ }
2445
+
2446
+ // This instruction is active. Replace it with a call to the VJP.
2447
+
2448
+ // Get the indices required for differentiating this function.
2449
+ LLVM_DEBUG (getADDebugStream () << " Primal-transforming:\n " << *ai << ' \n ' );
2450
+ SmallVector<unsigned , 8 > activeParamIndices;
2451
+ SmallVector<unsigned , 8 > activeResultIndices;
2452
+ collectMinimalIndicesForFunctionCall (ai, synthesis.indices , activityInfo,
2453
+ activeParamIndices,
2454
+ activeResultIndices);
2455
+ assert (!activeParamIndices.empty () && " Parameter indices cannot be empty" );
2456
+ assert (!activeResultIndices.empty () && " Result indices cannot be empty" );
2457
+ LLVM_DEBUG (auto &s = getADDebugStream () << " Active indices: params={" ;
2458
+ interleave (activeParamIndices.begin (), activeParamIndices.end (),
2459
+ [&s](unsigned i) { s << i; }, [&s] { s << " , " ; });
2460
+ s << " }, results={" ; interleave (
2461
+ activeResultIndices.begin (), activeResultIndices.end (),
2462
+ [&s](unsigned i) { s << i; }, [&s] { s << " , " ; });
2463
+ s << " }\n " ;);
2464
+
2465
+ // FIXME: If there are mutiple active results, we don't support it yet.
2466
+ if (activeResultIndices.size () > 1 ) {
2467
+ context.emitNondifferentiabilityError (ai, synthesis.task );
2468
+ errorOccurred = true ;
2469
+ return ;
2470
+ }
2471
+
2472
+ // Form expected indices by assuming there's only one result.
2473
+ SILAutoDiffIndices indices (activeResultIndices.front (), activeParamIndices);
2474
+
2475
+ // Retrieve the original function being called.
2476
+ auto calleeOrigin = ai->getCalleeOrigin ();
2477
+ auto *calleeOriginFnRef = dyn_cast<FunctionRefInst>(calleeOrigin);
2478
+ // If callee does not trace back to a `function_ref`, it is an opaque
2479
+ // function. Emit a "not differentiable" diagnostic here.
2480
+ // FIXME: Handle `partial_apply`, `witness_method`.
2481
+ if (!calleeOriginFnRef) {
2482
+ context.emitNondifferentiabilityError (ai, synthesis.task );
2483
+ errorOccurred = true ;
2484
+ return ;
2485
+ }
2486
+
2487
+ // Find or register a differentiation task for this function.
2488
+ auto *newTask = context.lookUpOrRegisterDifferentiationTask (
2489
+ calleeOriginFnRef->getReferencedFunction (), indices,
2490
+ /* invoker*/ {ai, synthesis.task });
2491
+
2492
+ // Store this task so that AdjointGen can use it.
2493
+ getDifferentiationTask ()->getAssociatedTasks ().insert ({ai, newTask});
2494
+
2495
+ // If the task is newly created, then we need to schedule a synthesis item
2496
+ // for the primal.
2497
+ primalGen.lookUpPrimalAndMaybeScheduleSynthesis (newTask);
2498
+
2499
+ auto *vjpFn = newTask->getVJP ();
2500
+ assert (vjpFn);
2501
+ auto *vjp = builder.createFunctionRef (ai->getCallee ().getLoc (), vjpFn);
2502
+
2503
+ // TODO: The `visitApplyInstWithoutVJP` reapplies function conversions here,
2504
+ // but all the tests seem to pass without doing that here. Investigate.
2505
+
2506
+ // Call the VJP using the original parameters.
2507
+ SmallVector<SILValue, 8 > newArgs;
2508
+ auto vjpFnTy = vjpFn->getLoweredFunctionType ();
2509
+ auto numVJPParams = vjpFnTy->getNumParameters ();
2510
+ assert (vjpFnTy->getNumIndirectFormalResults () == 0 &&
2511
+ " FIXME: handle vjp with indirect results" );
2512
+ newArgs.reserve (numVJPParams);
2513
+ // Collect substituted arguments.
2514
+ for (auto origArg : ai->getArguments ())
2515
+ newArgs.push_back (getOpValue (origArg));
2516
+ assert (newArgs.size () == numVJPParams);
2517
+ // Apply the VJP.
2518
+ auto *vjpCall = builder.createApply (ai->getLoc (), vjp,
2519
+ ai->getSubstitutionMap (), newArgs,
2520
+ ai->isNonThrowing ());
2521
+ LLVM_DEBUG (getADDebugStream ()
2522
+ << " Applied vjp function\n " << *vjpCall);
2523
+
2524
+ // Get the VJP results (original results and pullback).
2525
+ SmallVector<SILValue, 8 > vjpDirectResults;
2526
+ extractAllElements (vjpCall, builder, vjpDirectResults);
2527
+ ArrayRef<SILValue> originalDirectResults =
2528
+ ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
2529
+ SILValue originalDirectResult = joinElements (originalDirectResults,
2530
+ builder,
2531
+ vjpCall->getLoc ());
2532
+ SILValue pullback = vjpDirectResults.back ();
2533
+
2534
+ // Store the original result to the value map.
2535
+ ValueMap.insert ({ai, originalDirectResult});
2536
+
2537
+ // Checkpoint the original results.
2538
+ getPrimalInfo ().addStaticPrimalValueDecl (ai);
2539
+ staticPrimalValues.push_back (originalDirectResult);
2540
+
2541
+ // Checkpoint the pullback.
2542
+ getPrimalInfo ().addPullbackDecl (ai, pullback->getType ().getASTType ());
2543
+ staticPrimalValues.push_back (pullback);
2544
+
2545
+ // Some instructions that produce the callee may have been cloned.
2546
+ // If the original callee did not have any users beyond this `apply`,
2547
+ // recursively kill the cloned callee.
2548
+ if (auto *origCallee = cast_or_null<SingleValueInstruction>(
2549
+ ai->getCallee ()->getDefiningInstruction ()))
2550
+ if (origCallee->hasOneUse ())
2551
+ recursivelyDeleteTriviallyDeadInstructions (
2552
+ getOpValue (origCallee)->getDefiningInstruction ());
2553
+ }
2554
+
2393
2555
// / Handle the primal transformation of an `apply` instruction. We do not
2394
2556
// / always transform `apply`. When we do, we do not just blindly differentiate
2395
2557
// / from all results w.r.t. all parameters. Instead, we let activity analysis
2396
2558
// / decide whether to transform and what differentiation indices to use.
2397
- void visitApplyInst (ApplyInst *ai) {
2559
+ void visitApplyInstWithoutVJP (ApplyInst *ai) {
2398
2560
// Special handling logic only applies when `apply` is active. If not, just
2399
2561
// do standard cloning.
2400
2562
if (!activityInfo.isActive (ai, synthesis.indices )) {
@@ -3292,9 +3454,110 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3292
3454
return rematCloner.getMappedValue (value);
3293
3455
}
3294
3456
3457
+ void visitApplyInst (ApplyInst *ai) {
3458
+ if (DifferentiationUseVJP)
3459
+ visitApplyInstWithVJP (ai);
3460
+ else
3461
+ visitApplyInstWithoutVJP (ai);
3462
+ }
3463
+
3464
+ void visitApplyInstWithVJP (ApplyInst *ai) {
3465
+ // Replace a call to a function with a call to its pullback.
3466
+
3467
+ auto &builder = getBuilder ();
3468
+ auto loc = remapLocation (ai->getLoc ());
3469
+
3470
+ // Look for the task that differentiates the callee.
3471
+ auto &assocTasks = getDifferentiationTask ()->getAssociatedTasks ();
3472
+ auto assocTaskLookUp = assocTasks.find (ai);
3473
+ // If no task was found, then this task doesn't need to be differentiated.
3474
+ if (assocTaskLookUp == assocTasks.end ()) {
3475
+ // Must not be active.
3476
+ assert (
3477
+ !activityInfo.isActive (ai, getDifferentiationTask ()->getIndices ()));
3478
+ return ;
3479
+ }
3480
+ auto *otherTask = assocTaskLookUp->getSecond ();
3481
+ auto origTy = otherTask->getOriginal ()->getLoweredFunctionType ();
3482
+ SILFunctionConventions origConvs (origTy, getModule ());
3483
+
3484
+ // Get the pullback.
3485
+ auto *field = getPrimalInfo ().lookUpPullbackDecl (ai);
3486
+ assert (field);
3487
+ SILValue pullback = builder.createStructExtract (remapLocation (ai->getLoc ()),
3488
+ primalValueAggregateInAdj,
3489
+ field);
3490
+
3491
+ // Construct the pullback arguments.
3492
+ SmallVector<SILValue, 8 > args;
3493
+ auto seed = getAdjointValue (ai);
3494
+ auto *seedBuf = getBuilder ().createAllocStack (loc, seed.getType ());
3495
+ materializeAdjointIndirect (seed, seedBuf);
3496
+ if (seed.getType ().isAddressOnly (getModule ()))
3497
+ args.push_back (seedBuf);
3498
+ else {
3499
+ auto access = getBuilder ().createBeginAccess (
3500
+ loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3501
+ /* noNestedConflict*/ true ,
3502
+ /* fromBuiltin*/ false );
3503
+ args.push_back (getBuilder ().createLoad (
3504
+ loc, access, getBufferLOQ (seed.getSwiftType (), getAdjoint ())));
3505
+ getBuilder ().createEndAccess (loc, access, /* aborted*/ false );
3506
+ }
3507
+
3508
+ // Call the pullback.
3509
+ auto *pullbackCall = builder.createApply (ai->getLoc (), pullback,
3510
+ SubstitutionMap (), args,
3511
+ /* isNonThrowing*/ false );
3512
+
3513
+ // Clean up seed allocation.
3514
+ getBuilder ().createDeallocStack (loc, seedBuf);
3515
+
3516
+ // If `pullbackCall` is a tuple, extract all results.
3517
+ SmallVector<SILValue, 8 > dirResults;
3518
+ extractAllElements (pullbackCall, builder, dirResults);
3519
+ // Get all results in type-defined order.
3520
+ SmallVector<SILValue, 8 > allResults;
3521
+ collectAllActualResultsInTypeOrder (
3522
+ pullbackCall, dirResults, pullbackCall->getIndirectSILResults (),
3523
+ allResults);
3524
+ LLVM_DEBUG ({
3525
+ auto &s = getADDebugStream ();
3526
+ s << " All direct results of the nested pullback call: \n " ;
3527
+ llvm::for_each (dirResults, [&](SILValue v) { s << v; });
3528
+ s << " All indirect results of the nested pullback call: \n " ;
3529
+ llvm::for_each (pullbackCall->getIndirectSILResults (),
3530
+ [&](SILValue v) { s << v; });
3531
+ s << " All results of the nested pullback call: \n " ;
3532
+ llvm::for_each (allResults, [&](SILValue v) { s << v; });
3533
+ });
3534
+
3535
+ // Set adjoints for all original parameters.
3536
+ auto originalParams = ai->getArgumentsWithoutIndirectResults ();
3537
+ auto origNumIndRes = origConvs.getNumIndirectSILResults ();
3538
+ auto allResultsIt = allResults.begin ();
3539
+ // If the applied adjoint returns the adjoint of the original self
3540
+ // parameter, then it returns it first. Set the adjoint of the original
3541
+ // self parameter.
3542
+ auto selfParamIndex = originalParams.size () - 1 ;
3543
+ if (ai->hasSelfArgument () &&
3544
+ otherTask->getIndices ().isWrtParameter (selfParamIndex))
3545
+ addAdjointValue (ai->getArgument (origNumIndRes + selfParamIndex),
3546
+ AdjointValue::getMaterialized (*allResultsIt++));
3547
+ // Set adjoints for the remaining non-self original parameters.
3548
+ for (unsigned i : otherTask->getIndices ().parameters .set_bits ()) {
3549
+ // Do not set the adjoint of the original self parameter because we
3550
+ // already added it at the beginning.
3551
+ if (ai->hasSelfArgument () && i == selfParamIndex)
3552
+ continue ;
3553
+ addAdjointValue (ai->getArgument (origNumIndRes + i),
3554
+ AdjointValue::getMaterialized (*allResultsIt++));
3555
+ }
3556
+ }
3557
+
3295
3558
// / Handle `apply` instruction. If it's active (on the differentiation path),
3296
3559
// / we replace it with its corresponding adjoint.
3297
- void visitApplyInst (ApplyInst *ai) {
3560
+ void visitApplyInstWithoutVJP (ApplyInst *ai) {
3298
3561
// Replace a call to the function with a call to its adjoint.
3299
3562
auto &assocTasks = getDifferentiationTask ()->getAssociatedTasks ();
3300
3563
auto assocTaskLookUp = assocTasks.find (ai);
0 commit comments