Skip to content

Commit 1a81573

Browse files
authored
[AutoDiff] Start linear_function canonicalization skeleton (#33057)
Start `linear_function` canonicalization skeleton copying from `differentiable_function` canonicalization. For now, transpose function operands are filled in with `undef`.
1 parent cde03e6 commit 1a81573

File tree

8 files changed

+244
-42
lines changed

8 files changed

+244
-42
lines changed

include/swift/SILOptimizer/Differentiation/ADContext.h

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,22 @@ class ADContext {
7373
llvm::SmallVector<DifferentiableFunctionInst *, 32>
7474
differentiableFunctionInsts;
7575

76+
/// The worklist (stack) of `linear_function` instructions to be processed.
77+
llvm::SmallVector<LinearFunctionInst *, 32> linearFunctionInsts;
78+
7679
/// The set of `differentiable_function` instructions that have been
7780
/// processed. Used to avoid reprocessing invalidated instructions.
7881
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
7982
/// `ADContext::processDifferentiableFunctionInst`, this field may be removed.
8083
llvm::SmallPtrSet<DifferentiableFunctionInst *, 32>
8184
processedDifferentiableFunctionInsts;
8285

86+
/// The set of `linear_function` instructions that have been processed. Used
87+
/// to avoid reprocessing invalidated instructions.
88+
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
89+
/// `ADContext::processLinearFunctionInst`, this field may be removed.
90+
llvm::SmallPtrSet<LinearFunctionInst *, 32> processedLinearFunctionInsts;
91+
8392
/// Mapping from witnesses to invokers.
8493
/// `SmallMapVector` is used for deterministic insertion order iteration.
8594
llvm::SmallMapVector<SILDifferentiabilityWitness *, DifferentiationInvoker,
@@ -121,30 +130,19 @@ class ADContext {
121130
SILPassManager &getPassManager() const { return passManager; }
122131
Lowering::TypeConverter &getTypeConverter() { return module.Types; }
123132

124-
/// Get or create the synthesized file for the given `SILFunction`.
125-
/// Used by `LinearMapInfo` for adding generated linear map struct and
126-
/// branching trace enum declarations.
127-
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);
128-
129-
/// Returns true if the `differentiable_function` instruction worklist is
130-
/// empty.
131-
bool isDifferentiableFunctionInstsWorklistEmpty() const {
132-
return differentiableFunctionInsts.empty();
133+
llvm::SmallVectorImpl<DifferentiableFunctionInst *> &
134+
getDifferentiableFunctionInstWorklist() {
135+
return differentiableFunctionInsts;
133136
}
134137

135-
/// Pops and returns a `differentiable_function` instruction from the
136-
/// worklist. Returns nullptr if the worklist is empty.
137-
DifferentiableFunctionInst *popDifferentiableFunctionInstFromWorklist() {
138-
if (differentiableFunctionInsts.empty())
139-
return nullptr;
140-
return differentiableFunctionInsts.pop_back_val();
138+
llvm::SmallVectorImpl<LinearFunctionInst *> &getLinearFunctionInstWorklist() {
139+
return linearFunctionInsts;
141140
}
142141

143-
/// Adds the given `differentiable_function` instruction to the worklist.
144-
void
145-
addDifferentiableFunctionInstToWorklist(DifferentiableFunctionInst *dfi) {
146-
differentiableFunctionInsts.push_back(dfi);
147-
}
142+
/// Get or create the synthesized file for the given `SILFunction`.
143+
/// Used by `LinearMapInfo` for adding generated linear map struct and
144+
/// branching trace enum declarations.
145+
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);
148146

149147
/// Returns true if the given `differentiable_function` instruction has
150148
/// already been processed.
@@ -159,6 +157,17 @@ class ADContext {
159157
processedDifferentiableFunctionInsts.insert(dfi);
160158
}
161159

160+
/// Returns true if the given `linear_function` instruction has already been
161+
/// processed.
162+
bool isLinearFunctionInstProcessed(LinearFunctionInst *lfi) const {
163+
return processedLinearFunctionInsts.count(lfi);
164+
}
165+
166+
/// Adds the given `linear_function` instruction to the worklist.
167+
void markLinearFunctionInstAsProcessed(LinearFunctionInst *lfi) {
168+
processedLinearFunctionInsts.insert(lfi);
169+
}
170+
162171
const llvm::SmallMapVector<SILDifferentiabilityWitness *,
163172
DifferentiationInvoker, 32> &
164173
getInvokers() const {
@@ -204,12 +213,26 @@ class ADContext {
204213
IndexSubset *resultIndices, SILValue original,
205214
Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None);
206215

207-
// Given an `differentiable_function` instruction, finds the corresponding
216+
/// Creates a `linear_function` instruction using the given builder
217+
/// and arguments. Erase the newly created instruction from the processed set,
218+
/// if it exists - it may exist in the processed set if it has the same
219+
/// pointer value as a previously processed and deleted instruction.
220+
LinearFunctionInst *
221+
createLinearFunction(SILBuilder &builder, SILLocation loc,
222+
IndexSubset *parameterIndices, SILValue original,
223+
Optional<SILValue> transposeFunction = None);
224+
225+
// Given a `differentiable_function` instruction, finds the corresponding
208226
// differential operator used in the AST. If no differential operator is
209227
// found, return nullptr.
210228
DifferentiableFunctionExpr *
211229
findDifferentialOperator(DifferentiableFunctionInst *inst);
212230

231+
// Given a `linear_function` instruction, finds the corresponding differential
232+
// operator used in the AST. If no differential operator is found, return
233+
// nullptr.
234+
LinearFunctionExpr *findDifferentialOperator(LinearFunctionInst *inst);
235+
213236
template <typename... T, typename... U>
214237
InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag,
215238
U &&... args) const {
@@ -300,6 +323,21 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
300323
return diagnose(loc, diag, std::forward<U>(args)...);
301324
}
302325

326+
// For `linear_function` instructions: if the `linear_function` instruction
327+
// comes from a differential operator, emit an error on the expression and a
328+
// note on the non-differentiable operation. Otherwise, emit both an error and
329+
// note on the non-differentiation operation.
330+
case DifferentiationInvoker::Kind::LinearFunctionInst: {
331+
auto *inst = invoker.getLinearFunctionInst();
332+
if (auto *expr = findDifferentialOperator(inst)) {
333+
diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error)
334+
.highlight(expr->getSubExpr()->getSourceRange());
335+
return diagnose(loc, diag, std::forward<U>(args)...);
336+
}
337+
diagnose(loc, diag::autodiff_expression_not_differentiable_error);
338+
return diagnose(loc, diag, std::forward<U>(args)...);
339+
}
340+
303341
// For differentiability witnesses: try to find a `@differentiable` or
304342
// `@derivative` attribute. If an attribute is found, emit an error on it;
305343
// otherwise, emit an error on the original function.

include/swift/SILOptimizer/Differentiation/DifferentiationInvoker.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace swift {
2525

2626
class ApplyInst;
2727
class DifferentiableFunctionInst;
28+
class LinearFunctionInst;
2829
class SILDifferentiabilityWitness;
2930

3031
namespace autodiff {
@@ -42,6 +43,10 @@ struct DifferentiationInvoker {
4243
// expression).
4344
DifferentiableFunctionInst,
4445

46+
// Invoked by an `linear_function` instruction, which may or may not
47+
// be linked to a Swift AST node (e.g. an `LinearFunctionExpr` expression).
48+
LinearFunctionInst,
49+
4550
// Invoked by the indirect application of differentiation. This case has an
4651
// associated original `apply` instruction and
4752
// `SILDifferentiabilityWitness`.
@@ -60,6 +65,10 @@ struct DifferentiationInvoker {
6065
DifferentiableFunctionInst *diffFuncInst;
6166
Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}
6267

68+
/// The instruction associated with the `LinearFunctionInst` case.
69+
LinearFunctionInst *linearFuncInst;
70+
Value(LinearFunctionInst *inst) : linearFuncInst(inst) {}
71+
6372
/// The parent `apply` instruction and the witness associated with the
6473
/// `IndirectDifferentiation` case.
6574
std::pair<ApplyInst *, SILDifferentiabilityWitness *>
@@ -79,6 +88,8 @@ struct DifferentiationInvoker {
7988
public:
8089
DifferentiationInvoker(DifferentiableFunctionInst *inst)
8190
: kind(Kind::DifferentiableFunctionInst), value(inst) {}
91+
DifferentiationInvoker(LinearFunctionInst *inst)
92+
: kind(Kind::LinearFunctionInst), value(inst) {}
8293
DifferentiationInvoker(ApplyInst *applyInst,
8394
SILDifferentiabilityWitness *witness)
8495
: kind(Kind::IndirectDifferentiation), value({applyInst, witness}) {}
@@ -92,6 +103,11 @@ struct DifferentiationInvoker {
92103
return value.diffFuncInst;
93104
}
94105

106+
LinearFunctionInst *getLinearFunctionInst() const {
107+
assert(kind == Kind::LinearFunctionInst);
108+
return value.linearFuncInst;
109+
}
110+
95111
std::pair<ApplyInst *, SILDifferentiabilityWitness *>
96112
getIndirectDifferentiation() const {
97113
assert(kind == Kind::IndirectDifferentiation);

lib/SILOptimizer/Differentiation/ADContext.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,24 @@ DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
123123
return dfi;
124124
}
125125

126+
LinearFunctionInst *ADContext::createLinearFunction(
127+
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
128+
SILValue original, Optional<SILValue> transposeFunction) {
129+
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
130+
transposeFunction);
131+
processedLinearFunctionInsts.erase(lfi);
132+
return lfi;
133+
}
134+
126135
DifferentiableFunctionExpr *
127136
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
128137
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
129138
}
130139

140+
LinearFunctionExpr *
141+
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
142+
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
143+
}
144+
131145
} // end namespace autodiff
132146
} // end namespace swift

lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ SourceLoc DifferentiationInvoker::getLocation() const {
2828
switch (kind) {
2929
case Kind::DifferentiableFunctionInst:
3030
return getDifferentiableFunctionInst()->getLoc().getSourceLoc();
31+
case Kind::LinearFunctionInst:
32+
return getLinearFunctionInst()->getLoc().getSourceLoc();
3133
case Kind::IndirectDifferentiation:
3234
return getIndirectDifferentiation().first->getLoc().getSourceLoc();
3335
case Kind::SILDifferentiabilityWitnessInvoker:
@@ -46,6 +48,9 @@ void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
4648
os << "differentiable_function_inst=(" << *getDifferentiableFunctionInst()
4749
<< ")";
4850
break;
51+
case Kind::LinearFunctionInst:
52+
os << "linear_function_inst=(" << *getLinearFunctionInst() << ")";
53+
break;
4954
case Kind::IndirectDifferentiation: {
5055
auto indDiff = getIndirectDifferentiation();
5156
os << "indirect_differentiation=(" << *std::get<0>(indDiff) << ')';

lib/SILOptimizer/Differentiation/JVPCloner.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ class JVPCloner::Implementation final
601601
builder, loc, indices.parameters, indices.results, origCallee);
602602

603603
// Record the `differentiable_function` instruction.
604-
context.addDifferentiableFunctionInstToWorklist(diffFuncInst);
604+
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
605605

606606
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
607607
auto extractedJVP = builder.createDifferentiableFunctionExtract(
@@ -749,7 +749,15 @@ class JVPCloner::Implementation final
749749
// instruction to the `differentiable_function` worklist.
750750
TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
751751
auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
752-
context.addDifferentiableFunctionInstToWorklist(newDFI);
752+
context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
753+
}
754+
755+
void visitLinearFunctionInst(LinearFunctionInst *lfi) {
756+
// Clone `linear_function` from original to JVP, then add the cloned
757+
// instruction to the `linear_function` worklist.
758+
TypeSubstCloner::visitLinearFunctionInst(lfi);
759+
auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
760+
context.getLinearFunctionInstWorklist().push_back(newLFI);
753761
}
754762

755763
//--------------------------------------------------------------------------//

lib/SILOptimizer/Differentiation/VJPCloner.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ class VJPCloner::Implementation final
529529
getBuilder(), loc, indices.parameters, indices.results, origCallee);
530530

531531
// Record the `differentiable_function` instruction.
532-
context.addDifferentiableFunctionInstToWorklist(diffFuncInst);
532+
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);
533533

534534
auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
535535
auto extractedVJP = getBuilder().createDifferentiableFunctionExtract(
@@ -623,7 +623,15 @@ class VJPCloner::Implementation final
623623
// instruction to the `differentiable_function` worklist.
624624
TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
625625
auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
626-
context.addDifferentiableFunctionInstToWorklist(newDFI);
626+
context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
627+
}
628+
629+
void visitLinearFunctionInst(LinearFunctionInst *lfi) {
630+
// Clone `linear_function` from original to VJP, then add the cloned
631+
// instruction to the `linear_function` worklist.
632+
TypeSubstCloner::visitLinearFunctionInst(lfi);
633+
auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
634+
context.getLinearFunctionInstWorklist().push_back(newLFI);
627635
}
628636
};
629637

0 commit comments

Comments
 (0)