19
19
#define DEBUG_TYPE " differentiation"
20
20
21
21
#include " swift/SILOptimizer/Utils/Differentiation/PullbackEmitter.h"
22
+ #include " swift/SIL/InstructionUtils.h"
22
23
#include " swift/SIL/Projection.h"
23
24
#include " swift/SILOptimizer/PassManager/PrettyStackTrace.h"
24
25
#include " swift/SILOptimizer/Utils/Differentiation/ADContext.h"
@@ -351,6 +352,43 @@ SILValue PullbackEmitter::getAdjointProjection(SILBasicBlock *origBB,
351
352
}
352
353
return builder.createTupleElementAddr (teai->getLoc (), adjSource, adjIndex);
353
354
}
355
+ // Handle `ref_element_addr`.
356
+ // Adjoint projection: a local allocation initialized with the corresponding
357
+ // field value from the class's base adjoint value.
358
+ if (auto *reai = dyn_cast<RefElementAddrInst>(originalProjection)) {
359
+ assert (!reai->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>() &&
360
+ " `@noDerivative` class projections should never be active" );
361
+ auto loc = reai->getLoc ();
362
+ // Get the class operand, stripping `begin_borrow`.
363
+ auto classOperand = stripBorrow (reai->getOperand ());
364
+ // Get the class operand's adjoint value. Currently, it must be a
365
+ // `TangentVector` struct.
366
+ auto adjClass =
367
+ materializeAdjointDirect (getAdjointValue (origBB, classOperand), loc);
368
+ auto *tangentVectorDecl =
369
+ adjClass->getType ().getStructOrBoundGenericStruct ();
370
+ // TODO(TF-970): Replace assertions below with diagnostics.
371
+ assert (tangentVectorDecl && " `TangentVector` of a class must be a struct" );
372
+ auto tanFieldLookup =
373
+ tangentVectorDecl->lookupDirect (reai->getField ()->getName ());
374
+ assert (tanFieldLookup.size () == 1 &&
375
+ " Class `TangentVector` must have field of the same name" );
376
+ auto *tanField = cast<VarDecl>(tanFieldLookup.front ());
377
+ // Create a local allocation for the element adjoint buffer.
378
+ auto eltTanType = tanField->getValueInterfaceType ()->getCanonicalType ();
379
+ auto eltTanSILType = SILType::getPrimitiveAddressType (eltTanType);
380
+ auto *eltAdjBuffer = createFunctionLocalAllocation (eltTanSILType, loc);
381
+ builder.emitScopedBorrowOperation (
382
+ loc, adjClass, [&](SILValue borrowedAdjClass) {
383
+ // Initialize the element adjoint buffer with the base adjoint value.
384
+ auto *adjElt =
385
+ builder.createStructExtract (loc, borrowedAdjClass, tanField);
386
+ auto adjEltCopy = builder.emitCopyValueOperation (loc, adjElt);
387
+ builder.emitStoreValueOperation (loc, adjEltCopy, eltAdjBuffer,
388
+ StoreOwnershipQualifier::Init);
389
+ });
390
+ return eltAdjBuffer;
391
+ }
354
392
// Handle `begin_access`.
355
393
// Adjoint projection: the base adjoint buffer itself.
356
394
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
@@ -412,6 +450,20 @@ PullbackEmitter::getNextFunctionLocalAllocationInsertionPoint() {
412
450
return lastLocalAlloc->getDefiningInstruction ()->getIterator ();
413
451
}
414
452
453
+ AllocStackInst *
454
+ PullbackEmitter::createFunctionLocalAllocation (SILType type, SILLocation loc) {
455
+ // Set insertion point for local allocation builder: before the last local
456
+ // allocation, or at the start of the pullback function's entry if no local
457
+ // allocations exist yet.
458
+ localAllocBuilder.setInsertionPoint (
459
+ getPullback ().getEntryBlock (),
460
+ getNextFunctionLocalAllocationInsertionPoint ());
461
+ // Create and return local allocation.
462
+ auto *alloc = localAllocBuilder.createAllocStack (loc, type);
463
+ functionLocalAllocations.push_back (alloc);
464
+ return alloc;
465
+ }
466
+
415
467
SILValue &PullbackEmitter::getAdjointBuffer (SILBasicBlock *origBB,
416
468
SILValue originalBuffer) {
417
469
assert (originalBuffer->getType ().isAddress ());
@@ -425,25 +477,19 @@ SILValue &PullbackEmitter::getAdjointBuffer(SILBasicBlock *origBB,
425
477
if (auto adjProj = getAdjointProjection (origBB, originalBuffer))
426
478
return (bufferMap[{origBB, originalBuffer}] = adjProj);
427
479
480
+ auto bufObjectType = getRemappedTangentType (originalBuffer->getType ());
428
481
// Set insertion point for local allocation builder: before the last local
429
482
// allocation, or at the start of the pullback function's entry if no local
430
483
// allocations exist yet.
431
- localAllocBuilder.setInsertionPoint (
432
- getPullback ().getEntryBlock (),
433
- getNextFunctionLocalAllocationInsertionPoint ());
434
- // Allocate local buffer and initialize to zero.
435
- auto bufObjectType = getRemappedTangentType (originalBuffer->getType ());
436
- auto *newBuf = localAllocBuilder.createAllocStack (
437
- RegularLocation::getAutoGeneratedLocation (), bufObjectType);
484
+ auto *newBuf = createFunctionLocalAllocation (
485
+ bufObjectType, RegularLocation::getAutoGeneratedLocation ());
438
486
// Temporarily change global builder insertion point and emit zero into the
439
- // local buffer .
487
+ // local allocation .
440
488
auto insertionPoint = builder.getInsertionBB ();
441
489
builder.setInsertionPoint (localAllocBuilder.getInsertionBB (),
442
490
localAllocBuilder.getInsertionPoint ());
443
491
emitZeroIndirect (bufObjectType.getASTType (), newBuf, newBuf->getLoc ());
444
492
builder.setInsertionPoint (insertionPoint);
445
- // Register the local buffer.
446
- functionLocalAllocations.push_back (newBuf);
447
493
return (insertion.first ->getSecond () = newBuf);
448
494
}
449
495
@@ -1082,13 +1128,9 @@ PullbackEmitter::getArrayAdjointElementBuffer(SILValue arrayAdjoint,
1082
1128
{addArithConf, diffConf});
1083
1129
// %elt_adj = alloc_stack $T.TangentVector
1084
1130
// Create and register a local allocation.
1085
- localAllocBuilder.setInsertionPoint (
1086
- getPullback ().getEntryBlock (),
1087
- getNextFunctionLocalAllocationInsertionPoint ());
1088
- auto *eltAdjBuffer = localAllocBuilder.createAllocStack (loc, eltTanSILType);
1089
- functionLocalAllocations.push_back (eltAdjBuffer);
1131
+ auto *eltAdjBuffer = createFunctionLocalAllocation (eltTanSILType, loc);
1090
1132
// Temporarily change global builder insertion point and emit zero into the
1091
- // local buffer .
1133
+ // local allocation .
1092
1134
auto insertionPoint = builder.getInsertionBB ();
1093
1135
builder.setInsertionPoint (localAllocBuilder.getInsertionBB (),
1094
1136
localAllocBuilder.getInsertionPoint ());
0 commit comments