Skip to content

Commit dfa32b8

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] NFC: prettify commutative diagrams. (swiftlang#31525)
Make commutative diagrams pretty using Unicode box characters. Use lowercase letters for arrow names.
1 parent 4791625 commit dfa32b8

File tree

2 files changed

+32
-30
lines changed

2 files changed

+32
-30
lines changed

include/swift/SIL/TypeSubstCloner.h

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -322,28 +322,29 @@ class TypeSubstCloner : public SILClonerWithScopes<ImplClass> {
322322
return;
323323
}
324324
// If the extractee is a derivative function, check whether the *remapped
325-
// derivative function type* (BC) is equal to the *derivative remapped
326-
// function type* (AD).
325+
// derivative function type* (bc) is equal to the *derivative remapped
326+
// function type* (ad).
327327
//
328-
// +----------------+ remap +-------------------------+
329-
// | orig. fn type | -------(A)------> | remapped orig. fn type |
330-
// +----------------+ +-------------------------+
331-
// | |
332-
// (B, SILGen) getAutoDiffDerivativeFunctionType (D, here)
333-
// V V
334-
// +----------------+ remap +-------------------------+
335-
// | deriv. fn type | -------(C)------> | remapped deriv. fn type |
336-
// +----------------+ +-------------------------+
328+
// ┌────────────────┐ remap ┌─────────────────────────┐
329+
// │ orig. fn type │ ───────(a)──────► │ remapped orig. fn type │
330+
// └────────────────┘ └─────────────────────────┘
331+
// │ │
332+
// (b, SILGen) getAutoDiffDerivativeFunctionType (d, here)
333+
// │ │
334+
// ▼ ▼
335+
// ┌────────────────┐ remap ┌─────────────────────────┐
336+
// │ deriv. fn type │ ───────(c)──────► │ remapped deriv. fn type │
337+
// └────────────────┘ └─────────────────────────┘
337338
//
338-
// (AD) does not always commute with (BC):
339-
// - (AD) is the result of remapping, then computing the derivative type.
339+
// (ad) does not always commute with (bc):
340+
// - (ad) is the result of remapping, then computing the derivative type.
340341
// This is the default cloning behavior, but may break invariants in the
341342
// initial SIL generated by SILGen.
342-
// - (BC) is the result of computing the derivative type (SILGen), then
343+
// - (bc) is the result of computing the derivative type (SILGen), then
343344
// remapping. This is the expected type, preserving invariants from
344345
// earlier transforms.
345346
//
346-
// If (AD) is not equal to (BC), use (BC) as the explicit type.
347+
// If (ad) is not equal to (bc), use (bc) as the explicit type.
347348
SILType remappedOrigType = getOpType(dfei->getOperand()->getType());
348349
auto remappedOrigFnType = remappedOrigType.castTo<SILFunctionType>();
349350
auto derivativeRemappedFnType =

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3144,30 +3144,31 @@ TypeConverter::getConstantInfo(TypeExpansionContext expansion,
31443144
// If the constant refers to a derivative function, get the SIL type of the
31453145
// original function and use it to compute the derivative SIL type.
31463146
//
3147-
// This is necessary because the "lowered AST derivative function type" (BC)
3147+
// This is necessary because the "lowered AST derivative function type" (bc)
31483148
// may differ from the "derivative type of the lowered original function type"
3149-
// (AD):
3149+
// (ad):
31503150
//
3151-
// +--------------------+ lowering +--------------------+
3152-
// | AST orig. fn type | -------(A)------> | SIL orig. fn type |
3153-
// +--------------------+ +--------------------+
3154-
// | |
3155-
// (B, Sema) getAutoDiffDerivativeFunctionType (D, here)
3156-
// V V
3157-
// +--------------------+ lowering +--------------------+
3158-
// | AST deriv. fn type | -------(C)------> | SIL deriv. fn type |
3159-
// +--------------------+ +--------------------+
3151+
// ┌────────────────────┐ lowering ┌────────────────────┐
3152+
// │ AST orig. fn type │ ───────(a)──────► │ SIL orig. fn type │
3153+
// └────────────────────┘ └────────────────────┘
3154+
// │ │
3155+
// (b, Sema) getAutoDiffDerivativeFunctionType (d, here)
3156+
// │ │
3157+
// ▼ ▼
3158+
// ┌────────────────────┐ lowering ┌────────────────────┐
3159+
// │ AST deriv. fn type │ ───────(c)──────► │ SIL deriv. fn type │
3160+
// └────────────────────┘ └────────────────────┘
31603161
//
3161-
// (AD) does not always commute with (BC):
3162-
// - (BC) is the result of computing the AST derivative type (Sema), then
3162+
// (ad) does not always commute with (bc):
3163+
// - (bc) is the result of computing the AST derivative type (Sema), then
31633164
// lowering it via SILGen. This is the default lowering behavior, but may
31643165
// break SIL typing invariants because expected lowered derivative types are
31653166
// computed from lowered original function types.
3166-
// - (AD) is the result of lowering the original function type, then computing
3167+
// - (ad) is the result of lowering the original function type, then computing
31673168
// its derivative type. This is the expected lowered derivative type,
31683169
// preserving SIL typing invariants.
31693170
//
3170-
// Always use (AD) to compute lowered derivative function types.
3171+
// Always use (ad) to compute lowered derivative function types.
31713172
if (auto *derivativeId = constant.derivativeFunctionIdentifier) {
31723173
// Get lowered original function type.
31733174
auto origFnConstantInfo = getConstantInfo(

0 commit comments

Comments
 (0)