Skip to content

Commit cd4a6ba

Browse files
authored
[AutoDiff] Add differentiability_witness_function verification. (#28505)
Add `differentiability_witness_function` instruction assertions: - Must never be constructed with null witness. - Can only have explicit type in lowered SIL. Add `differentiability_witness_function` instruction verification: check that type of `differentiability_witness_function` instruction matches the type of the witnessed SIL function.
1 parent f8c773b commit cd4a6ba

File tree

4 files changed

+60
-30
lines changed

4 files changed

+60
-30
lines changed

lib/SIL/SILInstructions.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -781,13 +781,21 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
781781
DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst(
782782
SILModule &module, SILDebugLocation debugLoc,
783783
DifferentiabilityWitnessFunctionKind witnessKind,
784-
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType)
785-
: InstructionBase(debugLoc, FunctionType
786-
? *FunctionType
784+
SILDifferentiabilityWitness *witness, Optional<SILType> functionType)
785+
: InstructionBase(debugLoc, functionType
786+
? *functionType
787787
: getDifferentiabilityWitnessType(
788788
module, witnessKind, witness)),
789789
witnessKind(witnessKind), witness(witness),
790-
hasExplicitFunctionType(FunctionType) {}
790+
hasExplicitFunctionType(functionType) {
791+
assert(witness && "Differentiability witness must not be null");
792+
#ifndef NDEBUG
793+
if (functionType.hasValue()) {
794+
assert(module.getStage() == SILStage::Lowered &&
795+
"Explicit type is valid only in lowered SIL");
796+
}
797+
#endif
798+
}
791799
// SWIFT_ENABLE_TENSORFLOW END
792800

793801
FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind,

lib/SIL/SILVerifier.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,28 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
15761576
"The function operand must be a '@differentiable(linear)' "
15771577
"function");
15781578
}
1579+
1580+
void checkDifferentiabilityWitnessFunctionInst(
1581+
DifferentiabilityWitnessFunctionInst *dwfi) {
1582+
auto witnessFnTy = dwfi->getType().castTo<SILFunctionType>();
1583+
auto *witness = dwfi->getWitness();
1584+
// `DifferentiabilityWitnessFunctionInst` constructor asserts that
1585+
// `witness` is non-null.
1586+
auto witnessKind = dwfi->getWitnessKind();
1587+
// Return if not witnessing a derivative function.
1588+
auto derivKind = witnessKind.getAsDerivativeFunctionKind();
1589+
if (!derivKind)
1590+
return;
1591+
// Return if witness does not define the referenced derivative.
1592+
auto *derivativeFn = witness->getDerivative(*derivKind);
1593+
if (!derivativeFn)
1594+
return;
1595+
auto derivativeFnTy = derivativeFn->getLoweredFunctionType();
1596+
requireSameType(SILType::getPrimitiveObjectType(witnessFnTy),
1597+
SILType::getPrimitiveObjectType(derivativeFnTy),
1598+
"Type of witness instruction does not match actual type of "
1599+
"witnessed function");
1600+
}
15791601
// SWIFT_ENABLE_TENSORFLOW END
15801602

15811603
void verifyLLVMIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID) {

test/AutoDiff/differentiability_witness_function_inst.sil

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ bb0:
5858
// Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`.
5959
%generic_vjp_wrt_0_1_dependent_req = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T where T: Differentiable, T == T.TangentVector> @generic : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T
6060

61-
// Test explicit function types.
62-
%explicit_fnty = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float)
63-
6461
return undef : $()
6562
}
6663

@@ -73,7 +70,6 @@ bb0:
7370
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
7471
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
7572
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
76-
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float)
7773
// CHECK: }
7874

7975
// IRGEN: @AD__foo_PSUURS = external global %swift.differentiability_witness, align 8
@@ -106,6 +102,3 @@ bb0:
106102

107103
// IRGEN: [[PTR7:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRSs14DifferentiableRz13TangentVectorsAAPQzRszl, i32 0, i32 1), align 8
108104
// IRGEN: [[FNPTR7:%.*]] = bitcast i8* [[PTR7]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)*
109-
110-
// IRGEN: [[PTR8:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSUURS, i32 0, i32 0), align 8
111-
// IRGEN: [[FNPTR8:%.*]] = bitcast i8* [[PTR8]] to { float, i8*, %swift.refcounted* } (float, float, float)*

test/AutoDiff/differentiable_function_inst_lowered.sil

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s
22

3-
// Test `differentiable_function_extract` with explicit lowered type.
3+
// Test `differentiable_function_extract` and
4+
// `differentiability_witness_function` with explicit lowered type.
45
// SIL generated via `%target-sil-opt -loadable-address %s`.
56
// Note: SIL serialization/deserialization does not support lowered SIL.
67

@@ -27,37 +28,43 @@ struct Large : Differentiable {
2728
mutating func move(along direction: Large.TangentVector)
2829
}
2930

31+
sil_differentiability_witness [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
32+
3033
sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
3134
sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
3235

3336
// CHECK-LABEL: sil @test
3437
sil @test : $@convention(thin) () -> () {
3538
bb0:
36-
%0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
37-
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
38-
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
39+
%func = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
40+
%func_jvpwitness_wrt_012 = differentiability_witness_function [jvp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector)
41+
%func_vjpwitness_wrt_012 = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
42+
%func_diff_wrt_012 = differentiable_function [parameters 0 1 2] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large with_derivative {%func_jvpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector), %func_vjpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))}
43+
%func_vjp_wrt_012 = differentiable_function_extract [vjp] %func_diff_wrt_012 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
3944

40-
// CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
41-
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
45+
// CHECK: [[FUNC_REF:%.*]] = function_ref @examplefunc
46+
// CHECK: [[DIFF_WRT_012:%.*]] = differentiable_function [parameters 0 1 2] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
47+
// CHECK: [[VJP_WRT_012:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_012]] : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
4248

43-
%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
44-
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
49+
%func_diff_wrt_0 = differentiable_function [parameters 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
50+
%func_vjp_wrt_0 = differentiable_function_extract [vjp] %func_diff_wrt_0 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
4551

46-
// CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
47-
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
52+
// CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
53+
// CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
4854

49-
%5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
50-
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
51-
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
55+
%method = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
56+
%method_diff_wrt_0123 = differentiable_function [parameters 0 1 2] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
57+
%7 = differentiable_function_extract [vjp] %method_diff_wrt_0123 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
5258

53-
// CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
54-
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
59+
// CHECK: [[METHOD_REF:%.*]] = function_ref @examplemethod
60+
// CHECK: [[DIFF_WRT_0123:%.*]] = differentiable_function [parameters 0 1 2] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
61+
// CHECK: [[VJP_WRT_0123:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0123]] : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
5562

56-
%8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
57-
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
63+
%method_diff_wrt_0 = differentiable_function [parameters 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
64+
%method_vjp_wrt_0 = differentiable_function_extract [vjp] %method_diff_wrt_0 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
5865

59-
// CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
60-
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
66+
// CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
67+
// CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
6168

6269
%10 = tuple ()
6370
return %10 : $()

0 commit comments

Comments
 (0)