Skip to content

Commit 005a208

Browse files
committed
[AutoDiff] Handle materializing adjoints with non-differentiable fields
Fixes #66522
1 parent f53be47 commit 005a208

File tree

2 files changed

+241
-28
lines changed

2 files changed

+241
-28
lines changed

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

Lines changed: 174 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,30 @@ class PullbackCloner::Implementation final
327327
// Adjoint value materialization
328328
//--------------------------------------------------------------------------//
329329

330+
/// Determines whether this adjoint value can be materialized by materializing
331+
/// and then combining individual fields of the adjoint. This function should
332+
/// only be called for aggregate adjoints.
333+
///
334+
/// Users are allowed to define custom tangent vectors which may contain
335+
/// fields that do not conform to `AdditiveArithmetic` and `Differentiable`
336+
/// protocols. This is fine as long as they fulfill the corresponding
337+
/// `AdditiveArithmetic` and `Differentiable` protocol requirements on the
338+
/// tangent vector itself.
339+
///
340+
/// A user-defined tangent vector with above characteristics, cannot be
341+
/// materialized by materializing individual fields as that process relies on
342+
/// the `AdditiveArithmeticness` of the individual fields.
343+
bool isAdjointPiecewiseMaterializable(AdjointValue val) {
344+
assert(val.getKind() == AdjointValueKind::Aggregate);
345+
for (auto i : range(val.getNumAggregateElements())) {
346+
auto fieldCanTy = val.getAggregateElement(i).getType().getASTType();
347+
if (!getTangentSpace(fieldCanTy)) {
348+
return false;
349+
}
350+
}
351+
return true;
352+
}
353+
330354
/// Materializes an adjoint value. The type of the given adjoint value must be
331355
/// loadable.
332356
SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
@@ -339,17 +363,11 @@ class PullbackCloner::Implementation final
339363
result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
340364
break;
341365
case AdjointValueKind::Aggregate: {
342-
SmallVector<SILValue, 8> elements;
343-
for (auto i : range(val.getNumAggregateElements())) {
344-
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
345-
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
366+
if (isAdjointPiecewiseMaterializable(val)) {
367+
result = materializeAggregateAdjointDirectPiecewise(val, loc);
368+
} else {
369+
result = materializeAggregateAdjointDirect(val, loc);
346370
}
347-
if (val.getType().is<TupleType>())
348-
result = recordTemporary(
349-
builder.createTuple(loc, val.getType(), elements));
350-
else
351-
result = recordTemporary(
352-
builder.createStruct(loc, val.getType(), elements));
353371
break;
354372
}
355373
case AdjointValueKind::Concrete:
@@ -362,6 +380,68 @@ class PullbackCloner::Implementation final
362380
return result;
363381
}
364382

