@@ -3519,15 +3519,15 @@ class AdjointValueBase {
3519
3519
3520
3520
// / The underlying value.
3521
3521
union Value {
3522
- MutableArrayRef <AdjointValue> aggregate;
3522
+ ArrayRef <AdjointValue> aggregate;
3523
3523
ValueWithCleanup concrete;
3524
- Value (MutableArrayRef <AdjointValue> v) : aggregate (v) {}
3524
+ Value (ArrayRef <AdjointValue> v) : aggregate (v) {}
3525
3525
Value (ValueWithCleanup v) : concrete (v) {}
3526
3526
Value () {}
3527
3527
} value;
3528
3528
3529
3529
explicit AdjointValueBase (SILType type,
3530
- MutableArrayRef <AdjointValue> aggregate)
3530
+ ArrayRef <AdjointValue> aggregate)
3531
3531
: kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}
3532
3532
3533
3533
explicit AdjointValueBase (ValueWithCleanup v)
@@ -3591,11 +3591,15 @@ class AdjointValue final {
3591
3591
return base->value .aggregate .size ();
3592
3592
}
3593
3593
3594
- AdjointValue takeAggregateElement (unsigned i) {
3594
+ AdjointValue getAggregateElement (unsigned i) const {
3595
3595
assert (isAggregate ());
3596
3596
return base->value .aggregate [i];
3597
3597
}
3598
3598
3599
+ ArrayRef<AdjointValue> getAggregateElements () const {
3600
+ return base->value .aggregate ;
3601
+ }
3602
+
3599
3603
ValueWithCleanup getConcreteValue () const {
3600
3604
assert (isConcrete ());
3601
3605
return base->value .concrete ;
@@ -3684,6 +3688,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3684
3688
// / Mapping from original basic blocks to dominated active values.
3685
3689
DenseMap<SILBasicBlock *, SmallVector<SILValue, 8 >> activeValues;
3686
3690
3691
+ // / Local adjoint values to be cleaned up. This is populated when adjoint
3692
+ // / emission is run on one basic block and cleaned before processing another
3693
+ // / basic block.
3694
+ SmallVector<AdjointValue, 8 > blockLocalAdjointValues;
3695
+
3687
3696
// / Mapping from original basic blocks and original active values to
3688
3697
// / corresponding adjoint block arguments.
3689
3698
DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
@@ -3758,7 +3767,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3758
3767
AdjointValue makeAggregateAdjointValue (SILType type, EltRange elements);
3759
3768
3760
3769
// --------------------------------------------------------------------------//
3761
- // Managed value materializers
3770
+ // Symbolic value materializers
3762
3771
// --------------------------------------------------------------------------//
3763
3772
3764
3773
// / Materialize an adjoint value. The type of the given adjoint value must be
@@ -3777,7 +3786,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3777
3786
AdjointValue val, ValueWithCleanup &destBufferAccess);
3778
3787
3779
3788
// --------------------------------------------------------------------------//
3780
- // Helpers for managed value materializers
3789
+ // Helpers for symbolic value materializers
3781
3790
// --------------------------------------------------------------------------//
3782
3791
3783
3792
// / Emit a zero value by calling `AdditiveArithmetic.zero`. The given type
@@ -3788,6 +3797,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3788
3797
// / must conform to `AdditiveArithmetic` and be loadable in SIL.
3789
3798
SILValue emitZeroDirect (CanType type, SILLocation loc);
3790
3799
3800
+ // --------------------------------------------------------------------------//
3801
+ // Memory cleanup tools
3802
+ // --------------------------------------------------------------------------//
3803
+
3804
+ void emitCleanupForAdjointValue (AdjointValue value);
3805
+
3791
3806
// --------------------------------------------------------------------------//
3792
3807
// Accumulator
3793
3808
// --------------------------------------------------------------------------//
@@ -3899,8 +3914,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3899
3914
auto it = insertion.first ;
3900
3915
auto &&existingValue = it->getSecond ();
3901
3916
valueMap.erase (it);
3902
- initializeAdjointValue (origBB, originalValue,
3903
- accumulateAdjointsDirect (existingValue, newAdjointValue));
3917
+ auto adjVal = accumulateAdjointsDirect (existingValue, newAdjointValue);
3918
+ initializeAdjointValue (origBB, originalValue, adjVal);
3919
+ blockLocalAdjointValues.push_back (adjVal);
3904
3920
}
3905
3921
3906
3922
// / Get the adjoint block argument corresponding to the given original block
@@ -4378,7 +4394,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4378
4394
// Emit cleanups for children.
4379
4395
if (auto *cleanup = concreteActiveValueAdj.getCleanup ()) {
4380
4396
cleanup->disable ();
4381
- cleanup->applyRecursively (builder, activeValue. getLoc () );
4397
+ cleanup->applyRecursively (builder, adjLoc );
4382
4398
}
4383
4399
trampolineArguments.push_back (concreteActiveValueAdj);
4384
4400
// If the adjoint block does not yet have a registered adjoint
@@ -4415,11 +4431,15 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4415
4431
getPullbackInfo ().lookUpPredecessorEnumElement (predBB, bb);
4416
4432
adjointSuccessorCases.push_back ({enumEltDecl, adjointSuccBB});
4417
4433
}
4418
- assert (adjointSuccessorCases.size () == predEnum->getNumElements ());
4434
+ // Emit clenaups for all block-local adjoint values.
4435
+ for (auto adjVal : blockLocalAdjointValues)
4436
+ emitCleanupForAdjointValue (adjVal);
4437
+ blockLocalAdjointValues.clear ();
4419
4438
// - If the original block has exactly one predecessor, then the adjoint
4420
4439
// block has exactly one successor. Extract the pullback struct value
4421
4440
// from the predecessor enum value using `unchecked_enum_data` and
4422
4441
// branch to the adjoint successor block.
4442
+ assert (adjointSuccessorCases.size () == predEnum->getNumElements ());
4423
4443
if (adjointSuccessorCases.size () == 1 ) {
4424
4444
auto *predBB = bb->getSinglePredecessorBlock ();
4425
4445
assert (predBB);
@@ -4479,9 +4499,13 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4479
4499
indParamAdjoints.push_back (adjBuf);
4480
4500
}
4481
4501
};
4482
- // Accumulate differentiation parameter adjoints.
4502
+ // Collect differentiation parameter adjoints.
4483
4503
for (auto i : getIndices ().parameters ->getIndices ())
4484
4504
addRetElt (i);
4505
+ // Emit cleanups for all local values.
4506
+ for (auto adjVal : blockLocalAdjointValues)
4507
+ emitCleanupForAdjointValue (adjVal);
4508
+ blockLocalAdjointValues.clear ();
4485
4509
4486
4510
// Disable cleanup for original indirect parameter adjoint buffers.
4487
4511
// Copy them to adjoint indirect results.
@@ -4667,8 +4691,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4667
4691
builder.createDeallocStack (loc, tmpBuf);
4668
4692
}
4669
4693
else
4670
- addAdjointValue (bb, origArg, makeConcreteAdjointValue (ValueWithCleanup (
4671
- tan, makeCleanup (tan, emitCleanup, {seed.getCleanup ()}))));
4694
+ addAdjointValue (bb, origArg, makeConcreteAdjointValue (
4695
+ ValueWithCleanup (tan,
4696
+ makeCleanup (tan, emitCleanup, {seed.getCleanup ()}))));
4672
4697
}
4673
4698
}
4674
4699
// Deallocate pullback indirect results.
@@ -4857,7 +4882,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4857
4882
for (auto i : range (ti->getElements ().size ())) {
4858
4883
if (!getTangentSpace (ti->getElement (i)->getType ().getASTType ()))
4859
4884
continue ;
4860
- addAdjointValue (bb, ti->getElement (i), av.takeAggregateElement (adjIdx++));
4885
+ addAdjointValue (bb, ti->getElement (i), av.getAggregateElement (adjIdx++));
4861
4886
}
4862
4887
break ;
4863
4888
}
@@ -4964,18 +4989,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4964
4989
addAdjointValue (bb, si->getSrc (), makeConcreteAdjointValue (
4965
4990
ValueWithCleanup (adjVal, valueCleanup)));
4966
4991
// Set the buffer to zero, with a cleanup.
4967
- auto *bai = dyn_cast<BeginAccessInst>(adjBuf.getValue ());
4968
- if (bai && !(bai->getAccessKind () == SILAccessKind::Modify ||
4969
- bai->getAccessKind () == SILAccessKind::Init)) {
4970
- auto *modifyAccess = builder.createBeginAccess (
4971
- si->getLoc (), bai->getSource (), SILAccessKind::Modify,
4972
- SILAccessEnforcement::Static, /* noNestedConflict*/ true ,
4973
- /* fromBuiltin*/ false );
4974
- emitZeroIndirect (bufType.getASTType (), modifyAccess, si->getLoc ());
4975
- builder.createEndAccess (si->getLoc (), modifyAccess, /* aborted*/ false );
4976
- } else {
4977
- emitZeroIndirect (bufType.getASTType (), adjBuf, si->getLoc ());
4978
- }
4992
+ emitZeroIndirect (bufType.getASTType (), adjBuf, si->getLoc ());
4979
4993
}
4980
4994
4981
4995
// Handle `copy_addr` instruction.
@@ -5001,18 +5015,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
5001
5015
addToAdjointBuffer (bb, cai->getSrc (), readAccess);
5002
5016
builder.createEndAccess (cai->getLoc (), readAccess, /* aborted*/ false );
5003
5017
// Set the buffer to zero, with a cleanup.
5004
- auto *bai = dyn_cast<BeginAccessInst>(adjDest.getValue ());
5005
- if (bai && !(bai->getAccessKind () == SILAccessKind::Modify ||
5006
- bai->getAccessKind () == SILAccessKind::Init)) {
5007
- auto *modifyAccess = builder.createBeginAccess (
5008
- cai->getLoc (), bai->getSource (), SILAccessKind::Modify,
5009
- SILAccessEnforcement::Static, /* noNestedConflict*/ true ,
5010
- /* fromBuiltin*/ false );
5011
- emitZeroIndirect (destType.getASTType (), modifyAccess, cai->getLoc ());
5012
- builder.createEndAccess (cai->getLoc (), modifyAccess, /* aborted*/ false );
5013
- } else {
5014
- emitZeroIndirect (destType.getASTType (), adjDest, cai->getLoc ());
5015
- }
5018
+ emitZeroIndirect (destType.getASTType (), adjDest, cai->getLoc ());
5016
5019
auto cleanup = makeCleanup (adjDest, emitCleanup);
5017
5020
adjDest.setCleanup (cleanup);
5018
5021
}
@@ -5140,7 +5143,7 @@ ValueWithCleanup AdjointEmitter::materializeAdjointDirect(
5140
5143
SmallVector<SILValue, 8 > elements;
5141
5144
SmallVector<Cleanup *, 8 > cleanups;
5142
5145
for (auto i : range (val.getNumAggregateElements ())) {
5143
- auto eltVal = materializeAdjointDirect (val.takeAggregateElement (i), loc);
5146
+ auto eltVal = materializeAdjointDirect (val.getAggregateElement (i), loc);
5144
5147
elements.push_back (eltVal.getValue ());
5145
5148
cleanups.push_back (eltVal.getCleanup ());
5146
5149
}
@@ -5216,7 +5219,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
5216
5219
ValueWithCleanup eltBuf (
5217
5220
builder.createTupleElementAddr (loc, destBufferAccess, idx, eltTy),
5218
5221
/* cleanup*/ nullptr );
5219
- materializeAdjointIndirectHelper (val.takeAggregateElement (idx), eltBuf);
5222
+ materializeAdjointIndirectHelper (val.getAggregateElement (idx), eltBuf);
5220
5223
destBufferAccess.setCleanup (makeCleanupFromChildren (
5221
5224
{destBufferAccess.getCleanup (), eltBuf.getCleanup ()}));
5222
5225
}
@@ -5228,7 +5231,7 @@ void AdjointEmitter::materializeAdjointIndirectHelper(
5228
5231
ValueWithCleanup eltBuf (
5229
5232
builder.createStructElementAddr (loc, destBufferAccess, *fieldIt),
5230
5233
/* cleanup*/ nullptr );
5231
- materializeAdjointIndirectHelper (val.takeAggregateElement (i), eltBuf);
5234
+ materializeAdjointIndirectHelper (val.getAggregateElement (i), eltBuf);
5232
5235
destBufferAccess.setCleanup (makeCleanupFromChildren (
5233
5236
{destBufferAccess.getCleanup (), eltBuf.getCleanup ()}));
5234
5237
}
@@ -5293,6 +5296,25 @@ SILValue AdjointEmitter::emitZeroDirect(CanType type, SILLocation loc) {
5293
5296
return loaded;
5294
5297
}
5295
5298
5299
+ void AdjointEmitter::emitCleanupForAdjointValue (AdjointValue value) {
5300
+ switch (value.getKind ()) {
5301
+ case AdjointValueKind::Zero: return ;
5302
+ case AdjointValueKind::Aggregate:
5303
+ for (auto element : value.getAggregateElements ())
5304
+ emitCleanupForAdjointValue (element);
5305
+ break ;
5306
+ case AdjointValueKind::Concrete: {
5307
+ auto concrete = value.getConcreteValue ();
5308
+ auto *cleanup = concrete.getCleanup ();
5309
+ LLVM_DEBUG (getADDebugStream () << " Applying "
5310
+ << cleanup->getNumChildren () << " for value "
5311
+ << concrete.getValue () << " child cleanups\n " );
5312
+ cleanup->applyRecursively (builder, concrete.getLoc ());
5313
+ break ;
5314
+ }
5315
+ }
5316
+ }
5317
+
5296
5318
AdjointValue
5297
5319
AdjointEmitter::accumulateAdjointsDirect (AdjointValue lhs,
5298
5320
AdjointValue rhs) {
@@ -5324,7 +5346,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
5324
5346
for (auto idx : range (rhs.getNumAggregateElements ())) {
5325
5347
auto lhsElt = builder.createTupleExtract (
5326
5348
lhsVal.getLoc (), lhsVal, idx);
5327
- auto rhsElt = rhs.takeAggregateElement (idx);
5349
+ auto rhsElt = rhs.getAggregateElement (idx);
5328
5350
newElements.push_back (accumulateAdjointsDirect (
5329
5351
makeConcreteAdjointValue (
5330
5352
ValueWithCleanup (lhsElt, lhsVal.getCleanup ())),
@@ -5336,7 +5358,7 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
5336
5358
++fieldIt, ++i) {
5337
5359
auto lhsElt = builder.createStructExtract (
5338
5360
lhsVal.getLoc (), lhsVal, *fieldIt);
5339
- auto rhsElt = rhs.takeAggregateElement (i);
5361
+ auto rhsElt = rhs.getAggregateElement (i);
5340
5362
newElements.push_back (accumulateAdjointsDirect (
5341
5363
makeConcreteAdjointValue (
5342
5364
ValueWithCleanup (lhsElt, lhsVal.getCleanup ())),
@@ -5365,8 +5387,8 @@ AdjointEmitter::accumulateAdjointsDirect(AdjointValue lhs,
5365
5387
SmallVector<AdjointValue, 8 > newElements;
5366
5388
for (auto i : range (lhs.getNumAggregateElements ()))
5367
5389
newElements.push_back (
5368
- accumulateAdjointsDirect (lhs.takeAggregateElement (i),
5369
- rhs.takeAggregateElement (i)));
5390
+ accumulateAdjointsDirect (lhs.getAggregateElement (i),
5391
+ rhs.getAggregateElement (i)));
5370
5392
return makeAggregateAdjointValue (lhs.getType (), newElements);
5371
5393
}
5372
5394
}
0 commit comments