Skip to content

Commit fccfa29

Browse files
authored
[AutoDiff upstream] Update LoadableByAddress. (#30825)
Update LoadableByAddress to handle AutoDiff-related instructions: - `differentiable_function` - `differentiable_function_extract` - `linear_function` - `linear_function_extract` - `differentiability_witness_function`
1 parent 3ebb81e commit fccfa29

File tree

4 files changed

+241
-4
lines changed

4 files changed

+241
-4
lines changed

lib/IRGen/LoadableByAddress.cpp

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,14 +403,17 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env,
403403
} else if (isLargeLoadableType(env, storageType, IGM)) {
404404
if (param.getConvention() == ParameterConvention::Direct_Guaranteed)
405405
return SILParameterInfo(storageType.getASTType(),
406-
ParameterConvention::Indirect_In_Guaranteed);
406+
ParameterConvention::Indirect_In_Guaranteed,
407+
param.getDifferentiability());
407408
else
408409
return SILParameterInfo(storageType.getASTType(),
409-
ParameterConvention::Indirect_In_Constant);
410+
ParameterConvention::Indirect_In_Constant,
411+
param.getDifferentiability());
410412
} else {
411413
auto newType = getNewSILType(env, storageType, IGM);
412414
return SILParameterInfo(newType.getASTType(),
413-
param.getConvention());
415+
param.getConvention(),
416+
param.getDifferentiability());
414417
}
415418
}
416419

