Skip to content

Commit 80e5a51

Browse files
authored
[AutoDiff upstream] Add differentiable function type lowering. (swiftlang#30677)
Add `@differentiable` and `@differentiable(linear)` type lowering. Resolves TF-1221.
1 parent 37f75ec commit 80e5a51

File tree

2 files changed

+183
-3
lines changed

2 files changed

+183
-3
lines changed

lib/SIL/TypeLowering.cpp

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,19 @@ namespace {
240240

241241
RetTy visitSILFunctionType(CanSILFunctionType type,
242242
AbstractionPattern origType) {
243+
// Handle `@differentiable` and `@differentiable(linear)` functions.
244+
switch (type->getDifferentiabilityKind()) {
245+
case DifferentiabilityKind::Normal:
246+
return asImpl().visitNormalDifferentiableSILFunctionType(
247+
type, getNormalDifferentiableSILFunctionTypeRecursiveProperties(
248+
type, origType));
249+
case DifferentiabilityKind::Linear:
250+
return asImpl().visitLinearDifferentiableSILFunctionType(
251+
type, getLinearDifferentiableSILFunctionTypeRecursiveProperties(
252+
type, origType));
253+
case DifferentiabilityKind::NonDifferentiable:
254+
break;
255+
}
243256
// Only escaping closures are references.
244257
bool isSwiftEscaping = type->getExtInfo().isNoEscape() &&
245258
type->getExtInfo().getRepresentation() ==
@@ -250,6 +263,53 @@ namespace {
250263
return asImpl().handleTrivial(type);
251264
}
252265

266+
RecursiveProperties
267+
getNormalDifferentiableSILFunctionTypeRecursiveProperties(
268+
CanSILFunctionType type, AbstractionPattern origType) {
269+
auto &M = TC.M;
270+
auto origTy = type->getWithoutDifferentiability();
271+
// Pass the `AbstractionPattern` generic signature to
272+
// `SILFunctionType:getAutoDiffDerivativeFunctionType` for correct type
273+
// lowering.
274+
auto jvpTy = origTy->getAutoDiffDerivativeFunctionType(
275+
type->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0,
276+
AutoDiffDerivativeFunctionKind::JVP, TC,
277+
LookUpConformanceInModule(&M), CanGenericSignature());
278+
auto vjpTy = origTy->getAutoDiffDerivativeFunctionType(
279+
type->getDifferentiabilityParameterIndices(), /*resultIndex*/ 0,
280+
AutoDiffDerivativeFunctionKind::VJP, TC,
281+
LookUpConformanceInModule(&M), CanGenericSignature());
282+
RecursiveProperties props;
283+
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
284+
props.addSubobject(classifyType(origType, jvpTy, TC, Expansion));
285+
props.addSubobject(classifyType(origType, vjpTy, TC, Expansion));
286+
return props;
287+
}
288+
289+
RecursiveProperties
290+
getLinearDifferentiableSILFunctionTypeRecursiveProperties(
291+
CanSILFunctionType type, AbstractionPattern origType) {
292+
auto &M = TC.M;
293+
auto origTy = type->getWithoutDifferentiability();
294+
auto transposeTy = origTy->getAutoDiffTransposeFunctionType(
295+
type->getDifferentiabilityParameterIndices(), TC,
296+
LookUpConformanceInModule(&M), origType.getGenericSignatureOrNull());
297+
RecursiveProperties props;
298+
props.addSubobject(classifyType(origType, origTy, TC, Expansion));
299+
props.addSubobject(classifyType(origType, transposeTy, TC, Expansion));
300+
return props;
301+
}
302+
303+
RetTy visitNormalDifferentiableSILFunctionType(
304+
CanSILFunctionType type, RecursiveProperties props) {
305+
return handleAggregateByProperties(type, props);
306+
}
307+
308+
RetTy visitLinearDifferentiableSILFunctionType(
309+
CanSILFunctionType type, RecursiveProperties props) {
310+
return handleAggregateByProperties(type, props);
311+
}
312+
253313
RetTy visitLValueType(CanLValueType type,
254314
AbstractionPattern origType) {
255315
llvm_unreachable("shouldn't get an l-value type here");
@@ -960,6 +1020,106 @@ namespace {
9601020
}
9611021
};
9621022

1023+
/// A type lowering for `@differentiable` function types.
1024+
class NormalDifferentiableSILFunctionTypeLowering final
1025+
: public LoadableAggTypeLowering<
1026+
NormalDifferentiableSILFunctionTypeLowering,
1027+
NormalDifferentiableFunctionTypeComponent> {
1028+
public:
1029+
using LoadableAggTypeLowering::LoadableAggTypeLowering;
1030+
1031+
SILValue emitRValueProject(
1032+
SILBuilder &B, SILLocation loc, SILValue tupleValue,
1033+
NormalDifferentiableFunctionTypeComponent extractee,
1034+
const TypeLowering &eltLowering) const {
1035+
return B.createDifferentiableFunctionExtract(
1036+
loc, extractee, tupleValue);
1037+
}
1038+
1039+
SILValue rebuildAggregate(SILBuilder &B, SILLocation loc,
1040+
ArrayRef<SILValue> values) const override {
1041+
assert(values.size() == 3);
1042+
auto fnTy = getLoweredType().castTo<SILFunctionType>();
1043+
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
1044+
return B.createDifferentiableFunction(
1045+
loc, paramIndices, values[0], std::make_pair(values[1], values[2]));
1046+
}
1047+
1048+
void lowerChildren(TypeConverter &TC,
1049+
SmallVectorImpl<Child> &children) const override {
1050+
auto fnTy = getLoweredType().castTo<SILFunctionType>();
1051+
auto numDerivativeFns = 2;
1052+
children.reserve(numDerivativeFns + 1);
1053+
auto origFnTy = fnTy->getWithoutDifferentiability();
1054+
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
1055+
children.push_back(Child{
1056+
NormalDifferentiableFunctionTypeComponent::Original,
1057+
TC.getTypeLowering(origFnTy, getExpansionContext())
1058+
});
1059+
for (AutoDiffDerivativeFunctionKind kind :
1060+
{AutoDiffDerivativeFunctionKind::JVP,
1061+
AutoDiffDerivativeFunctionKind::VJP}) {
1062+
auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
1063+
paramIndices, 0, kind, TC,
1064+
LookUpConformanceInModule(&TC.M));
1065+
auto silTy = SILType::getPrimitiveObjectType(derivativeFnTy);
1066+
NormalDifferentiableFunctionTypeComponent extractee(kind);
1067+
// Assert that we have the right extractee. A terrible bug in the past
1068+
// was caused by implicit conversions from `unsigned` to
1069+
// `NormalDifferentiableFunctionTypeComponent` which resulted into a
1070+
// wrong extractee.
1071+
assert(extractee.getAsDerivativeFunctionKind() == kind);
1072+
children.push_back(Child{
1073+
extractee, TC.getTypeLowering(silTy, getExpansionContext())});
1074+
}
1075+
assert(children.size() == 3);
1076+
}
1077+
};
1078+
1079+
/// A type lowering for `@differentiable(linear)` function types.
1080+
class LinearDifferentiableSILFunctionTypeLowering final
1081+
: public LoadableAggTypeLowering<
1082+
LinearDifferentiableSILFunctionTypeLowering,
1083+
LinearDifferentiableFunctionTypeComponent> {
1084+
public:
1085+
using LoadableAggTypeLowering::LoadableAggTypeLowering;
1086+
1087+
SILValue emitRValueProject(
1088+
SILBuilder &B, SILLocation loc, SILValue tupleValue,
1089+
LinearDifferentiableFunctionTypeComponent component,
1090+
const TypeLowering &eltLowering) const {
1091+
return B.createLinearFunctionExtract(loc, component, tupleValue);
1092+
}
1093+
1094+
SILValue rebuildAggregate(SILBuilder &B, SILLocation loc,
1095+
ArrayRef<SILValue> values) const override {
1096+
assert(values.size() == 2);
1097+
auto fnTy = getLoweredType().castTo<SILFunctionType>();
1098+
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
1099+
return B.createLinearFunction(loc, paramIndices, values[0], values[1]);
1100+
}
1101+
1102+
void lowerChildren(TypeConverter &TC,
1103+
SmallVectorImpl<Child> &children) const override {
1104+
auto fnTy = getLoweredType().castTo<SILFunctionType>();
1105+
children.reserve(2);
1106+
auto origFnTy = fnTy->getWithoutDifferentiability();
1107+
auto paramIndices = fnTy->getDifferentiabilityParameterIndices();
1108+
children.push_back(Child{
1109+
LinearDifferentiableFunctionTypeComponent::Original,
1110+
TC.getTypeLowering(origFnTy, getExpansionContext())
1111+
});
1112+
auto transposeFnTy = origFnTy->getAutoDiffTransposeFunctionType(
1113+
paramIndices, TC, LookUpConformanceInModule(&TC.M));
1114+
auto transposeSILFnTy = SILType::getPrimitiveObjectType(transposeFnTy);
1115+
children.push_back(Child{
1116+
LinearDifferentiableFunctionTypeComponent::Transpose,
1117+
TC.getTypeLowering(transposeSILFnTy, getExpansionContext())
1118+
});
1119+
assert(children.size() == 2);
1120+
}
1121+
};
1122+
9631123
class LeafLoadableTypeLowering : public NonTrivialLoadableTypeLowering {
9641124
public:
9651125
LeafLoadableTypeLowering(SILType type, RecursiveProperties properties,
@@ -1358,6 +1518,20 @@ namespace {
13581518
properties);
13591519
}
13601520

1521+
TypeLowering *
1522+
visitNormalDifferentiableSILFunctionType(CanSILFunctionType type,
1523+
RecursiveProperties props) {
1524+
return handleAggregateByProperties
1525+
<NormalDifferentiableSILFunctionTypeLowering>(type, props);
1526+
}
1527+
1528+
TypeLowering *
1529+
visitLinearDifferentiableSILFunctionType(CanSILFunctionType type,
1530+
RecursiveProperties props) {
1531+
return handleAggregateByProperties
1532+
<LinearDifferentiableSILFunctionTypeLowering>(type, props);
1533+
}
1534+
13611535
template <class LoadableLoweringClass>
13621536
TypeLowering *handleAggregateByProperties(CanType type,
13631537
RecursiveProperties props) {

test/AutoDiff/SIL/Serialization/differentiable_function.swift

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
// RUN: %empty-directory(%t)
2-
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name differentiation -enable-experimental-differentiable-programming
3-
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name differentiation -enable-experimental-differentiable-programming
4-
// RUN: %target-sil-opt %t/tmp.2.sib -module-name differentiation -emit-sorted-sil -enable-experimental-differentiable-programming | %FileCheck %s
2+
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %s -emit-sib -o %t/tmp.sib -module-name main
3+
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp.sib -o %t/tmp.sil -module-name main
4+
// NOTE(SR-12090): Workaround because import declarations are not preserved in .sib files.
5+
// RUN: sed -e 's/import Swift$/import Swift; import _Differentiation/' %t/tmp.sil > %t/tmp_fixed.sil
6+
// RUN: %target-sil-opt -enable-experimental-differentiable-programming %t/tmp_fixed.sil -module-name main -emit-sorted-sil | %FileCheck %s
7+
8+
// NOTE(SR-12090): `shell` is required only to run `sed` as a SR-12090 workaround.
9+
// REQUIRES: shell
510

611
sil_stage raw
712

813
import Swift
14+
import _Differentiation
915

1016
sil @a : $@convention(thin) (@differentiable (Float) -> Float) -> @differentiable (Float) -> Float {
1117
bb0(%0 : $@differentiable (Float) -> Float):

0 commit comments

Comments
 (0)