Skip to content

Commit 1f4438b

Browse files
authored
---
yaml --- r: 340954 b: refs/heads/rxwei-patch-1 c: 06a93b8 h: refs/heads/master
1 parent 643e9b0 commit 1f4438b

File tree

4 files changed

+27
-19
lines changed

4 files changed

+27
-19
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 85255eb6ae9e439747954a4d0978b60ac5ecd3ca
1018+
refs/heads/rxwei-patch-1: 06a93b8e8ce41d162fe12b67c3bed55b1a058c50
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3828,6 +3828,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38283828
type.getCategory());
38293829
}
38303830

3831+
/// Substitutes all replacement types of the given substitution map using the
3832+
/// adjoint function's substitution map.
3833+
SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
3834+
return substMap.subst(getAdjoint().getForwardingSubstitutionMap());
3835+
}
3836+
38313837
//--------------------------------------------------------------------------//
38323838
// Managed value mapping
38333839
//--------------------------------------------------------------------------//
@@ -4654,7 +4660,8 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
46544660
pullbackType, *applyInfo.originalPullbackType);
46554661
auto *thunkRef = builder.createFunctionRef(loc, thunk);
46564662
pullback = builder.createPartialApply(
4657-
loc, thunkRef, thunk->getForwardingSubstitutionMap(),
4663+
loc, thunkRef,
4664+
remapSubstitutionMap(thunk->getForwardingSubstitutionMap()),
46584665
{pullback}, pullbackType->getCalleeConvention());
46594666
}
46604667
args.push_back(seed);

branches/rxwei-patch-1/test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -170,23 +170,6 @@ struct TF_305 : Differentiable {
170170
}
171171
}
172172

173-
protocol TF_534_Layer : Differentiable {
174-
associatedtype Input : Differentiable
175-
associatedtype Output : Differentiable
176-
177-
@differentiable
178-
func callAsFunction(_ input: Input) -> Output
179-
}
180-
struct TF_534_Tensor<Scalar> : Differentiable {}
181-
182-
func TF_534<Model: TF_534_Layer>(
183-
_ model: inout Model, inputs: Model.Input
184-
) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
185-
return valueWithPullback(at: model) { model -> Model.Output in
186-
return model(inputs)
187-
}.0
188-
}
189-
190173
//===----------------------------------------------------------------------===//
191174
// Classes and existentials (not yet supported)
192175
//===----------------------------------------------------------------------===//

branches/rxwei-patch-1/test/AutoDiff/generics.swift

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,22 @@ func TF_523_f(_ x: TF_523_Struct) -> Float {
151151
return x.a * 2
152152
}
153153

154+
// TF_534: Thunk substitution map remapping.
155+
protocol TF_534_Layer : Differentiable {
156+
associatedtype Input : Differentiable
157+
associatedtype Output : Differentiable
158+
159+
@differentiable
160+
func callAsFunction(_ input: Input) -> Output
161+
}
162+
struct TF_534_Tensor<Scalar> : Differentiable {}
163+
164+
func TF_534<Model: TF_534_Layer>(
165+
_ model: inout Model, inputs: Model.Input
166+
) -> TF_534_Tensor<Float> where Model.Output == TF_534_Tensor<Float> {
167+
return valueWithPullback(at: model) { model -> Model.Output in
168+
return model(inputs)
169+
}.0
170+
}
171+
154172
// TODO: add more tests.

0 commit comments

Comments
 (0)