Skip to content

[AutoDiff] Add differentiability_witness_function verification. #28505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions lib/SIL/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,13 +781,21 @@ SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation debugLoc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType)
: InstructionBase(debugLoc, FunctionType
? *FunctionType
SILDifferentiabilityWitness *witness, Optional<SILType> functionType)
: InstructionBase(debugLoc, functionType
? *functionType
: getDifferentiabilityWitnessType(
module, witnessKind, witness)),
witnessKind(witnessKind), witness(witness),
hasExplicitFunctionType(FunctionType) {}
hasExplicitFunctionType(functionType) {
assert(witness && "Differentiability witness must not be null");
#ifndef NDEBUG
if (functionType.hasValue()) {
assert(module.getStage() == SILStage::Lowered &&
"Explicit type is valid only in lowered SIL");
}
#endif
}
// SWIFT_ENABLE_TENSORFLOW END

FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind,
Expand Down
22 changes: 22 additions & 0 deletions lib/SIL/SILVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,28 @@ class SILVerifier : public SILVerifierBase<SILVerifier> {
"The function operand must be a '@differentiable(linear)' "
"function");
}

void checkDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *dwfi) {
auto witnessFnTy = dwfi->getType().castTo<SILFunctionType>();
auto *witness = dwfi->getWitness();
// `DifferentiabilityWitnessFunctionInst` constructor asserts that
// `witness` is non-null.
auto witnessKind = dwfi->getWitnessKind();
// Return if not witnessing a derivative function.
auto derivKind = witnessKind.getAsDerivativeFunctionKind();
if (!derivKind)
return;
// Return if witness does not define the referenced derivative.
auto *derivativeFn = witness->getDerivative(*derivKind);
if (!derivativeFn)
return;
auto derivativeFnTy = derivativeFn->getLoweredFunctionType();
requireSameType(SILType::getPrimitiveObjectType(witnessFnTy),
SILType::getPrimitiveObjectType(derivativeFnTy),
"Type of witness instruction does not match actual type of "
"witnessed function");
}
// SWIFT_ENABLE_TENSORFLOW END

void verifyLLVMIntrinsic(BuiltinInst *BI, llvm::Intrinsic::ID ID) {
Expand Down
7 changes: 0 additions & 7 deletions test/AutoDiff/differentiability_witness_function_inst.sil
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ bb0:
// Test "dependent" generic requirements: `T == T.TangentVector` depends on `T: Differentiable`.
%generic_vjp_wrt_0_1_dependent_req = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T where T: Differentiable, T == T.TangentVector> @generic : $@convention(thin) <T> (@in_guaranteed T, Float) -> @out T

// Test explicit function types.
%explicit_fnty = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float)

return undef : $()
}

Expand All @@ -73,7 +70,6 @@ bb0:
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : Differentiable> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [vjp] [parameters 0 1] [results 0] <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0
// CHECK: {{%.*}} = differentiability_witness_function [jvp] [parameters 0] [results 0] @foo : $@convention(thin) (Float, Float, Float) -> Float as $@convention(thin) (Float, Float, Float) -> (Float, (Float) -> Float)
// CHECK: }

// IRGEN: @AD__foo_PSUURS = external global %swift.differentiability_witness, align 8
Expand Down Expand Up @@ -106,6 +102,3 @@ bb0:

// IRGEN: [[PTR7:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__generic_PSSRSs14DifferentiableRz13TangentVectorsAAPQzRszl, i32 0, i32 1), align 8
// IRGEN: [[FNPTR7:%.*]] = bitcast i8* [[PTR7]] to { i8*, %swift.refcounted* } (%swift.opaque*, %swift.opaque*, float, %swift.type*, i8**)*

