Skip to content

Commit 9fceb38

Browse files
authored
Revert "[AutoDiff] Constrain wrt parameters to conform to Differentiable. (#26406)" (#26420)
This reverts commit 6f58fd4. #26406 introduced TF-697; reverting it until the issue is fixed. Add testcase for TF-697 to prevent future regressions.
1 parent 5a73cc1 commit 9fceb38

File tree

3 files changed

+50
-34
lines changed

3 files changed

+50
-34
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -217,19 +217,6 @@ getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
217217
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
218218
for (auto &req : attr->getRequirements())
219219
builder.addRequirement(req, source, original->getModule().getSwiftModule());
220-
// Constrain all wrt parameters to conform to `Differentiable`.
221-
auto &ctx = original->getASTContext();
222-
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
223-
auto paramIndexSet = attr->getIndices().parameters;
224-
for (unsigned paramIdx : paramIndexSet->getIndices()) {
225-
if (!paramIndexSet->contains(paramIdx))
226-
continue;
227-
auto paramType =
228-
original->getConventions().getSILArgumentType(paramIdx).getASTType();
229-
Requirement req(RequirementKind::Conformance, paramType,
230-
diffableProto->getDeclaredType());
231-
builder.addRequirement(req, source, original->getModule().getSwiftModule());
232-
}
233220
return std::move(builder)
234221
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true)
235222
->getCanonicalSignature();
@@ -2863,17 +2850,11 @@ class VJPEmitter final
28632850
auto origTy = original->getLoweredFunctionType();
28642851
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
28652852

2866-
auto pbGenericSig = getAssociatedFunctionGenericSignature(attr, original);
2867-
28682853
// RAII that pushes the original function's generic signature to
28692854
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
2870-
// will know the pullback's generic parameter types.
2855+
// will know the original function's generic parameter types.
28712856
Lowering::GenericContextScope genericContextScope(
2872-
module.Types, pbGenericSig);
2873-
2874-
auto *pbGenericEnv = pbGenericSig
2875-
? pbGenericSig->createGenericEnvironment()
2876-
: nullptr;
2857+
module.Types, origTy->getGenericSignature());
28772858

28782859
// Given a type, returns its formal SIL parameter info.
28792860
auto getTangentParameterInfoForOriginalResult = [&](
@@ -2965,6 +2946,10 @@ class VJPEmitter final
29652946
mangler.mangleAutoDiffLinearMapHelper(
29662947
original->getName(), AutoDiffLinearMapKind::Pullback,
29672948
indices)).str();
2949+
auto pbGenericSig = getAssociatedFunctionGenericSignature(attr, original);
2950+
auto *pbGenericEnv = pbGenericSig
2951+
? pbGenericSig->createGenericEnvironment()
2952+
: nullptr;
29682953
auto pbType = SILFunctionType::get(
29692954
pbGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
29702955
origTy->getCalleeConvention(), pbParams, {}, adjResults, None,
@@ -3286,7 +3271,7 @@ class VJPEmitter final
32863271
auto original = getOpValue(ai->getCallee());
32873272
auto functionSource = original;
32883273
SILValue vjpValue;
3289-
// If `functionSource` is a `@differentiable` function, just extract it.
3274+
// If functionSource is a @differentiable function, just extract it.
32903275
auto originalFnTy = original->getType().castTo<SILFunctionType>();
32913276
if (originalFnTy->isDifferentiable()) {
32923277
auto paramIndices = originalFnTy->getDifferentiationParameterIndices();
@@ -3536,6 +3521,12 @@ class JVPEmitter final
35363521
auto origTy = original->getLoweredFunctionType();
35373522
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
35383523

3524+
// RAII that pushes the original function's generic signature to
3525+
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
3526+
// will know the original function's generic parameter types.
3527+
Lowering::GenericContextScope genericContextScope(
3528+
module.Types, origTy->getGenericSignature());
3529+
35393530
SmallVector<SILParameterInfo, 8> diffParams;
35403531
SmallVector<SILResultInfo, 8> diffResults;
35413532
auto origParams = origTy->getParameters();
@@ -3562,13 +3553,6 @@ class JVPEmitter final
35623553
original->getName(), AutoDiffLinearMapKind::Differential,
35633554
indices)).str();
35643555
auto diffGenericSig = getAssociatedFunctionGenericSignature(attr, original);
3565-
3566-
// RAII that pushes the original function's generic signature to
3567-
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
3568-
// will know the differential's generic parameter types.
3569-
Lowering::GenericContextScope genericContextScope(
3570-
module.Types, diffGenericSig);
3571-
35723556
auto *diffGenericEnv = diffGenericSig
35733557
? diffGenericSig->createGenericEnvironment()
35743558
: nullptr;
@@ -5991,7 +5975,7 @@ static SILFunction *createEmptyJVP(
59915975

59925976
// RAII that pushes the original function's generic signature to
59935977
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
5994-
// will know the JVP's generic parameter types.
5978+
// will know the VJP's generic parameter types.
59955979
Lowering::GenericContextScope genericContextScope(
59965980
module.Types, jvpGenericSig);
59975981

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
1212
}
1313
_ = gradient(at: 1.0, in: generic)
1414

