Skip to content

Commit 63cfd7c

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent f53be47 commit 63cfd7c

File tree

2 files changed

+278
-30
lines changed

2 files changed

+278
-30
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 211 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,30 @@ class PullbackCloner::Implementation final
327327
// Adjoint value materialization
328328
//--------------------------------------------------------------------------//
329329

330+
/// Determines whether this adjoint value can be materialized by materializing
331+
/// and then combining individual fields of the adjoint. This function should
332+
/// only be called for aggregate adjoints.
333+
///
334+
/// Users are allowed to define custom tangent vectors which may contain
335+
/// fields that do not conform to `AdditiveArithmetic` and `Differentiable`
336+
/// protocols. This is fine as long as they fulfill the corresponding
337+
/// `AdditiveArithmetic` and `Differentiable` protocol requirements on the
338+
/// tangent vector itself.
339+
///
340+
/// A user-defined tangent vector with above characteristics, cannot be
341+
/// materialized by materializing individual fields as that process relies on
342+
/// the `AdditiveArithmeticness` of the individual fields.
343+
bool isAdjointPiecewiseMaterializable(AdjointValue val) {
344+
assert(val.getKind() == AdjointValueKind::Aggregate);
345+
for (auto i : range(val.getNumAggregateElements())) {
346+
auto fieldCanTy = val.getAggregateElement(i).getType().getASTType();
347+
if (!getTangentSpace(fieldCanTy)) {
348+
return false;
349+
}
350+
}
351+
return true;
352+
}
353+
330354
/// Materializes an adjoint value. The type of the given adjoint value must be
331355
/// loadable.
332356
SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
@@ -336,20 +360,14 @@ class PullbackCloner::Implementation final
336360
SILValue result;
337361
switch (val.getKind()) {
338362
case AdjointValueKind::Zero:
339-
result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
363+
result = builder.emitZero(loc, val.getSwiftType());
340364
break;
341365
case AdjointValueKind::Aggregate: {
342-
SmallVector<SILValue, 8> elements;
343-
for (auto i : range(val.getNumAggregateElements())) {
344-
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
345-
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
366+
if (isAdjointPiecewiseMaterializable(val)) {
367+
result = materializeAggregateAdjointDirectPiecewise(val, loc);
368+
} else {
369+
result = materializeAggregateAdjointDirect(val, loc);
346370
}
347-
if (val.getType().is<TupleType>())
348-
result = recordTemporary(
349-
builder.createTuple(loc, val.getType(), elements));
350-
else
351-
result = recordTemporary(
352-
builder.createStruct(loc, val.getType(), elements));
353371
break;
354372
}
355373
case AdjointValueKind::Concrete:
@@ -362,6 +380,93 @@ class PullbackCloner::Implementation final
362380
return result;
363381
}
364382

383+
/// Used to materialize an aggregate adjoint directly, if
384+
/// `isAdjointPiecewiseMaterializable` returned false.
385+
SILValue materializeAggregateAdjointDirect(AdjointValue val,
386+
SILLocation loc) {
387+
SILValue result;
388+
auto *resultAlloc = builder.createAllocStack(loc, val.getType());
389+
builder.emitZeroIntoBuffer(loc, resultAlloc, IsInitialization);
390+
391+
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
392+
for (auto fieldIndex : range(val.getNumAggregateElements())) {
393+
auto adjField = val.getAggregateElement(fieldIndex);
394+
// No need to materialize zero field adjoints when we
395+
// have already materialized a zero aggregate adjoint
396+
// for the aggregate type containing the field.
397+
if (adjField.getKind() != AdjointValueKind::Zero) {
398+
auto eltTy = SILType::getPrimitiveAddressType(
399+
tupTy->getElementType(fieldIndex)->getCanonicalType());
400+
auto lhsAdjEltBuf =
401+
builder
402+
.createTupleElementAddr(loc, resultAlloc, fieldIndex, eltTy)
403+
->getResult(0);
404+
auto rhsAdjEltBuf =
405+
builder.createAllocStack(loc, adjField.getType())->getResult(0);
406+
materializeAdjointIndirect(adjField, rhsAdjEltBuf, loc);
407+
408+
// lhsAdjEltBuf += rhsAdjEltBuf
409+
builder.emitInPlaceAdd(loc, lhsAdjEltBuf, rhsAdjEltBuf);
410+
builder.createDestroyAddr(loc, rhsAdjEltBuf);
411+
builder.createDeallocStack(loc, rhsAdjEltBuf);
412+
}
413+
}
414+
} else if (auto *structDecl =
415+
val.getSwiftType().getStructOrBoundGenericStruct()) {
416+
unsigned fieldIndex = 0;
417+
for (auto it = structDecl->getStoredProperties().begin();
418+
it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) {
419+
auto adjField = val.getAggregateElement(fieldIndex);
420+
// No need to materialize zero field adjoints when we
421+
// have already materialized a zero aggregate adjoint
422+
// for the aggregate type containing the field.
423+
if (adjField.getKind() != AdjointValueKind::Zero) {
424+
VarDecl *field = *it;
425+
auto lhsAdjEltBuf =
426+
builder.createStructElementAddr(loc, resultAlloc, field)
427+
->getResult(0);
428+
auto rhsAdjEltBuf =
429+
builder.createAllocStack(loc, adjField.getType())->getResult(0);
430+
materializeAdjointIndirect(adjField, rhsAdjEltBuf, loc);
431+
432+
// lhsAdjEltBuf += rhsAdjEltBuf
433+
builder.emitInPlaceAdd(loc, lhsAdjEltBuf, rhsAdjEltBuf);
434+
builder.createDestroyAddr(loc, rhsAdjEltBuf);
435+
builder.createDeallocStack(loc, rhsAdjEltBuf);
436+
}
437+
}
438+
} else {
439+
llvm_unreachable("Not an aggregate type");
440+
}
441+
442+
result = recordTemporary(builder.emitLoadValueOperation(
443+
loc, resultAlloc, LoadOwnershipQualifier::Take));
444+
builder.createDeallocStack(loc, resultAlloc);
445+
446+
return result;
447+
}
448+
449+
/// Used to materialize an aggregate adjoint directly, if
450+
/// `isAdjointPiecewiseMaterializable` returned true.
451+
SILValue materializeAggregateAdjointDirectPiecewise(AdjointValue val,
452+
SILLocation loc) {
453+
SILValue result;
454+
455+
SmallVector<SILValue, 8> elements;
456+
for (auto i : range(val.getNumAggregateElements())) {
457+
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
458+
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
459+
}
460+
if (val.getType().is<TupleType>())
461+
result =
462+
recordTemporary(builder.createTuple(loc, val.getType(), elements));
463+
else
464+
result =
465+
recordTemporary(builder.createStruct(loc, val.getType(), elements));
466+
467+
return result;
468+
}
469+
365470
/// Materializes an adjoint value indirectly to a SIL buffer.
366471
void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
367472
SILLocation loc) {
@@ -376,38 +481,114 @@ class PullbackCloner::Implementation final
376481
/// materialize the symbolic tuple or struct, filling the
377482
/// buffer.
378483
case AdjointValueKind::Aggregate: {
379-
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
380-
for (auto idx : range(val.getNumAggregateElements())) {
381-
auto eltTy = SILType::getPrimitiveAddressType(
382-
tupTy->getElementType(idx)->getCanonicalType());
383-
auto *eltBuf =
384-
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
385-
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
386-
}
387-
} else if (auto *structDecl =
388-
val.getSwiftType()->getStructOrBoundGenericStruct()) {
389-
auto fieldIt = structDecl->getStoredProperties().begin();
390-
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
391-
++fieldIt, ++i) {
392-
auto eltBuf =
393-
builder.createStructElementAddr(loc, destAddress, *fieldIt);
394-
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
395-
}
484+
if (isAdjointPiecewiseMaterializable(val)) {
485+
materializeAggregateAdjointInDirectPiecewise(val, destAddress, loc);
396486
} else {
397-
llvm_unreachable("Not an aggregate type");
487+
materializeAggregateAdjointInDirect(val, destAddress, loc);
398488
}
399489
break;
400490
}
401491
/// If adjoint value is concrete, it is already materialized. Store it in
402492
/// the destination address.
403493
case AdjointValueKind::Concrete:
404494
auto concreteVal = val.getConcreteValue();
405-
builder.emitStoreValueOperation(loc, concreteVal, destAddress,
495+
// `val` needs to be an owned value for storing it into `destAddress`,
496+
// which may not always be the case. So, we create a copy of the value
497+
// first.
498+
auto copyConreteVal = builder.emitCopyValueOperation(loc, concreteVal);
499+
builder.emitStoreValueOperation(loc, copyConreteVal, destAddress,
406500
StoreOwnershipQualifier::Init);
407501
break;
408502
}
409503
}
410504

