@@ -73,13 +73,22 @@ class ADContext {
73
73
llvm::SmallVector<DifferentiableFunctionInst *, 32 >
74
74
differentiableFunctionInsts;
75
75
76
+ // / The worklist (stack) of `linear_function` instructions to be processed.
77
+ llvm::SmallVector<LinearFunctionInst *, 32 > linearFunctionInsts;
78
+
76
79
// / The set of `differentiable_function` instructions that have been
77
80
// / processed. Used to avoid reprocessing invalidated instructions.
78
81
// / NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
79
82
// / `ADContext::processDifferentiableFunctionInst`, this field may be removed.
80
83
llvm::SmallPtrSet<DifferentiableFunctionInst *, 32 >
81
84
processedDifferentiableFunctionInsts;
82
85
86
+ // / The set of `linear_function` instructions that have been processed. Used
87
+ // / to avoid reprocessing invalidated instructions.
88
+ // / NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
89
+ // / `ADContext::processLinearFunctionInst`, this field may be removed.
90
+ llvm::SmallPtrSet<LinearFunctionInst *, 32 > processedLinearFunctionInsts;
91
+
83
92
// / Mapping from witnesses to invokers.
84
93
// / `SmallMapVector` is used for deterministic insertion order iteration.
85
94
llvm::SmallMapVector<SILDifferentiabilityWitness *, DifferentiationInvoker,
@@ -121,30 +130,19 @@ class ADContext {
121
130
SILPassManager &getPassManager () const { return passManager; }
122
131
Lowering::TypeConverter &getTypeConverter () { return module .Types ; }
123
132
124
- // / Get or create the synthesized file for the given `SILFunction`.
125
- // / Used by `LinearMapInfo` for adding generated linear map struct and
126
- // / branching trace enum declarations.
127
- SynthesizedFileUnit &getOrCreateSynthesizedFile (SILFunction *original);
128
-
129
- // / Returns true if the `differentiable_function` instruction worklist is
130
- // / empty.
131
- bool isDifferentiableFunctionInstsWorklistEmpty () const {
132
- return differentiableFunctionInsts.empty ();
133
+ llvm::SmallVectorImpl<DifferentiableFunctionInst *> &
134
+ getDifferentiableFunctionInstWorklist () {
135
+ return differentiableFunctionInsts;
133
136
}
134
137
135
- // / Pops and returns a `differentiable_function` instruction from the
136
- // / worklist. Returns nullptr if the worklist is empty.
137
- DifferentiableFunctionInst *popDifferentiableFunctionInstFromWorklist () {
138
- if (differentiableFunctionInsts.empty ())
139
- return nullptr ;
140
- return differentiableFunctionInsts.pop_back_val ();
138
+ llvm::SmallVectorImpl<LinearFunctionInst *> &getLinearFunctionInstWorklist () {
139
+ return linearFunctionInsts;
141
140
}
142
141
143
- // / Adds the given `differentiable_function` instruction to the worklist.
144
- void
145
- addDifferentiableFunctionInstToWorklist (DifferentiableFunctionInst *dfi) {
146
- differentiableFunctionInsts.push_back (dfi);
147
- }
142
+ // / Get or create the synthesized file for the given `SILFunction`.
143
+ // / Used by `LinearMapInfo` for adding generated linear map struct and
144
+ // / branching trace enum declarations.
145
+ SynthesizedFileUnit &getOrCreateSynthesizedFile (SILFunction *original);
148
146
149
147
// / Returns true if the given `differentiable_function` instruction has
150
148
// / already been processed.
@@ -159,6 +157,17 @@ class ADContext {
159
157
processedDifferentiableFunctionInsts.insert (dfi);
160
158
}
161
159
160
+ // / Returns true if the given `linear_function` instruction has already been
161
+ // / processed.
162
+ bool isLinearFunctionInstProcessed (LinearFunctionInst *lfi) const {
163
+ return processedLinearFunctionInsts.count (lfi);
164
+ }
165
+
166
+ // / Adds the given `linear_function` instruction to the worklist.
167
+ void markLinearFunctionInstAsProcessed (LinearFunctionInst *lfi) {
168
+ processedLinearFunctionInsts.insert (lfi);
169
+ }
170
+
162
171
const llvm::SmallMapVector<SILDifferentiabilityWitness *,
163
172
DifferentiationInvoker, 32 > &
164
173
getInvokers () const {
@@ -204,12 +213,26 @@ class ADContext {
204
213
IndexSubset *resultIndices, SILValue original,
205
214
Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None);
206
215
207
- // Given an `differentiable_function` instruction, finds the corresponding
216
+ // / Creates a `linear_function` instruction using the given builder
217
+ // / and arguments. Erase the newly created instruction from the processed set,
218
+ // / if it exists - it may exist in the processed set if it has the same
219
+ // / pointer value as a previously processed and deleted instruction.
220
+ LinearFunctionInst *
221
+ createLinearFunction (SILBuilder &builder, SILLocation loc,
222
+ IndexSubset *parameterIndices, SILValue original,
223
+ Optional<SILValue> transposeFunction = None);
224
+
225
+ // Given a `differentiable_function` instruction, finds the corresponding
208
226
// differential operator used in the AST. If no differential operator is
209
227
// found, return nullptr.
210
228
DifferentiableFunctionExpr *
211
229
findDifferentialOperator (DifferentiableFunctionInst *inst);
212
230
231
+ // Given a `linear_function` instruction, finds the corresponding differential
232
+ // operator used in the AST. If no differential operator is found, return
233
+ // nullptr.
234
+ LinearFunctionExpr *findDifferentialOperator (LinearFunctionInst *inst);
235
+
213
236
template <typename ... T, typename ... U>
214
237
InFlightDiagnostic diagnose (SourceLoc loc, Diag<T...> diag,
215
238
U &&... args) const {
@@ -300,6 +323,21 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
300
323
return diagnose (loc, diag, std::forward<U>(args)...);
301
324
}
302
325
326
+ // For `linear_function` instructions: if the `linear_function` instruction
327
+ // comes from a differential operator, emit an error on the expression and a
328
+ // note on the non-differentiable operation. Otherwise, emit both an error and
329
+ // note on the non-differentiation operation.
330
+ case DifferentiationInvoker::Kind::LinearFunctionInst: {
331
+ auto *inst = invoker.getLinearFunctionInst ();
332
+ if (auto *expr = findDifferentialOperator (inst)) {
333
+ diagnose (expr->getLoc (), diag::autodiff_function_not_differentiable_error)
334
+ .highlight (expr->getSubExpr ()->getSourceRange ());
335
+ return diagnose (loc, diag, std::forward<U>(args)...);
336
+ }
337
+ diagnose (loc, diag::autodiff_expression_not_differentiable_error);
338
+ return diagnose (loc, diag, std::forward<U>(args)...);
339
+ }
340
+
303
341
// For differentiability witnesses: try to find a `@differentiable` or
304
342
// `@derivative` attribute. If an attribute is found, emit an error on it;
305
343
// otherwise, emit an error on the original function.
0 commit comments