383+
/// Used to materialize an aggregate adjoint directly, if
384+
/// `isAdjointPiecewiseMaterializable` returned false.
385+
SILValue materializeAggregateAdjointDirect(AdjointValue val,
386+
SILLocation loc) {
387+
SILValue result = builder.emitZero(loc, val.getSwiftType());
388+
auto isTupleType = val.getType().is<TupleType>();
389+
390+
SILInstruction *destructureInst;
391+
if (isTupleType) {
392+
destructureInst = builder.createDestructureTuple(loc, result);
393+
} else {
394+
destructureInst = builder.createDestructureStruct(loc, result);
395+
}
396+
397+
// Note - Materializing the aggregate adjoints by collecting concrete
398+
// field values and using the `struct` instruction rather than doing a
399+
// `+=` on individual fields buffers because the TangentBuilder's
400+
// `emitInPlaceAdd` method requires the destination buffer to be an address,
401+
// which does not work for trivial types such as floats and doubles.
402+
SmallVector<SILValue, 8> elements;
403+
for (auto i : range(val.getNumAggregateElements())) {
404+
auto fieldCanTy = val.getAggregateElement(i).getType().getASTType();
405+
406+
if (!getTangentSpace(fieldCanTy)) {
407+
elements.push_back(destructureInst->getResult(i));
408+
} else {
409+
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
410+
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
411+
}
412+
}
413+
414+
if (isTupleType)
415+
result =
416+
recordTemporary(builder.createTuple(loc, val.getType(), elements));
417+
else
418+
result =
419+
recordTemporary(builder.createStruct(loc, val.getType(), elements));
420+
421+
return result;
422+
}
423+
424+
/// Used to materialize an aggregate adjoint directly, if
425+
/// `isAdjointPiecewiseMaterializable` returned true.
426+
SILValue materializeAggregateAdjointDirectPiecewise(AdjointValue val,
427+
SILLocation loc) {
428+
SILValue result;
429+
430+
SmallVector<SILValue, 8> elements;
431+
for (auto i : range(val.getNumAggregateElements())) {
432+
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
433+
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
434+
}
435+
if (val.getType().is<TupleType>())
436+
result =
437+
recordTemporary(builder.createTuple(loc, val.getType(), elements));
438+
else
439+
result =
440+
recordTemporary(builder.createStruct(loc, val.getType(), elements));
441+
442+
return result;
443+
}
444+
365445
/// Materializes an adjoint value indirectly to a SIL buffer.
366446
void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
367447
SILLocation loc) {
@@ -376,25 +456,10 @@ class PullbackCloner::Implementation final
376456
/// materialize the symbolic tuple or struct, filling the
377457
/// buffer.
378458
case AdjointValueKind::Aggregate: {
379-
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
380-
for (auto idx : range(val.getNumAggregateElements())) {
381-
auto eltTy = SILType::getPrimitiveAddressType(
382-
tupTy->getElementType(idx)->getCanonicalType());
383-
auto *eltBuf =
384-
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
385-
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
386-
}
387-
} else if (auto *structDecl =
388-
val.getSwiftType()->getStructOrBoundGenericStruct()) {
389-
auto fieldIt = structDecl->getStoredProperties().begin();
390-
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
391-
++fieldIt, ++i) {
392-
auto eltBuf =
393-
builder.createStructElementAddr(loc, destAddress, *fieldIt);
394-
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
395-
}
459+
if (isAdjointPiecewiseMaterializable(val)) {
460+
materializeAggregateAdjointInDirectPiecewise(val, destAddress, loc);
396461
} else {
397-
llvm_unreachable("Not an aggregate type");
462+
materializeAggregateAdjointInDirect(val, destAddress, loc);
398463
}
399464
break;
400465
}
@@ -408,6 +473,87 @@ class PullbackCloner::Implementation final
408473
}
409474
}
410475

