@@ -240,6 +240,19 @@ namespace {
240
240
241
241
RetTy visitSILFunctionType (CanSILFunctionType type,
242
242
AbstractionPattern origType) {
243
+ // Handle `@differentiable` and `@differentiable(linear)` functions.
244
+ switch (type->getDifferentiabilityKind ()) {
245
+ case DifferentiabilityKind::Normal:
246
+ return asImpl ().visitNormalDifferentiableSILFunctionType (
247
+ type, getNormalDifferentiableSILFunctionTypeRecursiveProperties (
248
+ type, origType));
249
+ case DifferentiabilityKind::Linear:
250
+ return asImpl ().visitLinearDifferentiableSILFunctionType (
251
+ type, getLinearDifferentiableSILFunctionTypeRecursiveProperties (
252
+ type, origType));
253
+ case DifferentiabilityKind::NonDifferentiable:
254
+ break ;
255
+ }
243
256
// Only escaping closures are references.
244
257
bool isSwiftEscaping = type->getExtInfo ().isNoEscape () &&
245
258
type->getExtInfo ().getRepresentation () ==
@@ -250,6 +263,53 @@ namespace {
250
263
return asImpl ().handleTrivial (type);
251
264
}
252
265
266
+ RecursiveProperties
267
+ getNormalDifferentiableSILFunctionTypeRecursiveProperties (
268
+ CanSILFunctionType type, AbstractionPattern origType) {
269
+ auto &M = TC.M ;
270
+ auto origTy = type->getWithoutDifferentiability ();
271
+ // Pass the `AbstractionPattern` generic signature to
272
+ // `SILFunctionType:getAutoDiffDerivativeFunctionType` for correct type
273
+ // lowering.
274
+ auto jvpTy = origTy->getAutoDiffDerivativeFunctionType (
275
+ type->getDifferentiabilityParameterIndices (), /* resultIndex*/ 0 ,
276
+ AutoDiffDerivativeFunctionKind::JVP, TC,
277
+ LookUpConformanceInModule (&M), CanGenericSignature ());
278
+ auto vjpTy = origTy->getAutoDiffDerivativeFunctionType (
279
+ type->getDifferentiabilityParameterIndices (), /* resultIndex*/ 0 ,
280
+ AutoDiffDerivativeFunctionKind::VJP, TC,
281
+ LookUpConformanceInModule (&M), CanGenericSignature ());
282
+ RecursiveProperties props;
283
+ props.addSubobject (classifyType (origType, origTy, TC, Expansion));
284
+ props.addSubobject (classifyType (origType, jvpTy, TC, Expansion));
285
+ props.addSubobject (classifyType (origType, vjpTy, TC, Expansion));
286
+ return props;
287
+ }
288
+
289
+ RecursiveProperties
290
+ getLinearDifferentiableSILFunctionTypeRecursiveProperties (
291
+ CanSILFunctionType type, AbstractionPattern origType) {
292
+ auto &M = TC.M ;
293
+ auto origTy = type->getWithoutDifferentiability ();
294
+ auto transposeTy = origTy->getAutoDiffTransposeFunctionType (
295
+ type->getDifferentiabilityParameterIndices (), TC,
296
+ LookUpConformanceInModule (&M), origType.getGenericSignatureOrNull ());
297
+ RecursiveProperties props;
298
+ props.addSubobject (classifyType (origType, origTy, TC, Expansion));
299
+ props.addSubobject (classifyType (origType, transposeTy, TC, Expansion));
300
+ return props;
301
+ }
302
+
303
+ RetTy visitNormalDifferentiableSILFunctionType (
304
+ CanSILFunctionType type, RecursiveProperties props) {
305
+ return handleAggregateByProperties (type, props);
306
+ }
307
+
308
+ RetTy visitLinearDifferentiableSILFunctionType (
309
+ CanSILFunctionType type, RecursiveProperties props) {
310
+ return handleAggregateByProperties (type, props);
311
+ }
312
+
253
313
RetTy visitLValueType (CanLValueType type,
254
314
AbstractionPattern origType) {
255
315
llvm_unreachable (" shouldn't get an l-value type here" );
@@ -960,6 +1020,106 @@ namespace {
960
1020
}
961
1021
};
962
1022
1023
+ // / A type lowering for `@differentiable` function types.
1024
+ class NormalDifferentiableSILFunctionTypeLowering final
1025
+ : public LoadableAggTypeLowering<
1026
+ NormalDifferentiableSILFunctionTypeLowering,
1027
+ NormalDifferentiableFunctionTypeComponent> {
1028
+ public:
1029
+ using LoadableAggTypeLowering::LoadableAggTypeLowering;
1030
+
1031
+ SILValue emitRValueProject (
1032
+ SILBuilder &B, SILLocation loc, SILValue tupleValue,
1033
+ NormalDifferentiableFunctionTypeComponent extractee,
1034
+ const TypeLowering &eltLowering) const {
1035
+ return B.createDifferentiableFunctionExtract (
1036
+ loc, extractee, tupleValue);
1037
+ }
1038
+
1039
+ SILValue rebuildAggregate (SILBuilder &B, SILLocation loc,
1040
+ ArrayRef<SILValue> values) const override {
1041
+ assert (values.size () == 3 );
1042
+ auto fnTy = getLoweredType ().castTo <SILFunctionType>();
1043
+ auto paramIndices = fnTy->getDifferentiabilityParameterIndices ();
1044
+ return B.createDifferentiableFunction (
1045
+ loc, paramIndices, values[0 ], std::make_pair (values[1 ], values[2 ]));
1046
+ }
1047
+
1048
+ void lowerChildren (TypeConverter &TC,
1049
+ SmallVectorImpl<Child> &children) const override {
1050
+ auto fnTy = getLoweredType ().castTo <SILFunctionType>();
1051
+ auto numDerivativeFns = 2 ;
1052
+ children.reserve (numDerivativeFns + 1 );
1053
+ auto origFnTy = fnTy->getWithoutDifferentiability ();
1054
+ auto paramIndices = fnTy->getDifferentiabilityParameterIndices ();
1055
+ children.push_back (Child{
1056
+ NormalDifferentiableFunctionTypeComponent::Original,
1057
+ TC.getTypeLowering (origFnTy, getExpansionContext ())
1058
+ });
1059
+ for (AutoDiffDerivativeFunctionKind kind :
1060
+ {AutoDiffDerivativeFunctionKind::JVP,
1061
+ AutoDiffDerivativeFunctionKind::VJP}) {
1062
+ auto derivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType (
1063
+ paramIndices, 0 , kind, TC,
1064
+ LookUpConformanceInModule (&TC.M ));
1065
+ auto silTy = SILType::getPrimitiveObjectType (derivativeFnTy);
1066
+ NormalDifferentiableFunctionTypeComponent extractee (kind);
1067
+ // Assert that we have the right extractee. A terrible bug in the past
1068
+ // was caused by implicit conversions from `unsigned` to
1069
+ // `NormalDifferentiableFunctionTypeComponent` which resulted into a
1070
+ // wrong extractee.
1071
+ assert (extractee.getAsDerivativeFunctionKind () == kind);
1072
+ children.push_back (Child{
1073
+ extractee, TC.getTypeLowering (silTy, getExpansionContext ())});
1074
+ }
1075
+ assert (children.size () == 3 );
1076
+ }
1077
+ };
1078
+
1079
+ // / A type lowering for `@differentiable(linear)` function types.
1080
+ class LinearDifferentiableSILFunctionTypeLowering final
1081
+ : public LoadableAggTypeLowering<
1082
+ LinearDifferentiableSILFunctionTypeLowering,
1083
+ LinearDifferentiableFunctionTypeComponent> {
1084
+ public:
1085
+ using LoadableAggTypeLowering::LoadableAggTypeLowering;
1086
+
1087
+ SILValue emitRValueProject (
1088
+ SILBuilder &B, SILLocation loc, SILValue tupleValue,
1089
+ LinearDifferentiableFunctionTypeComponent component,
1090
+ const TypeLowering &eltLowering) const {
1091
+ return B.createLinearFunctionExtract (loc, component, tupleValue);
1092
+ }
1093
+
1094
+ SILValue rebuildAggregate (SILBuilder &B, SILLocation loc,
1095
+ ArrayRef<SILValue> values) const override {
1096
+ assert (values.size () == 2 );
1097
+ auto fnTy = getLoweredType ().castTo <SILFunctionType>();
1098
+ auto paramIndices = fnTy->getDifferentiabilityParameterIndices ();
1099
+ return B.createLinearFunction (loc, paramIndices, values[0 ], values[1 ]);
1100
+ }
1101
+
1102
+ void lowerChildren (TypeConverter &TC,
1103
+ SmallVectorImpl<Child> &children) const override {
1104
+ auto fnTy = getLoweredType ().castTo <SILFunctionType>();
1105
+ children.reserve (2 );
1106
+ auto origFnTy = fnTy->getWithoutDifferentiability ();
1107
+ auto paramIndices = fnTy->getDifferentiabilityParameterIndices ();
1108
+ children.push_back (Child{
1109
+ LinearDifferentiableFunctionTypeComponent::Original,
1110
+ TC.getTypeLowering (origFnTy, getExpansionContext ())
1111
+ });
1112
+ auto transposeFnTy = origFnTy->getAutoDiffTransposeFunctionType (
1113
+ paramIndices, TC, LookUpConformanceInModule (&TC.M ));
1114
+ auto transposeSILFnTy = SILType::getPrimitiveObjectType (transposeFnTy);
1115
+ children.push_back (Child{
1116
+ LinearDifferentiableFunctionTypeComponent::Transpose,
1117
+ TC.getTypeLowering (transposeSILFnTy, getExpansionContext ())
1118
+ });
1119
+ assert (children.size () == 2 );
1120
+ }
1121
+ };
1122
+
963
1123
class LeafLoadableTypeLowering : public NonTrivialLoadableTypeLowering {
964
1124
public:
965
1125
LeafLoadableTypeLowering (SILType type, RecursiveProperties properties,
@@ -1358,6 +1518,20 @@ namespace {
1358
1518
properties);
1359
1519
}
1360
1520
1521
+ TypeLowering *
1522
+ visitNormalDifferentiableSILFunctionType (CanSILFunctionType type,
1523
+ RecursiveProperties props) {
1524
+ return handleAggregateByProperties
1525
+ <NormalDifferentiableSILFunctionTypeLowering>(type, props);
1526
+ }
1527
+
1528
+ TypeLowering *
1529
+ visitLinearDifferentiableSILFunctionType (CanSILFunctionType type,
1530
+ RecursiveProperties props) {
1531
+ return handleAggregateByProperties
1532
+ <LinearDifferentiableSILFunctionTypeLowering>(type, props);
1533
+ }
1534
+
1361
1535
template <class LoadableLoweringClass >
1362
1536
TypeLowering *handleAggregateByProperties (CanType type,
1363
1537
RecursiveProperties props) {
0 commit comments