@@ -327,6 +327,30 @@ class PullbackCloner::Implementation final
327
327
// Adjoint value materialization
328
328
// --------------------------------------------------------------------------//
329
329
330
+ // / Determines whether this adjoint value can be materialized by materializing
331
+ // / and then combining individual fields of the adjoint. This function should
332
+ // / only be called for aggregate adjoints.
333
+ // /
334
+ // / Users are allowed to define custom tangent vectors which may contain
335
+ // / fields that do not conform to `AdditiveArithmetic` and `Differentiable`
336
+ // / protocols. This is fine as long as they fulfill the corresponding
337
+ // / `AdditiveArithmetic` and `Differentiable` protocol requirements on the
338
+ // / tangent vector itself.
339
+ // /
340
+ // / A user-defined tangent vector with above characteristics, cannot be
341
+ // / materialized by materializing individual fields as that process relies on
342
+ // / the `AdditiveArithmeticness` of the individual fields.
343
+ bool isAdjointPiecewiseMaterializable (AdjointValue val) {
344
+ assert (val.getKind () == AdjointValueKind::Aggregate);
345
+ for (auto i : range (val.getNumAggregateElements ())) {
346
+ auto fieldCanTy = val.getAggregateElement (i).getType ().getASTType ();
347
+ if (!getTangentSpace (fieldCanTy)) {
348
+ return false ;
349
+ }
350
+ }
351
+ return true ;
352
+ }
353
+
330
354
// / Materializes an adjoint value. The type of the given adjoint value must be
331
355
// / loadable.
332
356
SILValue materializeAdjointDirect (AdjointValue val, SILLocation loc) {
@@ -339,17 +363,11 @@ class PullbackCloner::Implementation final
339
363
result = recordTemporary (builder.emitZero (loc, val.getSwiftType ()));
340
364
break ;
341
365
case AdjointValueKind::Aggregate: {
342
- SmallVector<SILValue, 8 > elements;
343
- for ( auto i : range (val. getNumAggregateElements ())) {
344
- auto eltVal = materializeAdjointDirect (val. getAggregateElement (i), loc);
345
- elements. push_back (builder. emitCopyValueOperation (loc, eltVal) );
366
+ if ( isAdjointPiecewiseMaterializable (val)) {
367
+ result = materializeAggregateAdjointDirectPiecewise (val, loc);
368
+ } else {
369
+ result = materializeAggregateAdjointDirect (val, loc );
346
370
}
347
- if (val.getType ().is <TupleType>())
348
- result = recordTemporary (
349
- builder.createTuple (loc, val.getType (), elements));
350
- else
351
- result = recordTemporary (
352
- builder.createStruct (loc, val.getType (), elements));
353
371
break ;
354
372
}
355
373
case AdjointValueKind::Concrete:
@@ -362,6 +380,68 @@ class PullbackCloner::Implementation final
362
380
return result;
363
381
}
364
382
383
+ // / Used to materialize an aggregate adjoint directly, if
384
+ // / `isAdjointPiecewiseMaterializable` returned false.
385
+ SILValue materializeAggregateAdjointDirect (AdjointValue val,
386
+ SILLocation loc) {
387
+ SILValue result = builder.emitZero (loc, val.getSwiftType ());
388
+ auto isTupleType = val.getType ().is <TupleType>();
389
+
390
+ SILInstruction *destructureInst;
391
+ if (isTupleType) {
392
+ destructureInst = builder.createDestructureTuple (loc, result);
393
+ } else {
394
+ destructureInst = builder.createDestructureStruct (loc, result);
395
+ }
396
+
397
+ // Note - Materializing the aggregate adjoints by collecting concrete
398
+ // field values and using the `struct` instruction rather than doing a
399
+ // `+=` on individual fields buffers because the TangentBuilder's
400
+ // `emitInPlaceAdd` method requires the destination buffer to be an address,
401
+ // which does not work for trivial types such as floats and doubles.
402
+ SmallVector<SILValue, 8 > elements;
403
+ for (auto i : range (val.getNumAggregateElements ())) {
404
+ auto fieldCanTy = val.getAggregateElement (i).getType ().getASTType ();
405
+
406
+ if (!getTangentSpace (fieldCanTy)) {
407
+ elements.push_back (destructureInst->getResult (i));
408
+ } else {
409
+ auto eltVal = materializeAdjointDirect (val.getAggregateElement (i), loc);
410
+ elements.push_back (builder.emitCopyValueOperation (loc, eltVal));
411
+ }
412
+ }
413
+
414
+ if (isTupleType)
415
+ result =
416
+ recordTemporary (builder.createTuple (loc, val.getType (), elements));
417
+ else
418
+ result =
419
+ recordTemporary (builder.createStruct (loc, val.getType (), elements));
420
+
421
+ return result;
422
+ }
423
+
424
+ // / Used to materialize an aggregate adjoint directly, if
425
+ // / `isAdjointPiecewiseMaterializable` returned true.
426
+ SILValue materializeAggregateAdjointDirectPiecewise (AdjointValue val,
427
+ SILLocation loc) {
428
+ SILValue result;
429
+
430
+ SmallVector<SILValue, 8 > elements;
431
+ for (auto i : range (val.getNumAggregateElements ())) {
432
+ auto eltVal = materializeAdjointDirect (val.getAggregateElement (i), loc);
433
+ elements.push_back (builder.emitCopyValueOperation (loc, eltVal));
434
+ }
435
+ if (val.getType ().is <TupleType>())
436
+ result =
437
+ recordTemporary (builder.createTuple (loc, val.getType (), elements));
438
+ else
439
+ result =
440
+ recordTemporary (builder.createStruct (loc, val.getType (), elements));
441
+
442
+ return result;
443
+ }
444
+
365
445
// / Materializes an adjoint value indirectly to a SIL buffer.
366
446
void materializeAdjointIndirect (AdjointValue val, SILValue destAddress,
367
447
SILLocation loc) {
@@ -376,25 +456,10 @@ class PullbackCloner::Implementation final
376
456
// / materialize the symbolic tuple or struct, filling the
377
457
// / buffer.
378
458
case AdjointValueKind::Aggregate: {
379
- if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
380
- for (auto idx : range (val.getNumAggregateElements ())) {
381
- auto eltTy = SILType::getPrimitiveAddressType (
382
- tupTy->getElementType (idx)->getCanonicalType ());
383
- auto *eltBuf =
384
- builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
385
- materializeAdjointIndirect (val.getAggregateElement (idx), eltBuf, loc);
386
- }
387
- } else if (auto *structDecl =
388
- val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
389
- auto fieldIt = structDecl->getStoredProperties ().begin ();
390
- for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
391
- ++fieldIt, ++i) {
392
- auto eltBuf =
393
- builder.createStructElementAddr (loc, destAddress, *fieldIt);
394
- materializeAdjointIndirect (val.getAggregateElement (i), eltBuf, loc);
395
- }
459
+ if (isAdjointPiecewiseMaterializable (val)) {
460
+ materializeAggregateAdjointInDirectPiecewise (val, destAddress, loc);
396
461
} else {
397
- llvm_unreachable ( " Not an aggregate type " );
462
+ materializeAggregateAdjointInDirect (val, destAddress, loc );
398
463
}
399
464
break ;
400
465
}
@@ -408,6 +473,87 @@ class PullbackCloner::Implementation final
408
473
}
409
474
}
410
475
476
+ // / Used to materialize an aggregate adjoint indirectly, if
477
+ // / `isAdjointPiecewiseMaterializable` returned false.
478
+ void materializeAggregateAdjointInDirect (AdjointValue val,
479
+ SILValue destAddress,
480
+ SILLocation loc) {
481
+ assert (destAddress->getType ().isAddress ());
482
+ auto zeroAggAdj = builder.emitZero (loc, val.getSwiftType ());
483
+ auto isTupleType = val.getType ().is <TupleType>();
484
+
485
+ SILInstruction *destructureInst;
486
+ if (isTupleType)
487
+ destructureInst = builder.createDestructureTuple (loc, zeroAggAdj);
488
+ else
489
+ destructureInst = builder.createDestructureStruct (loc, zeroAggAdj);
490
+
491
+ if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
492
+ for (auto idx : range (val.getNumAggregateElements ())) {
493
+ auto eltTy = SILType::getPrimitiveAddressType (
494
+ tupTy->getElementType (idx)->getCanonicalType ());
495
+ auto *eltBuf =
496
+ builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
497
+
498
+ auto fieldCanTy = val.getAggregateElement (idx).getType ().getASTType ();
499
+ if (!getTangentSpace (fieldCanTy)) {
500
+ builder.emitStoreValueOperation (loc, destructureInst->getResult (idx),
501
+ eltBuf,
502
+ StoreOwnershipQualifier::Init);
503
+ } else {
504
+ materializeAdjointIndirect (val.getAggregateElement (idx), eltBuf, loc);
505
+ }
506
+ }
507
+ } else if (auto *structDecl =
508
+ val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
509
+ auto fieldIt = structDecl->getStoredProperties ().begin ();
510
+ for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
511
+ ++fieldIt, ++i) {
512
+ auto eltBuf =
513
+ builder.createStructElementAddr (loc, destAddress, *fieldIt);
514
+
515
+ auto fieldCanTy = val.getAggregateElement (i).getType ().getASTType ();
516
+ if (!getTangentSpace (fieldCanTy)) {
517
+ builder.emitStoreValueOperation (loc, destructureInst->getResult (i),
518
+ eltBuf,
519
+ StoreOwnershipQualifier::Init);
520
+ } else {
521
+ materializeAdjointIndirect (val.getAggregateElement (i), eltBuf, loc);
522
+ }
523
+ }
524
+ } else {
525
+ llvm_unreachable (" Not an aggregate type" );
526
+ }
527
+ }
528
+
529
+ // / Used to materialize an aggregate adjoint indirectly, if
530
+ // / `isAdjointPiecewiseMaterializable` returned true.
531
+ void materializeAggregateAdjointInDirectPiecewise (AdjointValue val,
532
+ SILValue destAddress,
533
+ SILLocation loc) {
534
+ assert (destAddress->getType ().isAddress ());
535
+ if (auto *tupTy = val.getSwiftType ()->getAs <TupleType>()) {
536
+ for (auto idx : range (val.getNumAggregateElements ())) {
537
+ auto eltTy = SILType::getPrimitiveAddressType (
538
+ tupTy->getElementType (idx)->getCanonicalType ());
539
+ auto *eltBuf =
540
+ builder.createTupleElementAddr (loc, destAddress, idx, eltTy);
541
+ materializeAdjointIndirect (val.getAggregateElement (idx), eltBuf, loc);
542
+ }
543
+ } else if (auto *structDecl =
544
+ val.getSwiftType ()->getStructOrBoundGenericStruct ()) {
545
+ auto fieldIt = structDecl->getStoredProperties ().begin ();
546
+ for (unsigned i = 0 ; fieldIt != structDecl->getStoredProperties ().end ();
547
+ ++fieldIt, ++i) {
548
+ auto eltBuf =
549
+ builder.createStructElementAddr (loc, destAddress, *fieldIt);
550
+ materializeAdjointIndirect (val.getAggregateElement (i), eltBuf, loc);
551
+ }
552
+ } else {
553
+ llvm_unreachable (" Not an aggregate type" );
554
+ }
555
+ }
556
+
411
557
// --------------------------------------------------------------------------//
412
558
// Adjoint value mapping
413
559
// --------------------------------------------------------------------------//
0 commit comments