Skip to content

Commit da03d99

Browse files
authored
[NFC] [AutoDiff] Improve SILGen linear map thunking documentation. (#26557)
- Add documentation comments clarifying differential/pullback thunk self reordering logic. - Change `SILGenFunction::getThunkedAutoDiffLinearMap` to use `AutoDiffLinearMapKind`. - Various gardening.
1 parent 083af37 commit da03d99

File tree

4 files changed

+45
-33
lines changed

4 files changed

+45
-33
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -531,37 +531,40 @@ inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
531531
return s;
532532
}
533533

534-
/// The kind of an associated function.
535-
struct AutoDiffAssociatedFunctionKind {
534+
/// The kind of an linear map.
535+
struct AutoDiffLinearMapKind {
536536
enum innerty : uint8_t {
537-
// The Jacobian-vector products function.
538-
JVP = 0,
539-
// The vector-Jacobian products function.
540-
VJP = 1
537+
// The differential function.
538+
Differential = 0,
539+
// The pullback function.
540+
Pullback = 1
541541
} rawValue;
542542

543-
AutoDiffAssociatedFunctionKind() = default;
544-
AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {}
545-
explicit AutoDiffAssociatedFunctionKind(StringRef string);
543+
AutoDiffLinearMapKind() = default;
544+
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
546545
operator innerty() const { return rawValue; }
547546
};
548547

549-
/// The kind of an linear map.
550-
struct AutoDiffLinearMapKind {
548+
/// The kind of an associated function.
549+
struct AutoDiffAssociatedFunctionKind {
551550
enum innerty : uint8_t {
552-
// The differential function.
553-
Differential = 0,
554-
// The pullback function.
555-
Pullback = 1
551+
// The Jacobian-vector products function.
552+
JVP = 0,
553+
// The vector-Jacobian products function.
554+
VJP = 1
556555
} rawValue;
557556

558-
AutoDiffLinearMapKind() = default;
559-
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
557+
AutoDiffAssociatedFunctionKind() = default;
558+
AutoDiffAssociatedFunctionKind(innerty rawValue) : rawValue(rawValue) {}
559+
explicit AutoDiffAssociatedFunctionKind(StringRef string);
560560
operator innerty() const { return rawValue; }
561+
AutoDiffLinearMapKind getLinearMapKind() {
562+
return (AutoDiffLinearMapKind::innerty)rawValue;
563+
}
561564
};
562565

563-
/// In conjunction with the original function decl, identifies an associated
564-
/// autodiff function.
566+
/// In conjunction with the original function declaration, identifies an
567+
/// autodiff associated function.
565568
///
566569
/// Is uniquely allocated within an ASTContext so that it can be hashed and
567570
/// compared by opaque pointer value.

lib/AST/AutoDiff.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ AutoDiffAssociatedFunctionKind::
2929
AutoDiffAssociatedFunctionKind(StringRef string) {
3030
Optional<innerty> result =
3131
llvm::StringSwitch<Optional<innerty>>(string)
32-
.Case("jvp", JVP).Case("vjp", VJP);
32+
.Case("jvp", JVP).Case("vjp", VJP);
3333
assert(result && "Invalid string");
3434
rawValue = *result;
3535
}

lib/SILGen/SILGenFunction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1788,7 +1788,7 @@ class LLVM_LIBRARY_VISIBILITY SILGenFunction
17881788
/// - The last parameter, for differentials.
17891789
/// - The last result, for pullbacks.
17901790
ManagedValue getThunkedAutoDiffLinearMap(
1791-
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
1791+
ManagedValue linearMap, AutoDiffLinearMapKind linearMapKind,
17921792
CanSILFunctionType fromType, CanSILFunctionType toType,
17931793
bool reorderSelf);
17941794

lib/SILGen/SILGenPoly.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3394,7 +3394,7 @@ static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
33943394
/// Adapted from `SILGenModule::getOrCreateReabstractionThunk`.
33953395
ManagedValue
33963396
SILGenFunction::getThunkedAutoDiffLinearMap(
3397-
ManagedValue linearMap, AutoDiffAssociatedFunctionKind assocFnKind,
3397+
ManagedValue linearMap, AutoDiffLinearMapKind linearMapKind,
33983398
CanSILFunctionType fromType, CanSILFunctionType toType,
33993399
bool reorderSelf) {
34003400
// Compute the thunk type.
@@ -3418,11 +3418,11 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34183418
thunkType, fromInterfaceType, toInterfaceType,
34193419
Type(), getModule().getSwiftModule());
34203420
// TODO(TF-685): Use principled thunk mangling.
3421-
switch (assocFnKind) {
3422-
case AutoDiffAssociatedFunctionKind::JVP:
3421+
switch (linearMapKind) {
3422+
case AutoDiffLinearMapKind::Differential:
34233423
name += "_differential";
34243424
break;
3425-
case AutoDiffAssociatedFunctionKind::VJP:
3425+
case AutoDiffLinearMapKind::Pullback:
34263426
name += "_pullback";
34273427
break;
34283428
}
@@ -3476,20 +3476,30 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
34763476
// - If self is direct, reorder direct results after `apply` is generated.
34773477
// - For differentials: reorder parameter infos and arguments.
34783478
auto numIndirectResults = thunkIndirectResults.size();
3479-
if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::VJP &&
3479+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback &&
34803480
toResults.size() > 1) {
34813481
auto toSelfResult = toResults.back();
34823482
if (toSelfResult.isFormalIndirect() && numIndirectResults > 1) {
3483+
// Before: [ind_res1, ind_res2, ..., ind_res_self, arg1, arg2, ..., pb]
3484+
// After: [ind_res_self, ind_res1, ind_res2, ..., arg1, arg2, ..., pb]
34833485
std::rotate(thunkArguments.begin(),
34843486
thunkArguments.begin() + numIndirectResults - 1,
34853487
thunkArguments.begin() + numIndirectResults);
3488+
// Before: [ind_res1, ind_res2, ..., ind_res_self]
3489+
// After: [ind_res_self, ind_res1, ind_res2, ...]
3490+
std::rotate(thunkIndirectResults.begin(), thunkIndirectResults.end() - 1,
3491+
thunkIndirectResults.end());
34863492
}
34873493
std::rotate(toResults.begin(), toResults.end() - 1, toResults.end());
34883494
}
3489-
if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::JVP &&
3495+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Differential &&
34903496
thunkArguments.size() > 1) {
3497+
// Before: [ind_res1, ind_res2, ..., arg1, arg2, ..., arg_self, df]
3498+
// After: [ind_res1, ind_res2, ..., arg_self, arg1, arg2, ..., df]
34913499
std::rotate(thunkArguments.begin() + numIndirectResults,
34923500
thunkArguments.end() - 2, thunkArguments.end() - 1);
3501+
// Before: [arg1, arg2, ..., arg_self]
3502+
// After: [arg_self, arg1, arg2, ...]
34933503
std::rotate(toParameters.begin(), toParameters.end() - 1,
34943504
toParameters.end());
34953505
}
@@ -3589,14 +3599,12 @@ SILGenFunction::getThunkedAutoDiffLinearMap(
35893599

35903600
// Handle self reordering.
35913601
// For pullbacks: rotate direct results if self is direct.
3592-
if (reorderSelf && assocFnKind == AutoDiffAssociatedFunctionKind::VJP) {
3602+
if (reorderSelf && linearMapKind == AutoDiffLinearMapKind::Pullback) {
35933603
auto fromSelfResult = fromConv.getResults().front();
35943604
auto toSelfResult = toConv.getResults().back();
35953605
assert(fromSelfResult.getType() == toSelfResult.getType());
3596-
if (toSelfResult.isFormalIndirect() && thunkIndirectResults.size() > 1) {
3597-
std::rotate(thunkIndirectResults.begin(), thunkIndirectResults.end() - 1,
3598-
thunkIndirectResults.end());
3599-
}
3606+
// Before: [dir_res_self, dir_res1, dir_res2, ...]
3607+
// After: [dir_res1, dir_res2, ..., dir_res_self]
36003608
if (toSelfResult.isFormalDirect() && fromSelfResult.isFormalDirect() &&
36013609
directResults.size() > 1) {
36023610
std::rotate(directResults.begin(), directResults.begin() + 1,
@@ -3802,8 +3810,9 @@ SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
38023810
}
38033811

38043812
// Otherwise, apply reabstraction/self reordering thunk to linear map.
3813+
auto linearMapKind = assocFnKind.getLinearMapKind();
38053814
linearMap = thunkSGF.getThunkedAutoDiffLinearMap(
3806-
linearMap, assocFnKind, linearMapFnType, targetLinearMapFnType,
3815+
linearMap, linearMapKind, linearMapFnType, targetLinearMapFnType,
38073816
reorderSelf);
38083817

38093818
// Return original results and thunked differential/pullback.

0 commit comments

Comments
 (0)