Skip to content

Commit 6688575

Browse files
committed
[AudoDiff] NFC: Replace 'SILAutoDiffIndices' with 'AutoDiffConfig'.
Resolve rdar://71678394 / SR-13889.
1 parent e8c7714 commit 6688575

22 files changed

+293
-308
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -173,80 +173,59 @@ enum class AutoDiffGeneratedDeclarationKind : uint8_t {
173173
BranchingTraceEnum
174174
};
175175

176-
/// SIL-level automatic differentiation indices. Consists of:
177-
/// - The differentiability parameter indices.
178-
/// - The differentiability result indices.
179-
// TODO(TF-913): Remove `SILAutoDiffIndices` in favor of `AutoDiffConfig`.
180-
// `AutoDiffConfig` additionally stores a derivative generic signature.
181-
struct SILAutoDiffIndices {
182-
/// The indices of independent parameters to differentiate with respect to.
183-
IndexSubset *parameters;
184-
/// The indices of dependent results to differentiate from.
185-
IndexSubset *results;
186-
187-
/*implicit*/ SILAutoDiffIndices(IndexSubset *parameters, IndexSubset *results)
188-
: parameters(parameters), results(results) {
189-
assert(parameters && "Parameter indices must be non-null");
190-
assert(results && "Result indices must be non-null");
191-
}
192-
193-
bool operator==(const SILAutoDiffIndices &other) const;
194-
195-
bool operator!=(const SILAutoDiffIndices &other) const {
196-
return !(*this == other);
197-
};
176+
/// Identifies an autodiff derivative function configuration:
177+
/// - Parameter indices.
178+
/// - Result indices.
179+
/// - Derivative generic signature (optional).
180+
struct AutoDiffConfig {
181+
IndexSubset *parameterIndices;
182+
IndexSubset *resultIndices;
183+
GenericSignature derivativeGenericSignature;
184+
185+
/*implicit*/ AutoDiffConfig(
186+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
187+
GenericSignature derivativeGenericSignature = GenericSignature())
188+
: parameterIndices(parameterIndices), resultIndices(resultIndices),
189+
derivativeGenericSignature(derivativeGenericSignature) {}
198190

199191
/// Returns true if `parameterIndex` is a differentiability parameter index.
200192
bool isWrtParameter(unsigned parameterIndex) const {
201-
return parameterIndex < parameters->getCapacity() &&
202-
parameters->contains(parameterIndex);
193+
return parameterIndex < parameterIndices->getCapacity() &&
194+
parameterIndices->contains(parameterIndex);
203195
}
204196

205-
void print(llvm::raw_ostream &s = llvm::outs()) const;
206-
SWIFT_DEBUG_DUMP;
197+
/// Returns true if `resultIndex` is a differentiability result index.
198+
bool isWrtResult(unsigned resultIndex) const {
199+
return resultIndex < resultIndices->getCapacity() &&
200+
resultIndices->contains(resultIndex);
201+
}
207202

203+
AutoDiffConfig withGenericSignature(GenericSignature signature) const {
204+
return AutoDiffConfig(parameterIndices, resultIndices, signature);
205+
}
206+
207+
// TODO(SR-13506): Use principled mangling for AD-generated symbols.
208208
std::string mangle() const {
209209
std::string result = "src_";
210210
interleave(
211-
results->getIndices(),
211+
resultIndices->getIndices(),
212212
[&](unsigned idx) { result += llvm::utostr(idx); },
213213
[&] { result += '_'; });
214214
result += "_wrt_";
215215
llvm::interleave(
216-
parameters->getIndices(),
216+
parameterIndices->getIndices(),
217217
[&](unsigned idx) { result += llvm::utostr(idx); },
218218
[&] { result += '_'; });
219219
return result;
220220
}
221-
};
222-
223-
/// Identifies an autodiff derivative function configuration:
224-
/// - Parameter indices.
225-
/// - Result indices.
226-
/// - Derivative generic signature (optional).
227-
struct AutoDiffConfig {
228-
IndexSubset *parameterIndices;
229-
IndexSubset *resultIndices;
230-
GenericSignature derivativeGenericSignature;
231-
232-
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
233-
IndexSubset *resultIndices,
234-
GenericSignature derivativeGenericSignature)
235-
: parameterIndices(parameterIndices), resultIndices(resultIndices),
236-
derivativeGenericSignature(derivativeGenericSignature) {}
237-
238-
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
239-
// TODO(TF-913): This is a temporary shim for incremental removal of
240-
// `SILAutoDiffIndices`. Eventually remove this.
241-
SILAutoDiffIndices getSILAutoDiffIndices() const;
242221

243222
void print(llvm::raw_ostream &s = llvm::outs()) const;
244223
SWIFT_DEBUG_DUMP;
245224
};
246225