505+
/// Used to materialize an aggregate adjoint indirectly, if
506+
/// `isAdjointPiecewiseMaterializable` returned false.
507+
void materializeAggregateAdjointInDirect(AdjointValue val,
508+
SILValue destAddress,
509+
SILLocation loc) {
510+
assert(destAddress->getType().isAddress());
511+
auto zeroAggAdj = builder.emitZero(loc, val.getSwiftType());
512+
auto isTupleType = val.getType().is<TupleType>();
513+
514+
SILInstruction *destructureInst;
515+
if (isTupleType)
516+
destructureInst = builder.createDestructureTuple(loc, zeroAggAdj);
517+
else
518+
destructureInst = builder.createDestructureStruct(loc, zeroAggAdj);
519+
520+
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
521+
for (auto idx : range(val.getNumAggregateElements())) {
522+
auto adjField = val.getAggregateElement(idx);
523+
524+
auto eltTy = SILType::getPrimitiveAddressType(
525+
tupTy->getElementType(idx)->getCanonicalType());
526+
auto *eltBuf =
527+
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
528+
529+
if (adjField.getKind() != AdjointValueKind::Zero) {
530+
materializeAdjointIndirect(adjField, eltBuf, loc);
531+
} else {
532+
// No need to individually materialize zero field adjoints. Instead
533+
// we can use the corresponding adjoint values from `zeroAggAdj`.
534+
builder.emitStoreValueOperation(loc, destructureInst->getResult(idx),
535+
eltBuf,
536+
StoreOwnershipQualifier::Init);
537+
}
538+
}
539+
} else if (auto *structDecl =
540+
val.getSwiftType()->getStructOrBoundGenericStruct()) {
541+
auto fieldIt = structDecl->getStoredProperties().begin();
542+
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
543+
++fieldIt, ++i) {
544+
auto adjField = val.getAggregateElement(i);
545+
546+
auto eltBuf =
547+
builder.createStructElementAddr(loc, destAddress, *fieldIt);
548+
549+
if (adjField.getKind() != AdjointValueKind::Zero) {
550+
materializeAdjointIndirect(adjField, eltBuf, loc);
551+
} else {
552+
// No need to individually materialize zero field adjoints. Instead
553+
// we can use the corresponding adjoint values from `zeroAggAdj`.
554+
builder.emitStoreValueOperation(loc, destructureInst->getResult(i),
555+
eltBuf,
556+
StoreOwnershipQualifier::Init);
557+
}
558+
}
559+
} else {
560+
llvm_unreachable("Not an aggregate type");
561+
}
562+
}
563+
564+
/// Used to materialize an aggregate adjoint indirectly, if
565+
/// `isAdjointPiecewiseMaterializable` returned true.
566+
void materializeAggregateAdjointInDirectPiecewise(AdjointValue val,
567+
SILValue destAddress,
568+
SILLocation loc) {
569+
assert(destAddress->getType().isAddress());
570+
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
571+
for (auto idx : range(val.getNumAggregateElements())) {
572+
auto eltTy = SILType::getPrimitiveAddressType(
573+
tupTy->getElementType(idx)->getCanonicalType());
574+
auto *eltBuf =
575+
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
576+
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
577+
}
578+
} else if (auto *structDecl =
579+
val.getSwiftType()->getStructOrBoundGenericStruct()) {
580+
auto fieldIt = structDecl->getStoredProperties().begin();
581+
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
582+
++fieldIt, ++i) {
583+
auto eltBuf =
584+
builder.createStructElementAddr(loc, destAddress, *fieldIt);
585+
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
586+
}
587+
} else {
588+
llvm_unreachable("Not an aggregate type");
589+
}
590+
}
591+
411592
//--------------------------------------------------------------------------//
412593
// Adjoint value mapping
413594
//--------------------------------------------------------------------------//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: %target-swift-frontend -emit-sil -O %s
2+
3+
import _Differentiation
4+
5+
// Issue #66522:
6+
// Pullback generation for a function mapping a differentiable input type to
7+
// one of its differentiable fields fails when the input's tangent vector contains
8+
// non-differentiable fields.
9+
public struct P<Value>: Differentiable
10+
where
11+
Value: Differentiable,
12+
Value.TangentVector == Value,
13+
Value: AdditiveArithmetic {
14+
// `P` is its own `TangentVector`
15+
public typealias TangentVector = Self
16+
17+
// Non-differentiable field in `P`'s `TangentVector`.
18+
public let name: String = ""
19+
var value: Value
20+
}
21+
22+
extension P: Equatable, AdditiveArithmetic
23+
where Value: AdditiveArithmetic {
24+
public static var zero: Self {fatalError()}
25+
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
26+
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
27+
}
28+
29+
@differentiable(reverse)
30+
internal func testFunction(data: P<Double>) -> Double {
31+
data.value
32+
}
33+
34+
35+
// WHAT?
36+
// This test case currently fails with the following compiler error -
37+
// ```
38+
// /Users/kshitij/workspace/swift-project/swift/test/AutoDiff/compiler_crashers_fixed/issue-66522-pullback-generation-when-tangentvector-of-input-contains-nondifferentiable-fields.swift:9:15: error: expression is not differentiable
39+
// public struct P<Value>: Differentiable
40+
// ^
41+
// /Users/kshitij/workspace/swift-project/swift/test/AutoDiff/compiler_crashers_fixed/issue-66522-pullback-generation-when-tangentvector-of-input-contains-nondifferentiable-fields.swift:9:15: note: cannot differentiate access to property 'P.name' because property type 'String' does not conform to 'Differentiable'
42+
// public struct P<Value>: Differentiable
43+
// ^
44+
// ```
45+
//
46+
// WHY?
47+
// The `TangentVector` for `P` has been defined as `Self`.
48+
// The pullback generation for `P.init` fails because it now
49+
// contains a projection to a non-differentiable stored
50+
// property, which is not allowed currently - https://github.com/apple/swift/blob/main/lib/SILOptimizer/Differentiation/PullbackCloner.cpp#L1850.
51+
//
52+
// The same issue does not happen if we let the compiler
53+
// generated `P`'s `TangentVector`, because `P.init` now
54+
// becomes differentiable.
55+
//
56+
// SO WHAT?
57+
// It feels like a bit of an edge case that for a user-defined tangent
58+
// vector, such as that for `P`, we cannot have a function mapping
59+
// individual member fields to `P`.
60+
//
61+
// NOW WHAT?
62+
// Is this edge case something that we need to handle?
63+
//
64+
// @differentiable(reverse)
65+
// internal func testFunction2(data: Double) -> P<Double> {
66+
// P(value: data)
67+
// }

0 commit comments

Comments
 (0)