// IRGEN: [[PTR8:%.*]] = load i8*, i8** getelementptr inbounds (%swift.differentiability_witness, %swift.differentiability_witness* @AD__foo_PSUURS, i32 0, i32 0), align 8
// IRGEN: [[FNPTR8:%.*]] = bitcast i8* [[PTR8]] to { float, i8*, %swift.refcounted* } (float, float, float)*
45 changes: 26 additions & 19 deletions test/AutoDiff/differentiable_function_inst_lowered.sil
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck %s

// Test `differentiable_function_extract` with explicit lowered type.
// Test `differentiable_function_extract` and
// `differentiability_witness_function` with explicit lowered type.
// SIL generated via `%target-sil-opt -loadable-address %s`.
// Note: SIL serialization/deserialization does not support lowered SIL.

Expand All @@ -27,37 +28,43 @@ struct Large : Differentiable {
mutating func move(along direction: Large.TangentVector)
}

sil_differentiability_witness [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large

sil @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
sil @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large

// CHECK-LABEL: sil @test
sil @test : $@convention(thin) () -> () {
bb0:
%0 = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
%func = function_ref @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%func_jvpwitness_wrt_012 = differentiability_witness_function [jvp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector)
%func_vjpwitness_wrt_012 = differentiability_witness_function [vjp] [parameters 0 1 2] [results 0] @examplefunc : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
%func_diff_wrt_012 = differentiable_function [parameters 0 1 2] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large with_derivative {%func_jvpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector, Large.TangentVector, Large.TangentVector) -> Large.TangentVector), %func_vjpwitness_wrt_012 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))}
%func_vjp_wrt_012 = differentiable_function_extract [vjp] %func_diff_wrt_012 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

// CHECK: %1 = differentiable_function [parameters 0 1 2] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %2 = differentiable_function_extract [vjp] %1 : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
// CHECK: [[FUNC_REF:%.*]] = function_ref @examplefunc
// CHECK: [[DIFF_WRT_012:%.*]] = differentiable_function [parameters 0 1 2] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: [[VJP_WRT_012:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_012]] : $@differentiable @convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

%3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
%func_diff_wrt_0 = differentiable_function [parameters 0] %func : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%func_vjp_wrt_0 = differentiable_function_extract [vjp] %func_diff_wrt_0 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

// CHECK: %3 = differentiable_function [parameters 0] %0 : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %4 = differentiable_function_extract [vjp] %3 : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
// CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[FUNC_REF]] : $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(thin) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(thin) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

%5 = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
%method = function_ref @examplemethod : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%method_diff_wrt_0123 = differentiable_function [parameters 0 1 2] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%7 = differentiable_function_extract [vjp] %method_diff_wrt_0123 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

// CHECK: %6 = differentiable_function [parameters 0 1 2] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %7 = differentiable_function_extract [vjp] %6 : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))
// CHECK: [[METHOD_REF:%.*]] = function_ref @examplemethod
// CHECK: [[DIFF_WRT_0123:%.*]] = differentiable_function [parameters 0 1 2] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: [[VJP_WRT_0123:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0123]] : $@differentiable @convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> (Large.TangentVector, Large.TangentVector, Large.TangentVector))

%8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
%method_diff_wrt_0 = differentiable_function [parameters 0] %method : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
%method_vjp_wrt_0 = differentiable_function_extract [vjp] %method_diff_wrt_0 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

// CHECK: %8 = differentiable_function [parameters 0] %5 : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: %9 = differentiable_function_extract [vjp] %8 : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
// CHECK: [[DIFF_WRT_0:%.*]] = differentiable_function [parameters 0] [[METHOD_REF]] : $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> @out Large
// CHECK: [[VJP_WRT_0:%.*]] = differentiable_function_extract [vjp] [[DIFF_WRT_0]] : $@differentiable @convention(method) (@in_constant Large, @nondiff @in_constant Large, @nondiff @in_constant Large) -> @out Large as $@convention(method) (@in_constant Large, @in_constant Large, @in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)

%10 = tuple ()
return %10 : $()
Expand Down