Skip to content

Commit def2be9

Browse files
authored
[AutoDiff] Fix substitution map remapping bug. (#25689)
Remap pullback reabstraction thunk substitution map during VJP generation. Resolves TF-534.
1 parent 6b822d8 commit def2be9

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3336,7 +3336,8 @@ class VJPEmitter final
33363336
loweredPullbackType);
33373337
auto *thunkRef = getBuilder().createFunctionRef(loc, thunk);
33383338
pullback = getBuilder().createPartialApply(
3339-
ai->getLoc(), thunkRef, thunk->getForwardingSubstitutionMap(),
3339+
ai->getLoc(), thunkRef,
3340+
getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()),
33403341
{pullback}, actualPullbackType->getCalleeConvention());
33413342
}
33423343
pullbackValues[ai->getParent()].push_back(pullback);

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,23 @@ struct TF_305 : Differentiable {
170170
}
171171
}
172172

173+
protocol TF_534_Layer : Differentiable {
174+
associatedtype Input : Differentiable
175+
associatedtype Output : Differentiable
176+
177+
@differentiable
178+
func callAsFunction(_ input: Input) -> Output
179+
}
180+
struct TF_534_Tensor<Scalar> : Differentiable {}
181+
182+
func TF_534<Model: TF_534_Layer>(
183+
_ model: inout Model, inputs: Model.Input
184+
) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
185+
return valueWithPullback(at: model) { model -> Model.Output in
186+
return model(inputs)
187+
}.0
188+
}
189+
173190
//===----------------------------------------------------------------------===//
174191
// Classes and existentials (not yet supported)
175192
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)