@@ -327,6 +327,30 @@ class PullbackCloner::Implementation final
327
327
// Adjoint value materialization
328
328
// --------------------------------------------------------------------------//
329
329
330
+ // / Determines whether this adjoint value can be materialized by materializing
331
+ // / and then combining individual fields of the adjoint. This function should
332
+ // / only be called for aggregate adjoints.
333
+ // /
334
+ // / Users are allowed to define custom tangent vectors which may contain
335
+ // / fields that do not conform to `AdditiveArithmetic` and `Differentiable`
336
+ // / protocols. This is fine as long as they fulfill the corresponding
337
+ // / `AdditiveArithmetic` and `Differentiable` protocol requirements on the
338
+ // / tangent vector itself.
339
+ // /
340
+ // / A user-defined tangent vector with above characteristics, cannot be
341
+ // / materialized by materializing individual fields as that process relies on
342
+ // / the `AdditiveArithmeticness` of the individual fields.
343
+ bool isAdjointPiecewiseMaterializable (AdjointValue val) {
344
+ assert (val.getKind () == AdjointValueKind::Aggregate);
345
+ for (auto i : range (val.getNumAggregateElements ())) {
346
+ auto fieldCanTy = val.getAggregateElement (i).getType ().getASTType ();
347
+ if (!getTangentSpace (fieldCanTy)) {
348
+ return false ;
349
+ }
350
+ }
351
+ return true ;
352
+ }
353
+
330
354
// / Materializes an adjoint value. The type of the given adjoint value must be
331
355
// / loadable.
332
356
SILValue materializeAdjointDirect (AdjointValue val, SILLocation loc) {
@@ -336,20 +360,14 @@ class PullbackCloner::Implementation final
336
360
SILValue result;
337
361
switch (val.getKind ()) {
338
362
case AdjointValueKind::Zero:
339
- result = recordTemporary ( builder.emitZero (loc, val.getSwiftType () ));
363
+ result = builder.emitZero (loc, val.getSwiftType ());
340
364
break ;
341
365
case AdjointValueKind::Aggregate: {
342
- SmallVector<SILValue, 8 > elements;
343
- for ( auto i : range (val. getNumAggregateElements ())) {
344
- auto eltVal = materializeAdjointDirect (val. getAggregateElement (i), loc);
345
- elements. push_back (builder. emitCopyValueOperation (loc, eltVal) );
366
+ if ( isAdjointPiecewiseMaterializable (val)) {
367
+ result = materializeAggregateAdjointDirectPiecewise (val, loc);
368
+ } else {
369
+ result = materializeAggregateAdjointDirect (val, loc );
346
370
}
347
- if (val.getType ().is <TupleType>())
348
- result = recordTemporary (
349
- builder.createTuple (loc, val.getType (), elements));
350
- else
351
- result = recordTemporary (
352
- builder.createStruct (loc, val.getType (), elements));
353
371
break ;
354
372
}
355
373
case AdjointValueKind::Concrete:
@@ -362,6 +380,93 @@ class PullbackCloner::Implementation final
362
380
return result;
363
381
}
364
382
383
+ // / Used to materialize an aggregate adjoint directly, if
384
+ // / `isAdjointPiecewiseMaterializable` returned false.
385
+ SILValue materializeAggregateAdjointDirect (AdjointValue val,
386
+ SILLocation loc) {
387
+ SILValue result;
388
+ auto *resultAlloc = builder.createAllocStack (loc, val.getType ());
389
+ builder.emitZeroIntoBuffer (loc, resultAlloc, IsInitialization);
390
+
391
+ if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
392
+ for (auto fieldIndex : range (val.getNumAggregateElements ())) {
393
+ auto adjField = val.getAggregateElement (fieldIndex);
394
+ // No need to materialize zero field adjoints when we
395
+ // have already materialized a zero aggregate adjoint
396
+ // for the aggregate type containing the field.
397
+ if (adjField.getKind () != AdjointValueKind::Zero) {
398
+ auto eltTy = SILType::getPrimitiveAddressType (
399
+ tupTy->getElementType (fieldIndex)->getCanonicalType ());
400
+ auto lhsAdjEltBuf =
401
+ builder
402
+ .createTupleElementAddr (loc, resultAlloc, fieldIndex, eltTy)
403
+ ->getResult (0 );
404
+ auto rhsAdjEltBuf =
405
+ builder.createAllocStack (loc, adjField.getType ())->getResult (0 );
406
+ materializeAdjointIndirect (adjField, rhsAdjEltBuf, loc);
407
+
408
+ // lhsAdjEltBuf += rhsAdjEltBuf
409
+ builder.emitInPlaceAdd (loc, lhsAdjEltBuf, rhsAdjEltBuf);
410
+ builder.createDestroyAddr (loc, rhsAdjEltBuf);
411
+ builder.createDeallocStack (loc, rhsAdjEltBuf);
412
+ }
413
+ }
414
+ } else if (auto *structDecl =
415
+ val.getSwiftType ().getStructOrBoundGenericStruct ()) {
416
+ unsigned fieldIndex = 0 ;
417
+ for (auto it = structDecl->getStoredProperties ().begin ();
418
+ it != structDecl->getStoredProperties ().end (); ++it, ++fieldIndex) {
419
+ auto adjField = val.getAggregateElement (fieldIndex);
420
+ // No need to materialize zero field adjoints when we
421
+ // have already materialized a zero aggregate adjoint
422
+ // for the aggregate type containing the field.
423
+ if (adjField.getKind () != AdjointValueKind::Zero) {
424
+ VarDecl *field = *it;
425
+ auto lhsAdjEltBuf =
426
+ builder.createStructElementAddr (loc, resultAlloc, field)
427
+ ->getResult (0 );
428
+ auto rhsAdjEltBuf =
429
+ builder.createAllocStack (loc, adjField.getType ())->getResult (0 );
430
+ materializeAdjointIndirect (adjField, rhsAdjEltBuf, loc);
431
+
432
+ // lhsAdjEltBuf += rhsAdjEltBuf
433
+ builder.emitInPlaceAdd (loc, lhsAdjEltBuf, rhsAdjEltBuf);
434
+ builder.createDestroyAddr (loc, rhsAdjEltBuf);
435
+ builder.createDeallocStack (loc, rhsAdjEltBuf);
436
+ }
437
+ }
438
+ } else {
439
+ llvm_unreachable (" Not an aggregate type" );
440
+ }
441
+
442
+ result = recordTemporary (builder.emitLoadValueOperation (
443
+ loc, resultAlloc, LoadOwnershipQualifier::Take));
444
+ builder.createDeallocStack (loc, resultAlloc);
445
+
446
+ return result;
447
+ }
448
+
449
+ // / Used to materialize an aggregate adjoint directly, if
450
+ // / `isAdjointPiecewiseMaterializable` returned true.
451
+ SILValue materializeAggregateAdjointDirectPiecewise (AdjointValue val,
452
+ SILLocation loc) {
453
+ SILValue result;
454
+
455
+ SmallVector<SILValue, 8 > elements;
456
+ for (auto i : range (val.getNumAggregateElements ())) {
457
+ auto eltVal = materializeAdjointDirect (val.getAggregateElement (i), loc);
458
+ elements.push_back (builder.emitCopyValueOperation (loc, eltVal));
459
+ }
460
+ if (val.getType ().is <TupleType>())
461
+ result =
462
+ recordTemporary (builder.createTuple (loc, val.getType (), elements));
463
+ else
464
+ result =
465
+ recordTemporary (builder.createStruct (loc, val.getType (), elements));
466
+
467
+ return result;
468
+ }
469
+
365
470
// / Materializes an adjoint value indirectly to a SIL buffer.
366
471
void materializeAdjointIndirect (AdjointValue val, SILValue destAddress,
367
472
SILLocation loc) {
@@ -376,38 +481,114 @@ class PullbackCloner::Implementation final
376
481
// / materialize the symbolic tuple or struct, filling the
377
482
// / buffer.
378
483
case AdjointValueKind::Aggregate: {
379
- if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
380
- for (auto idx : range (val.getNumAggregateElements ())) {
381
- auto eltTy = SILType::getPrimitiveAddressType (
382
- tupTy->getElementType (idx)->getCanonicalType ());
383
- auto *eltBuf =
384
- builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
385
- materializeAdjointIndirect (val.getAggregateElement (idx), eltBuf, loc);
386
- }
387
- } else if (auto *structDecl =
388
- val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
389
- auto fieldIt = structDecl->getStoredProperties ().begin ();
390
- for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
391
- ++fieldIt, ++i) {
392
- auto eltBuf =
393
- builder.createStructElementAddr (loc, destAddress, *fieldIt);
394
- materializeAdjointIndirect (val.getAggregateElement (i), eltBuf, loc);
395
- }
484
+ if (isAdjointPiecewiseMaterializable (val)) {
485
+ materializeAggregateAdjointInDirectPiecewise (val, destAddress, loc);
396
486
} else {
397
- llvm_unreachable ( " Not an aggregate type " );
487
+ materializeAggregateAdjointInDirect (val, destAddress, loc );
398
488
}
399
489
break ;
400
490
}
401
491
// / If adjoint value is concrete, it is already materialized. Store it in
402
492
// / the destination address.
403
493
case AdjointValueKind::Concrete:
404
494
auto concreteVal = val.getConcreteValue ();
405
- builder.emitStoreValueOperation (loc, concreteVal, destAddress,
495
+ // `val` needs to be an owned value for storing it into `destAddress`,
496
+ // which may not always be the case. So, we create a copy of the value
497
+ // first.
498
+ auto copyConreteVal = builder.emitCopyValueOperation (loc, concreteVal);
499
+ builder.emitStoreValueOperation (loc, copyConreteVal, destAddress,
406
500
StoreOwnershipQualifier::Init);
407
501
break ;
408
502
}
409
503
}
410
504
505
+ // / Used to materialize an aggregate adjoint indirectly, if
506
+ // / `isAdjointPiecewiseMaterializable` returned false.
507
+ void materializeAggregateAdjointInDirect (AdjointValue val,
508
+ SILValue destAddress,
509
+ SILLocation loc) {
510
+ assert (destAddress->getType ().isAddress ());
511
+ auto zeroAggAdj = builder.emitZero (loc, val.getSwiftType ());
512
+ auto isTupleType = val.getType ().is <TupleType>();
513
+
514
+ SILInstruction *destructureInst;
515
+ if (isTupleType)
516
+ destructureInst = builder.createDestructureTuple (loc, zeroAggAdj);
517
+ else
518
+ destructureInst = builder.createDestructureStruct (loc, zeroAggAdj);
519
+
520
+ if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
521
+ for (auto idx : range (val.getNumAggregateElements ())) {
522
+ auto adjField = val.getAggregateElement (idx);
523
+
524
+ auto eltTy = SILType::getPrimitiveAddressType (
525
+ tupTy->getElementType (idx)->getCanonicalType ());
526
+ auto *eltBuf =
527
+ builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
528
+
529
+ if (adjField.getKind () != AdjointValueKind::Zero) {
530
+ materializeAdjointIndirect (adjField, eltBuf, loc);
531
+ } else {
532
+ // No need to individually materialize zero field adjoints. Instead
533
+ // we can use the corresponding adjoint values from `zeroAggAdj`.
534
+ builder.emitStoreValueOperation (loc, destructureInst->getResult (idx),
535
+ eltBuf,
536
+ StoreOwnershipQualifier::Init);
537
+ }
538
+ }
539
+ } else if (auto *structDecl =
540
+ val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
541
+ auto fieldIt = structDecl->getStoredProperties ().begin ();
542
+ for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
543
+ ++fieldIt, ++i) {
544
+ auto adjField = val.getAggregateElement (i);
545
+
546
+ auto eltBuf =
547
+ builder.createStructElementAddr (loc, destAddress, *fieldIt);
548
+
549
+ if (adjField.getKind () != AdjointValueKind::Zero) {
550
+ materializeAdjointIndirect (adjField, eltBuf, loc);
551
+ } else {
552
+ // No need to individually materialize zero field adjoints. Instead
553
+ // we can use the corresponding adjoint values from `zeroAggAdj`.
554
+ builder.emitStoreValueOperation (loc, destructureInst->getResult (i),
555
+ eltBuf,
556
+ StoreOwnershipQualifier::Init);
557
+ }
558
+ }
559
+ } else {
560
+ llvm_unreachable (" Not an aggregate type" );
561
+ }
562
+ }
563
+
564
+ // / Used to materialize an aggregate adjoint indirectly, if
565
+ // / `isAdjointPiecewiseMaterializable` returned true.
566
+ void materializeAggregateAdjointInDirectPiecewise (AdjointValue val,
567
+ SILValue destAddress,
568
+ SILLocation loc) {
569
+ assert (destAddress->getType ().isAddress ());
570
+ if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
571
+ for (auto idx : range (val.getNumAggregateElements ())) {
572
+ auto eltTy = SILType::getPrimitiveAddressType (
573
+ tupTy->getElementType (idx)->getCanonicalType ());
574
+ auto *eltBuf =
575
+ builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
576
+ materializeAdjointIndirect (val.getAggregateElement (idx), eltBuf, loc);
577
+ }
578
+ } else if (auto *structDecl =
579
+ val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
580
+ auto fieldIt = structDecl->getStoredProperties ().begin ();
581
+ for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
582
+ ++fieldIt, ++i) {
583
+ auto eltBuf =
584
+ builder.createStructElementAddr (loc, destAddress, *fieldIt);
585
+ materializeAdjointIndirect (val.getAggregateElement (i), eltBuf, loc);
586
+ }
587
+ } else {
588
+ llvm_unreachable (" Not an aggregate type" );
589
+ }
590
+ }
591
+
411
592
// --------------------------------------------------------------------------//
412
593
// Adjoint value mapping
413
594
// --------------------------------------------------------------------------//
0 commit comments