15-
//===----------------------------------------------------------------------===//
16-
// Unmet generic requirements
17-
//===----------------------------------------------------------------------===//
15+
// Test unmet generic requirements.
1816

1917
@differentiable(
2018
vjp: vjpWeirdExtraRequirements
@@ -91,6 +89,10 @@ let _: @differentiable (Float) -> TF_687<Any> = { x in TF_687<Any>(x, dummy: x)
9189
// Add `Differentiable` conformance for generic wrt parameters
9290
//===----------------------------------------------------------------------===//
9391

92+
// FIXME(TF-697): The tests below were fixed by
93+
// https://github.com/apple/swift/pull/26406, which was reverted because it
94+
// introduced TF-697.
95+
/*
9496
func id<T>(_ x: T) -> T { x }
9597
let _: @differentiable (Float) -> Float = { x in id(x) }
9698

@@ -104,4 +106,5 @@ extension TF_691: Differentiable where Scalar: Differentiable {}
104106

105107
func identity<T>(_ x: TF_691<T>) -> TF_691<T> { x }
106108
let _: @differentiable (Float) -> TF_691<Float> = { x in identity(TF_691(x)) }
107-
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }
109+
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }
110+
*/

test/AutoDiff/generics.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,33 @@ public func TF_688<Scalar: Differentiable>(
247247
reduction(x)
248248
}
249249

250+
// TF-697: Test generic requirements of generated AD associated function.
251+
protocol TF_697_Module: Differentiable where AllDifferentiableVariables == TangentVector {
252+
associatedtype Input
253+
associatedtype Output: Differentiable
254+
255+
@differentiable(wrt: self)
256+
func callModule(_ input: Input) -> Output
257+
}
258+
protocol TF_697_Layer: TF_697_Module where Input: Differentiable {
259+
@differentiable
260+
func callLayer(_ input: Input) -> Output
261+
}
262+
struct TF_697_Sequential<Layer1: TF_697_Module, Layer2: TF_697_Layer>: TF_697_Module
263+
where Layer1.Output == Layer2.Input {
264+
var layer1: Layer1
265+
var layer2: Layer2
266+
267+
@differentiable(wrt: self)
268+
func callModule(_ input: Layer1.Input) -> Layer2.Output {
269+
layer2.callLayer(layer1.callModule(input))
270+
}
271+
}
272+
extension TF_697_Sequential: TF_697_Layer where Layer1: TF_697_Layer {
273+
@differentiable
274+
func callLayer(_ input: Layer1.Input) -> Layer2.Output {
275+
layer2.callLayer(layer1.callLayer(input))
276+
}
277+
}
278+
250279
// TODO: add more tests.

0 commit comments

Comments
 (0)