@@ -1704,6 +1707,9 @@ class LoadableByAddress : public SILModuleTransform {
17041707
bool fixStoreToBlockStorageInstr(SILInstruction &I,
17051708
SmallVectorImpl<SILInstruction *> &Delete);
17061709

1710+
bool recreateDifferentiabilityWitnessFunction(
1711+
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
1712+
17071713
private:
17081714
llvm::SetVector<SILFunction *> modFuncs;
17091715
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2708,6 +2714,33 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr(
27082714
return true;
27092715
}
27102716

2717+
bool LoadableByAddress::recreateDifferentiabilityWitnessFunction(
2718+
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
2719+
auto *instr = dyn_cast<DifferentiabilityWitnessFunctionInst>(&I);
2720+
if (!instr)
2721+
return false;
2722+
2723+
// Check if we need to recreate the instruction.
2724+
auto *currIRMod = getIRGenModule()->IRGen.getGenModule(instr->getFunction());
2725+
auto resultFnTy = instr->getType().castTo<SILFunctionType>();
2726+
auto genSig = resultFnTy->getSubstGenericSignature();
2727+
GenericEnvironment *genEnv = nullptr;
2728+
if (genSig)
2729+
genEnv = genSig->getGenericEnvironment();
2730+
auto newResultFnTy =
2731+
MapperCache.getNewSILFunctionType(genEnv, resultFnTy, *currIRMod);
2732+
if (resultFnTy == newResultFnTy)
2733+
return true;
2734+
2735+
SILBuilderWithScope builder(instr);
2736+
auto *newInstr = builder.createDifferentiabilityWitnessFunction(
2737+
instr->getLoc(), instr->getWitnessKind(), instr->getWitness(),
2738+
SILType::getPrimitiveObjectType(newResultFnTy));
2739+
instr->replaceAllUsesWith(newInstr);
2740+
Delete.push_back(instr);
2741+
return true;
2742+
}
2743+
27112744
bool LoadableByAddress::recreateTupleInstr(
27122745
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
27132746
auto *tupleInstr = dyn_cast<TupleInst>(&I);
@@ -2750,6 +2783,19 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
27502783
auto currSILFunctionType = currSILType.castTo<SILFunctionType>();
27512784
GenericEnvironment *genEnv =
27522785
getSubstGenericEnvironment(convInstr->getFunction());
2786+
// Differentiable function conversion instructions can happen while the
2787+
// function is still generic. In that case, we must calculate the new type
2788+
// using the converted function's generic environment rather than the
2789+
// converting function's generic environment.
2790+
//
2791+
// This happens in witness thunks for default implementations of derivative
2792+
// requirements.
2793+
if (convInstr->getKind() == SILInstructionKind::DifferentiableFunctionInst ||
2794+
convInstr->getKind() == SILInstructionKind::DifferentiableFunctionExtractInst ||
2795+
convInstr->getKind() == SILInstructionKind::LinearFunctionInst ||
2796+
convInstr->getKind() == SILInstructionKind::LinearFunctionExtractInst)
2797+
if (auto genSig = currSILFunctionType->getSubstGenericSignature())
2798+
genEnv = genSig->getGenericEnvironment();
27532799
CanSILFunctionType newFnType = MapperCache.getNewSILFunctionType(
27542800
genEnv, currSILFunctionType, *currIRMod);
27552801
SILType newType = SILType::getPrimitiveObjectType(newFnType);
@@ -2790,6 +2836,34 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
27902836
instr->getLoc(), instr->getValue(), instr->getBase());
27912837
break;
27922838
}
2839+
case SILInstructionKind::DifferentiableFunctionInst: {
2840+
auto instr = cast<DifferentiableFunctionInst>(convInstr);
2841+
newInstr = convBuilder.createDifferentiableFunction(
2842+
instr->getLoc(), instr->getParameterIndices(),
2843+
instr->getOriginalFunction(),
2844+
instr->getOptionalDerivativeFunctionPair());
2845+
break;
2846+
}
2847+
case SILInstructionKind::DifferentiableFunctionExtractInst: {
2848+
auto instr = cast<DifferentiableFunctionExtractInst>(convInstr);
2849+
// Rewrite `differentiable_function_extract` with explicit extractee type.
2850+
newInstr = convBuilder.createDifferentiableFunctionExtract(
2851+
instr->getLoc(), instr->getExtractee(), instr->getOperand(), newType);
2852+
break;
2853+
}
2854+
case SILInstructionKind::LinearFunctionInst: {
2855+
auto instr = cast<LinearFunctionInst>(convInstr);
2856+
newInstr = convBuilder.createLinearFunction(
2857+
instr->getLoc(), instr->getParameterIndices(),
2858+
instr->getOriginalFunction(), instr->getOptionalTransposeFunction());
2859+
break;
2860+
}
2861+
case SILInstructionKind::LinearFunctionExtractInst: {
2862+
auto instr = cast<LinearFunctionExtractInst>(convInstr);
2863+
newInstr = convBuilder.createLinearFunctionExtract(
2864+
instr->getLoc(), instr->getExtractee(), instr->getFunctionOperand());
2865+
break;
2866+
}
27932867
default:
27942868
llvm_unreachable("Unexpected conversion instruction");
27952869
}
@@ -2878,7 +2952,11 @@ void LoadableByAddress::run() {
28782952
case SILInstructionKind::ConvertEscapeToNoEscapeInst:
28792953
case SILInstructionKind::MarkDependenceInst:
28802954
case SILInstructionKind::ThinFunctionToPointerInst:
2881-
case SILInstructionKind::ThinToThickFunctionInst: {
2955+
case SILInstructionKind::ThinToThickFunctionInst:
2956+
case SILInstructionKind::DifferentiableFunctionInst:
2957+
case SILInstructionKind::LinearFunctionInst:
2958+
case SILInstructionKind::LinearFunctionExtractInst:
2959+
case SILInstructionKind::DifferentiableFunctionExtractInst: {
28822960
conversionInstrs.insert(
28832961
cast<SingleValueInstruction>(currInstr));
28842962
break;
@@ -2945,6 +3023,11 @@ void LoadableByAddress::run() {
29453023
if (modApplies.count(PAI) == 0) {
29463024
modApplies.insert(PAI);
29473025
}
3026+
} else if (isa<DifferentiableFunctionInst>(&I) ||
3027+
isa<LinearFunctionInst>(&I) ||
3028+
isa<DifferentiableFunctionExtractInst>(&I) ||
3029+
isa<LinearFunctionExtractInst>(&I)) {
3030+
conversionInstrs.insert(cast<SingleValueInstruction>(&I));
29483031
}
29493032
}
29503033
}
@@ -2988,6 +3071,8 @@ void LoadableByAddress::run() {
29883071
continue;
29893072
else if (recreateApply(I, Delete))
29903073
continue;
3074+
else if (recreateDifferentiabilityWitnessFunction(I, Delete))
3075+
continue;
29913076
else
29923077
fixStoreToBlockStorageInstr(I, Delete);
29933078
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import _Differentiation
2+
3+
public struct LargeLoadableType<T>: AdditiveArithmetic, Differentiable {
4+
public var a, b, c, d, e: Float
5+
6+
public init(a: Float) {
7+
self.a = a
8+
self.b = 0
9+
self.c = 0
10+
self.d = 0
11+
self.e = 0
12+
}
13+
14+
@differentiable
15+
public func externalLBAModifiedFunction(_ x: Float) -> Float {
16+
return x * a
17+
}
18+
19+
// TODO(TF-1226): Remove custom derivative when stdlib derivatives are upstreamed.
20+
@usableFromInline
21+
@derivative(of: externalLBAModifiedFunction)
22+
func externalLBAModifiedFunctionVJP(_ x: Float) -> (
23+
value: Float, pullback: (Float) -> (Self, Float)
24+
) {
25+
let value = externalLBAModifiedFunction(x)
26+
return (value, { v in (Self(a: v * x), v * a) })
27+
}
28+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: %target-swift-frontend -c -enable-large-loadable-types -Xllvm -sil-verify-after-pass=loadable-address %s
2+
// RUN: %target-swift-frontend -emit-sil %s | %FileCheck %s -check-prefix=CHECK-SIL
3+
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %s 2>&1 | %FileCheck %s -check-prefix=CHECK-LBA-SIL
4+
// RUN: %target-run-simple-swift
5+
// REQUIRES: executable_test
6+
7+
// `isLargeLoadableType` depends on the ABI and differs between architectures.
8+
// REQUIRES: CPU=x86_64
9+
10+
// TF-11: Verify that LoadableByAddress works with differentiation-related instructions:
11+
// - `differentiable_function`
12+
// - `differentiable_function_extract`
13+
14+
// TODO: Add tests for `@differentiable(linear)` functions.
15+
16+
import _Differentiation
17+
import StdlibUnittest
18+
19+
var LBATests = TestSuite("LoadableByAddress")
20+
21+
// `Large` is a large loadable type.
22+
// `Large.TangentVector` is not a large loadable type.
23+
struct Large : Differentiable {
24+
var a: Float
25+
var b: Float
26+
var c: Float
27+
var d: Float
28+
@noDerivative let e: Float
29+
}
30+
31+
@_silgen_name("large2large")
32+
@differentiable
33+
func large2large(_ foo: Large) -> Large {
34+
foo
35+
}
36+
37+
// `large2large` old verification error:
38+
// SIL verification failed: JVP type does not match expected JVP type
39+
// $@callee_guaranteed (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector)
40+
// $@callee_guaranteed (@in_constant Large) -> (@out Large, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> @out Large.TangentVector)
41+
42+
@_silgen_name("large2small")
43+
@differentiable
44+
func large2small(_ foo: Large) -> Float {
45+
foo.a
46+
}
47+
48+
// `large2small` old verification error:
49+
// SIL verification failed: JVP type does not match expected JVP type
50+
// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float)
51+
// $@callee_guaranteed (@in_constant Large) -> (Float, @owned @callee_guaranteed (@in_constant Large.TangentVector) -> Float)
52+
53+
// CHECK-SIL: sil hidden @large2large : $@convention(thin) (Large) -> Large {
54+
// CHECK-LBA-SIL: sil hidden @large2large : $@convention(thin) (@in_constant Large) -> @out Large {
55+
56+
// CHECK-SIL-LABEL: sil hidden @large2small : $@convention(thin) (Large) -> Float {
57+
// CHECK-LBA-SIL: sil hidden @large2small : $@convention(thin) (@in_constant Large) -> Float {
58+
59+
// CHECK-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
60+
// CHECK-LBA-SIL: sil hidden @AD__large2large__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
61+
62+
// CHECK-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
63+
// CHECK-LBA-SIL: sil hidden @AD__large2large__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Large, @owned @callee_guaranteed (Large.TangentVector) -> Large.TangentVector) {
64+
65+
// CHECK-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
66+
// CHECK-LBA-SIL: sil hidden @AD__large2small__jvp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Large.TangentVector) -> Float) {
67+
68+
// CHECK-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
69+
// CHECK-LBA-SIL: sil hidden @AD__large2small__vjp_src_0_wrt_0 : $@convention(thin) (@in_constant Large) -> (Float, @owned @callee_guaranteed (Float) -> Large.TangentVector) {
70+
71+
LBATests.test("Correctness") {
72+
let one = Large.TangentVector(a: 1, b: 1, c: 1, d: 1)
73+
expectEqual(one,
74+
pullback(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2large)(one))
75+
expectEqual(Large.TangentVector(a: 1, b: 0, c: 0, d: 0),
76+
gradient(at: Large(a: 0, b: 0, c: 0, d: 0, e: 0), in: large2small))
77+
}
78+
79+
runAllTests()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// First, check that LBA actually modifies the function, so that this test is useful.
2+
3+
// RUN: %target-swift-frontend -emit-sil %S/Inputs/loadable_by_address_cross_module.swift | %FileCheck %s -check-prefix=CHECK-MODULE-PRE-LBA
4+
// RUN: %target-swift-frontend -c -Xllvm -sil-print-after=loadable-address %S/Inputs/loadable_by_address_cross_module.swift 2>&1 | %FileCheck %s -check-prefix=CHECK-MODULE-POST-LBA
5+
6+
// CHECK-MODULE-PRE-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, LargeLoadableType<T>) -> Float
7+
// CHECK-MODULE-POST-LBA: sil {{.*}}LBAModifiedFunction{{.*}} $@convention(method) <T> (Float, @in_constant LargeLoadableType<T>) -> Float
8+
9+
// Compile the module.
10+
11+
// RUN: %empty-directory(%t)
12+
// RUN: %target-build-swift -working-directory %t -parse-as-library -emit-module -module-name external -emit-module-path %t/external.swiftmodule -emit-library -static %S/Inputs/loadable_by_address_cross_module.swift
13+
14+
// Next, check that differentiability_witness_functions in the client get
15+
// correctly modified by LBA.
16+
17+
// RUN: %target-swift-frontend -emit-sil -I%t %s
18+
// RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA
19+
// RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA
20+
21+
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
22+
// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float
23+
24+
// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] <T> @$s8external17LargeLoadableTypeV0A19LBAModifiedFunctionyS2fF : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float, τ_0_0) -> Float for <LargeLoadableType<τ_0_0>>)
25+
// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] <T> @$s8external17LargeLoadableTypeV0A19LBAModifiedFunctionyS2fF : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> (Float, τ_0_0) for <LargeLoadableType<τ_0_0>>)
26+
27+
// Finally, execute the test.
28+
29+
// RUN: %target-build-swift -I%t -L%t %s -o %t/a.out -lm -lexternal
30+
// RUN: %target-run %t/a.out
31+
32+
// REQUIRES: executable_test
33+
34+
import _Differentiation
35+
import external
36+
import StdlibUnittest
37+
38+
var LBATests = TestSuite("LoadableByAddressCrossModule")
39+
40+
LBATests.test("Correctness") {
41+
let g = gradient(at: LargeLoadableType<Int>(a: 5), 10) { $0.externalLBAModifiedFunction($1) }
42+
expectEqual((LargeLoadableType<Int>(a: 10), 5), g)
43+
}
44+
45+
runAllTests()

0 commit comments

Comments
 (0)