476+
/// Used to materialize an aggregate adjoint indirectly, if
477+
/// `isAdjointPiecewiseMaterializable` returned false.
478+
void materializeAggregateAdjointInDirect(AdjointValue val,
479+
SILValue destAddress,
480+
SILLocation loc) {
481+
assert(destAddress->getType().isAddress());
482+
auto zeroAggAdj = builder.emitZero(loc, val.getSwiftType());
483+
auto isTupleType = val.getType().is<TupleType>();
484+
485+
SILInstruction *destructureInst;
486+
if (isTupleType)
487+
destructureInst = builder.createDestructureTuple(loc, zeroAggAdj);
488+
else
489+
destructureInst = builder.createDestructureStruct(loc, zeroAggAdj);
490+
491+
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
492+
for (auto idx : range(val.getNumAggregateElements())) {
493+
auto eltTy = SILType::getPrimitiveAddressType(
494+
tupTy->getElementType(idx)->getCanonicalType());
495+
auto *eltBuf =
496+
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
497+
498+
auto fieldCanTy = val.getAggregateElement(idx).getType().getASTType();
499+
if (!getTangentSpace(fieldCanTy)) {
500+
builder.emitStoreValueOperation(loc, destructureInst->getResult(idx),
501+
eltBuf,
502+
StoreOwnershipQualifier::Init);
503+
} else {
504+
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
505+
}
506+
}
507+
} else if (auto *structDecl =
508+
val.getSwiftType()->getStructOrBoundGenericStruct()) {
509+
auto fieldIt = structDecl->getStoredProperties().begin();
510+
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
511+
++fieldIt, ++i) {
512+
auto eltBuf =
513+
builder.createStructElementAddr(loc, destAddress, *fieldIt);
514+
515+
auto fieldCanTy = val.getAggregateElement(i).getType().getASTType();
516+
if (!getTangentSpace(fieldCanTy)) {
517+
builder.emitStoreValueOperation(loc, destructureInst->getResult(i),
518+
eltBuf,
519+
StoreOwnershipQualifier::Init);
520+
} else {
521+
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
522+
}
523+
}
524+
} else {
525+
llvm_unreachable("Not an aggregate type");
526+
}
527+
}
528+
529+
/// Used to materialize an aggregate adjoint indirectly, if
530+
/// `isAdjointPiecewiseMaterializable` returned true.
531+
void materializeAggregateAdjointInDirectPiecewise(AdjointValue val,
532+
SILValue destAddress,
533+
SILLocation loc) {
534+
assert(destAddress->getType().isAddress());
535+
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
536+
for (auto idx : range(val.getNumAggregateElements())) {
537+
auto eltTy = SILType::getPrimitiveAddressType(
538+
tupTy->getElementType(idx)->getCanonicalType());
539+
auto *eltBuf =
540+
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
541+
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
542+
}
543+
} else if (auto *structDecl =
544+
val.getSwiftType()->getStructOrBoundGenericStruct()) {
545+
auto fieldIt = structDecl->getStoredProperties().begin();
546+
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
547+
++fieldIt, ++i) {
548+
auto eltBuf =
549+
builder.createStructElementAddr(loc, destAddress, *fieldIt);
550+
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
551+
}
552+
} else {
553+
llvm_unreachable("Not an aggregate type");
554+
}
555+
}
556+
411557
//--------------------------------------------------------------------------//
412558
// Adjoint value mapping
413559
//--------------------------------------------------------------------------//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: %target-swift-frontend -emit-sil -O %s
2+
3+
import _Differentiation
4+
5+
// Issue #66522:
6+
// Pullback generation for a function mapping a differentiable input type to
7+
// one of its differentiable fields fails when the input's tangent vector contains
8+
// non-differentiable fields.
9+
public struct P<Value>: Differentiable
10+
where
11+
Value: Differentiable,
12+
Value.TangentVector == Value,
13+
Value: AdditiveArithmetic {
14+
// `P` is its own `TangentVector`
15+
public typealias TangentVector = Self
16+
17+
// Non-differentiable field in `P`'s `TangentVector`.
18+
public let name: String = ""
19+
var value: Value
20+
}
21+
22+
extension P: Equatable, AdditiveArithmetic
23+
where Value: AdditiveArithmetic {
24+
public static var zero: Self {fatalError()}
25+
public static func + (lhs: Self, rhs: Self) -> Self {fatalError()}
26+
public static func - (lhs: Self, rhs: Self) -> Self {fatalError()}
27+
}
28+
29+
@differentiable(reverse)
30+
internal func testFunction(data: P<Double>) -> Double {
31+
data.value
32+
}
33+
34+
35+
// WHAT?
36+
// This test case currently fails with the following compiler error -
37+
// ```
38+
// /Users/kshitij/workspace/swift-project/swift/test/AutoDiff/compiler_crashers_fixed/issue-66522-pullback-generation-when-tangentvector-of-input-contains-nondifferentiable-fields.swift:9:15: error: expression is not differentiable
39+
// public struct P<Value>: Differentiable
40+
// ^
41+
// /Users/kshitij/workspace/swift-project/swift/test/AutoDiff/compiler_crashers_fixed/issue-66522-pullback-generation-when-tangentvector-of-input-contains-nondifferentiable-fields.swift:9:15: note: cannot differentiate access to property 'P.name' because property type 'String' does not conform to 'Differentiable'
42+
// public struct P<Value>: Differentiable
43+
// ^
44+
// ```
45+
//
46+
// WHY?
47+
// The `TangentVector` for `P` has been defined as `Self`.
48+
// The pullback generation for `P.init` fails because it now
49+
// contains a projection to a non-differentiable stored
50+
// property, which is not allowed currently - https://github.com/apple/swift/blob/main/lib/SILOptimizer/Differentiation/PullbackCloner.cpp#L1850.
51+
//
52+
// The same issue does not happen if we let the compiler
53+
// generated `P`'s `TangentVector`, because `P.init` now
54+
// becomes differentiable.
55+
//
56+
// SO WHAT?
57+
// It feels like a bit of an edge case that for a user-defined tangent
58+
// vector, such as that for `P`, we cannot have a function mapping
59+
// individual member fields to `P`.
60+
//
61+
// NOW WHAT?
62+
// Is this edge case something that we need to handle?
63+
//
64+
// @differentiable(reverse)
65+
// internal func testFunction2(data: Double) -> P<Double> {
66+
// P(value: data)
67+
// }

0 commit comments

Comments
 (0)