Skip to content

Commit fe20afb

Browse files
committed
[AutoDiff] NFC: gardening.
Remove `SILAutoDiffIndices` argument from `LinearMapInfo` methods. Use the `SILAutoDiffIndices` stored in `LinearMapInfo` instead.
1 parent 04686a0 commit fe20afb

File tree

2 files changed

+12
-18
lines changed

2 files changed

+12
-18
lines changed

include/swift/SILOptimizer/Utils/Differentiation/LinearMapInfo.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,30 +113,26 @@ class LinearMapInfo {
113113
/// whose cases represent the predecessors/successors of the given original
114114
/// block.
115115
EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
116-
SILAutoDiffIndices indices,
117116
CanGenericSignature genericSig,
118117
SILLoopInfo *loopInfo);
119118

120119
/// Creates a struct declaration with the given JVP/VJP generic signature, for
121120
/// storing the linear map values and predecessor/successor basic block of the
122121
/// given original block.
123122
StructDecl *createLinearMapStruct(SILBasicBlock *originalBB,
124-
SILAutoDiffIndices indices,
125123
CanGenericSignature genericSig);
126124

127125
/// Adds a linear map field to the linear map struct.
128126
VarDecl *addLinearMapDecl(ApplyInst *ai, SILType linearMapType);
129127

130128
/// Given an `apply` instruction, conditionally adds a linear map struct field
131129
/// for its linear map function if it is active.
132-
void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
133-
SILAutoDiffIndices indices);
130+
void addLinearMapToStruct(ADContext &context, ApplyInst *ai);
134131

135132
/// Generates linear map struct and branching enum declarations for the given
136133
/// function. Linear map structs are populated with linear map fields and a
137134
/// branching enum field.
138135
void generateDifferentiationDataStructures(ADContext &context,
139-
SILAutoDiffIndices indices,
140136
SILFunction *derivative);
141137

142138
public:

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ LinearMapInfo::LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
5959
: kind(kind), original(original), derivative(derivative),
6060
activityInfo(activityInfo), indices(indices),
6161
typeConverter(context.getTypeConverter()) {
62-
generateDifferentiationDataStructures(context, indices, derivative);
62+
generateDifferentiationDataStructures(context, derivative);
6363
}
6464

6565
SILType LinearMapInfo::remapTypeInDerivative(SILType ty) {
@@ -122,9 +122,10 @@ void LinearMapInfo::computeAccessLevel(NominalTypeDecl *nominal,
122122
}
123123
}
124124

125-
EnumDecl *LinearMapInfo::createBranchingTraceDecl(
126-
SILBasicBlock *originalBB, SILAutoDiffIndices indices,
127-
CanGenericSignature genericSig, SILLoopInfo *loopInfo) {
125+
EnumDecl *
126+
LinearMapInfo::createBranchingTraceDecl(SILBasicBlock *originalBB,
127+
CanGenericSignature genericSig,
128+
SILLoopInfo *loopInfo) {
128129
assert(originalBB->getParent() == original);
129130
auto &astCtx = original->getASTContext();
130131
auto *moduleDecl = original->getModule().getSwiftModule();
@@ -195,7 +196,6 @@ EnumDecl *LinearMapInfo::createBranchingTraceDecl(
195196

196197
StructDecl *
197198
LinearMapInfo::createLinearMapStruct(SILBasicBlock *originalBB,
198-
SILAutoDiffIndices indices,
199199
CanGenericSignature genericSig) {
200200
assert(originalBB->getParent() == original);
201201
auto *original = originalBB->getParent();
@@ -267,8 +267,7 @@ VarDecl *LinearMapInfo::addLinearMapDecl(ApplyInst *ai, SILType linearMapType) {
267267
return linearMapDecl;
268268
}
269269

270-
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
271-
SILAutoDiffIndices indices) {
270+
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai) {
272271
SmallVector<SILValue, 4> allResults;
273272
SmallVector<unsigned, 8> activeParamIndices;
274273
SmallVector<unsigned, 8> activeResultIndices;
@@ -372,7 +371,7 @@ void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
372371
}
373372

374373
void LinearMapInfo::generateDifferentiationDataStructures(
375-
ADContext &context, SILAutoDiffIndices indices, SILFunction *derivativeFn) {
374+
ADContext &context, SILFunction *derivativeFn) {
376375
auto &astCtx = original->getASTContext();
377376
auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
378377
auto *loopInfo = loopAnalysis->get(original);
@@ -385,8 +384,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
385384

386385
// Create linear map struct for each original block.
387386
for (auto &origBB : *original) {
388-
auto *linearMapStruct =
389-
createLinearMapStruct(&origBB, indices, derivativeFnGenSig);
387+
auto *linearMapStruct = createLinearMapStruct(&origBB, derivativeFnGenSig);
390388
linearMapStructs.insert({&origBB, linearMapStruct});
391389
}
392390

@@ -402,8 +400,8 @@ void LinearMapInfo::generateDifferentiationDataStructures(
402400
break;
403401
}
404402
for (auto &origBB : *original) {
405-
auto *traceEnum = createBranchingTraceDecl(&origBB, indices,
406-
derivativeFnGenSig, loopInfo);
403+
auto *traceEnum =
404+
createBranchingTraceDecl(&origBB, derivativeFnGenSig, loopInfo);
407405
branchingTraceDecls.insert({&origBB, traceEnum});
408406
if (origBB.isEntry())
409407
continue;
@@ -426,7 +424,7 @@ void LinearMapInfo::generateDifferentiationDataStructures(
426424
continue;
427425
LLVM_DEBUG(getADDebugStream()
428426
<< "Adding linear map struct field for " << *ai);
429-
addLinearMapToStruct(context, ai, indices);
427+
addLinearMapToStruct(context, ai);
430428
}
431429
}
432430
}

0 commit comments

Comments
 (0)