247226
inline llvm::raw_ostream &operator<<(llvm::raw_ostream &s,
248-
const SILAutoDiffIndices &indices) {
249-
indices.print(s);
227+
const AutoDiffConfig &config) {
228+
config.print(s);
250229
return s;
251230
}
252231

include/swift/AST/IndexSubset.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class IndexSubset : public llvm::FoldingSetNode {
134134
/// Returns the number of bit words used to store the index subset.
135135
// Note: Use `getCapacity()` to get the total index subset capacity.
136136
// This is public only for unit testing
137-
// (in unittests/AST/SILAutoDiffIndices.cpp).
137+
// (in unittests/AST/IndexSubsetTests.cpp).
138138
unsigned getNumBitWords() const {
139139
return numBitWords;
140140
}

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,6 @@ class SILDifferentiabilityWitness
133133
bool isSerialized() const { return IsSerialized; }
134134
const DeclAttribute *getAttribute() const { return Attribute; }
135135

136-
/// Returns the `SILAutoDiffIndices` corresponding to this config's indices.
137-
// TODO(TF-893): This is a temporary shim for incremental removal of
138-
// `SILAutoDiffIndices`. Eventually remove this.
139-
SILAutoDiffIndices getSILAutoDiffIndices() const;
140-
141136
/// Verify that the differentiability witness is well-formed.
142137
void verify(const SILModule &module) const;
143138

include/swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,21 +200,47 @@ class DifferentiableActivityInfo {
200200
/// (dependent variable) indices.
201201
bool isUseful(SILValue value, IndexSubset *dependentVariableIndices) const;
202202

203-
/// Returns true if the given value is active for the given
204-
/// `SILAutoDiffIndices` (parameter indices and result index).
205-
bool isActive(SILValue value, const SILAutoDiffIndices &indices) const;
203+
/// Returns true if the given value is active for the given parameter indices
204+
/// and result indices.
205+
bool isActive(SILValue value,
206+
IndexSubset *parameterIndices,
207+
IndexSubset *resultIndices) const;
208+
209+
/// Returns true if the given value is active for the given config.
210+
bool isActive(SILValue value, AutoDiffConfig config) const {
211+
return isActive(value, config.parameterIndices, config.resultIndices);
212+
}
213+
214+
/// Returns the activity of the given value for the given config.
215+
Activity getActivity(SILValue value,
216+
IndexSubset *parameterIndices,
217+
IndexSubset *resultIndices) const;
206218

207-
/// Returns the activity of the given value for the given `SILAutoDiffIndices`
208-
/// (parameter indices and result index).
209-
Activity getActivity(SILValue value, const SILAutoDiffIndices &indices) const;
219+
/// Returns the activity of the given value for the given config.
220+
Activity getActivity(SILValue value, AutoDiffConfig config) const {
221+
return getActivity(value, config.parameterIndices, config.resultIndices);
222+
}
210223

211-
/// Prints activity information for the `indices` of the given `value`.
212-
void dump(SILValue value, const SILAutoDiffIndices &indices,
224+
/// Prints activity information for the config of the given value.
225+
void dump(SILValue value,
226+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
213227
llvm::raw_ostream &s = llvm::dbgs()) const;
214228

215-
/// Prints activity information for the given `indices`.
216-
void dump(SILAutoDiffIndices indices,
229+
/// Prints activity information for the config of the given value.
230+
void dump(SILValue value, AutoDiffConfig config,
231+
llvm::raw_ostream &s = llvm::dbgs()) const {
232+
return dump(value, config.parameterIndices, config.resultIndices, s);
233+
}
234+
235+
/// Prints all activity information for the given parameter indices and result
236+
/// indices.
237+
void dump(IndexSubset *parameterIndices, IndexSubset *resultIndices,
217238
llvm::raw_ostream &s = llvm::dbgs()) const;
239+
240+
/// Prints all activity information for the given config.
241+
void dump(AutoDiffConfig config, llvm::raw_ostream &s = llvm::dbgs()) const {
242+
return dump(config.parameterIndices, config.resultIndices, s);
243+
}
218244
};
219245

220246
class DifferentiableActivityCollection {

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ namespace autodiff {
4545

4646
/// Stores `apply` instruction information calculated by VJP generation.
4747
struct NestedApplyInfo {
48-
/// The differentiation indices that are used to differentiate this `apply`
48+
/// The differentiation config that is used to differentiate this `apply`
4949
/// instruction.
50-
SILAutoDiffIndices indices;
50+
AutoDiffConfig config;
5151
/// The original pullback type before reabstraction. `None` if the pullback
5252
/// type is not reabstracted.
5353
Optional<CanSILFunctionType> originalPullbackType;
@@ -120,6 +120,9 @@ class ADContext {
120120
/// Construct an ADContext for the given module.
121121
explicit ADContext(SILModuleTransform &transform);
122122

123+
// No copying.
124+
ADContext(const ADContext &) = delete;
125+
123126
//--------------------------------------------------------------------------//
124127
// General utilities
125128
//--------------------------------------------------------------------------//

include/swift/SILOptimizer/Differentiation/Common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ void collectAllActualResultsInTypeOrder(
119119
/// - The set of minimal parameter and result indices for differentiating the
120120
/// `apply` instruction.
121121
void collectMinimalIndicesForFunctionCall(
122-
ApplyInst *ai, SILAutoDiffIndices parentIndices,
122+
ApplyInst *ai, AutoDiffConfig parentConfig,
123123
const DifferentiableActivityInfo &activityInfo,
124124
SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
125125
SmallVectorImpl<unsigned> &resultIndices);

include/swift/SILOptimizer/Differentiation/LinearMapInfo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class LinearMapInfo {
6767
SILLoopInfo *loopInfo;
6868

6969
/// Differentiation indices of the function.
70-
const SILAutoDiffIndices indices;
70+
const AutoDiffConfig config;
7171

7272
/// Mapping from original basic blocks to linear map structs.
7373
llvm::DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;
@@ -149,7 +149,7 @@ class LinearMapInfo {
149149

150150
explicit LinearMapInfo(ADContext &context, AutoDiffLinearMapKind kind,
151151
SILFunction *original, SILFunction *derivative,
152-
SILAutoDiffIndices indices,
152+
AutoDiffConfig config,
153153
const DifferentiableActivityInfo &activityInfo,
154154
SILLoopInfo *loopInfo);
155155

include/swift/SILOptimizer/Differentiation/Thunk.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ SILValue reabstractFunction(
106106
std::pair<SILFunction *, SubstitutionMap>
107107
getOrCreateSubsetParametersThunkForDerivativeFunction(
108108
SILOptFunctionBuilder &fb, SILValue origFnOperand, SILValue derivativeFn,
109-
AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
110-
SILAutoDiffIndices actualIndices);
109+
AutoDiffDerivativeFunctionKind kind, AutoDiffConfig desiredConfig,
110+
AutoDiffConfig actualConfig);
111111

112112
/// Get or create a derivative function parameter index subset thunk from
113113
/// `actualIndices` to `desiredIndices` for the given associated function
@@ -119,7 +119,7 @@ getOrCreateSubsetParametersThunkForLinearMap(
119119
SILOptFunctionBuilder &fb, SILFunction *assocFn,
120120
CanSILFunctionType origFnType, CanSILFunctionType linearMapType,
121121
CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
122-
SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);
122+
AutoDiffConfig desiredConfig, AutoDiffConfig actualConfig);
123123

124124
} // end namespace autodiff
125125

include/swift/SILOptimizer/Differentiation/VJPCloner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class VJPCloner final {
5050
SILFunction &getVJP() const;
5151
SILFunction &getPullback() const;
5252
SILDifferentiabilityWitness *getWitness() const;
53-
const SILAutoDiffIndices getIndices() const;
53+
AutoDiffConfig getConfig() const;
5454
DifferentiationInvoker getInvoker() const;
5555
LinearMapInfo &getPullbackInfo() const;
5656
SILLoopInfo *getLoopInfo() const;

lib/AST/ASTMangler.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ std::string ASTMangler::mangleAutoDiffDerivativeFunctionHelper(
420420
Buffer << "_vjp_";
421421
break;
422422
}
423-
Buffer << config.getSILAutoDiffIndices().mangle();
423+
Buffer << config.mangle();
424424
if (config.derivativeGenericSignature) {
425425
Buffer << '_';
426426
appendGenericSignature(config.derivativeGenericSignature);
@@ -445,7 +445,7 @@ std::string ASTMangler::mangleAutoDiffLinearMapHelper(
445445
Buffer << "_pullback_";
446446
break;
447447
}
448-
Buffer << config.getSILAutoDiffIndices().mangle();
448+
Buffer << config.mangle();
449449
if (config.derivativeGenericSignature) {
450450
Buffer << '_';
451451
appendGenericSignature(config.derivativeGenericSignature);
@@ -484,7 +484,7 @@ std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
484484
}
485485
break;
486486
}
487-
Buffer << config.getSILAutoDiffIndices().mangle();
487+
Buffer << config.mangle();
488488
if (config.derivativeGenericSignature) {
489489
Buffer << '_';
490490
appendGenericSignature(config.derivativeGenericSignature);

lib/AST/AutoDiff.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,26 +97,6 @@ DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
9797
llvm_unreachable("invalid derivative kind");
9898
}
9999

100-
void SILAutoDiffIndices::print(llvm::raw_ostream &s) const {
101-
s << "(parameters=(";
102-
interleave(
103-
parameters->getIndices(), [&s](unsigned p) { s << p; },
104-
[&s] { s << ' '; });
105-
s << ") results=(";
106-
interleave(
107-
results->getIndices(), [&s](unsigned p) { s << p; }, [&s] { s << ' '; });
108-
s << "))";
109-
}
110-
111-
void SILAutoDiffIndices::dump() const {
112-
print(llvm::errs());
113-
llvm::errs() << '\n';
114-
}
115-
116-
SILAutoDiffIndices AutoDiffConfig::getSILAutoDiffIndices() const {
117-
return SILAutoDiffIndices(parameterIndices, resultIndices);
118-
}
119-
120100
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
121101
s << "(parameters=";
122102
parameterIndices->print(s);

lib/SIL/IR/SILDifferentiabilityWitness.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,3 @@ void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp,
7474
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
7575
return std::make_pair(getOriginalFunction()->getName(), getConfig());
7676
}
77-
78-
SILAutoDiffIndices SILDifferentiabilityWitness::getSILAutoDiffIndices() const {
79-
return getConfig().getSILAutoDiffIndices();
80-
}

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -509,26 +509,27 @@ bool DifferentiableActivityInfo::isUseful(
509509
}
510510

511511
bool DifferentiableActivityInfo::isActive(
512-
SILValue value, const SILAutoDiffIndices &indices) const {
513-
return isVaried(value, indices.parameters) &&
514-
isUseful(value, indices.results);
512+
SILValue value, IndexSubset *parameterIndices,
513+
IndexSubset *resultIndices) const {
514+
return isVaried(value, parameterIndices) && isUseful(value, resultIndices);
515515
}
516516

517517
Activity DifferentiableActivityInfo::getActivity(
518-
SILValue value, const SILAutoDiffIndices &indices) const {
518+
SILValue value, IndexSubset *parameterIndices,
519+
IndexSubset *resultIndices) const {
519520
Activity activity;
520-
if (isVaried(value, indices.parameters))
521+
if (isVaried(value, parameterIndices))
521522
activity |= ActivityFlags::Varied;
522-
if (isUseful(value, indices.results))
523+
if (isUseful(value, resultIndices))
523524
activity |= ActivityFlags::Useful;
524525
return activity;
525526
}
526527

527-
void DifferentiableActivityInfo::dump(SILValue value,
528-
const SILAutoDiffIndices &indices,
529-
llvm::raw_ostream &s) const {
528+
void DifferentiableActivityInfo::dump(
529+
SILValue value, IndexSubset *parameterIndices, IndexSubset *resultIndices,
530+
llvm::raw_ostream &s) const {
530531
s << '[';
531-
auto activity = getActivity(value, indices);
532+
auto activity = getActivity(value, parameterIndices, resultIndices);
532533
switch (activity.toRaw()) {
533534
case 0:
534535
s << "NONE";
@@ -546,19 +547,24 @@ void DifferentiableActivityInfo::dump(SILValue value,
546547
s << "] " << value;
547548
}
548549

549-
void DifferentiableActivityInfo::dump(SILAutoDiffIndices indices,
550-
llvm::raw_ostream &s) const {
550+
void DifferentiableActivityInfo::dump(
551+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
552+
llvm::raw_ostream &s) const {
551553
SILFunction &fn = getFunction();
552-
s << "Activity info for " << fn.getName() << " at " << indices << '\n';
554+
s << "Activity info for " << fn.getName() << " at parameter indices (";
555+
llvm::interleaveComma(parameterIndices->getIndices(), s);
556+
s << ") and result indices (";
557+
llvm::interleaveComma(resultIndices->getIndices(), s);
558+
s << "):\n";
553559
for (auto &bb : fn) {
554560
s << "bb" << bb.getDebugID() << ":\n";
555561
for (auto *arg : bb.getArguments())
556-
dump(arg, indices, s);
562+
dump(arg, parameterIndices, resultIndices, s);
557563
for (auto &inst : bb)
558564
for (auto res : inst.getResults())
559-
dump(res, indices, s);
565+
dump(res, parameterIndices, resultIndices, s);
560566
if (std::next(bb.getIterator()) != fn.end())
561567
s << '\n';
562568
}
563-
s << "End activity info for " << fn.getName() << " at " << indices << "\n\n";
569+
s << "End activity info for " << fn.getName() << '\n';
564570
}

0 commit comments

Comments
 (0)