32
32
using namespace swift ;
33
33
using namespace irgen ;
34
34
35
- using DiffFuncIndex = DifferentiableFunctionExtractee;
36
35
36
+ // ----------------------------------------------------------------------------//
37
+ // `@differentiable` (non-linear) function type info
38
+ // ----------------------------------------------------------------------------//
37
39
namespace {
38
- class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
40
+ class DifferentiableFuncFieldInfo final
41
+ : public RecordField<DifferentiableFuncFieldInfo> {
39
42
public:
40
- DiffFuncFieldInfo (DiffFuncIndex index, const TypeInfo &type,
41
- IndexSubset *parameterIndices)
42
- : RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
43
+ DifferentiableFuncFieldInfo (
44
+ DifferentiableFunctionExtractee component, const TypeInfo &type,
45
+ IndexSubset *parameterIndices)
46
+ : RecordField(type), component(component),
47
+ parameterIndices (parameterIndices) {}
43
48
44
49
// / The field index.
45
- const DiffFuncIndex Index ;
50
+ const DifferentiableFunctionExtractee component ;
46
51
47
52
// / The parameter indices.
48
- IndexSubset *ParameterIndices ;
53
+ IndexSubset *parameterIndices ;
49
54
50
55
std::string getFieldName () const {
51
- switch (Index ) {
56
+ switch (component ) {
52
57
case DifferentiableFunctionExtractee::Original:
53
58
return " original" ;
54
59
case DifferentiableFunctionExtractee::JVP:
@@ -61,32 +66,32 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
61
66
SILType getType (IRGenModule &IGM, SILType t) const {
62
67
auto fnTy = t.castTo <SILFunctionType>();
63
68
auto origFnTy = fnTy->getWithoutDifferentiability ();
64
- if (Index == DifferentiableFunctionExtractee::Original)
69
+ if (component == DifferentiableFunctionExtractee::Original)
65
70
return SILType::getPrimitiveObjectType (origFnTy);
66
- auto kind = *Index .getExtracteeAsDerivativeFunction ();
71
+ auto kind = *component .getExtracteeAsDerivativeFunction ();
67
72
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType (
68
- ParameterIndices , /* resultIndex*/ 0 , kind,
73
+ parameterIndices , /* resultIndex*/ 0 , kind,
69
74
IGM.getSILTypes (), LookUpConformanceInModule (IGM.getSwiftModule ()));
70
75
return SILType::getPrimitiveObjectType (assocTy);
71
76
}
72
77
};
73
78
74
- class DiffFuncTypeInfo final
75
- : public RecordTypeInfo<DiffFuncTypeInfo , LoadableTypeInfo,
76
- DiffFuncFieldInfo > {
79
+ class DifferentiableFuncTypeInfo final
80
+ : public RecordTypeInfo<DifferentiableFuncTypeInfo , LoadableTypeInfo,
81
+ DifferentiableFuncFieldInfo > {
77
82
using super =
78
- RecordTypeInfo<DiffFuncTypeInfo , LoadableTypeInfo, DiffFuncFieldInfo >;
83
+ RecordTypeInfo<DifferentiableFuncTypeInfo , LoadableTypeInfo, DifferentiableFuncFieldInfo >;
79
84
80
85
public:
81
- DiffFuncTypeInfo (ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
82
- llvm::Type *ty, Size size, SpareBitVector &&spareBits ,
83
- Alignment align, IsPOD_t isPOD ,
84
- IsFixedSize_t alwaysFixedSize)
86
+ DifferentiableFuncTypeInfo (
87
+ ArrayRef<DifferentiableFuncFieldInfo> fields, unsigned explosionSize ,
88
+ llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
89
+ IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
85
90
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
86
91
isPOD, alwaysFixedSize) {}
87
92
88
93
Address projectFieldAddress (IRGenFunction &IGF, Address addr, SILType T,
89
- const DiffFuncFieldInfo &field) const {
94
+ const DifferentiableFuncFieldInfo &field) const {
90
95
return field.projectAddress (IGF, addr, getNonFixedOffsets (IGF, T));
91
96
}
92
97
@@ -110,50 +115,52 @@ class DiffFuncTypeInfo final
110
115
}
111
116
};
112
117
113
- class DiffFuncTypeBuilder
114
- : public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo ,
115
- DiffFuncIndex > {
118
+ class DifferentiableFuncTypeBuilder
119
+ : public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo ,
120
+ DifferentiableFunctionExtractee > {
116
121
117
- SILFunctionType *origFnTy ;
122
+ SILFunctionType *originalType ;
118
123
IndexSubset *parameterIndices;
119
124
120
125
public:
121
- DiffFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
122
- : RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability ()),
126
+ DifferentiableFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
127
+ : RecordTypeBuilder(IGM),
128
+ originalType (fnTy->getWithoutDifferentiability ()),
123
129
parameterIndices(fnTy->getDifferentiationParameterIndices ()) {
124
- assert (fnTy->isDifferentiable () );
130
+ assert (fnTy->getDifferentiabilityKind () == DifferentiabilityKind::Normal );
125
131
}
126
132
127
- TypeInfo *createFixed (ArrayRef<DiffFuncFieldInfo > fields,
133
+ TypeInfo *createFixed (ArrayRef<DifferentiableFuncFieldInfo > fields,
128
134
StructLayout &&layout) {
129
135
llvm_unreachable (" @differentiable functions are always loadable" );
130
136
}
131
137
132
- DiffFuncTypeInfo *createLoadable (ArrayRef<DiffFuncFieldInfo> fields,
133
- StructLayout &&layout,
134
- unsigned explosionSize) {
135
- return DiffFuncTypeInfo ::create (
138
+ DifferentiableFuncTypeInfo *createLoadable (
139
+ ArrayRef<DifferentiableFuncFieldInfo> fields, StructLayout &&layout,
140
+ unsigned explosionSize) {
141
+ return DifferentiableFuncTypeInfo ::create (
136
142
fields, explosionSize, layout.getType (), layout.getSize (),
137
143
std::move (layout.getSpareBits ()), layout.getAlignment (), layout.isPOD (),
138
144
layout.isAlwaysFixedSize ());
139
145
}
140
146
141
- TypeInfo *createNonFixed (ArrayRef<DiffFuncFieldInfo > fields,
147
+ TypeInfo *createNonFixed (ArrayRef<DifferentiableFuncFieldInfo > fields,
142
148
FieldsAreABIAccessible_t fieldsAccessible,
143
149
StructLayout &&layout) {
144
150
llvm_unreachable (" @differentiable functions are always loadable" );
145
151
}
146
152
147
- DiffFuncFieldInfo getFieldInfo (unsigned index, DiffFuncIndex field,
148
- const TypeInfo &fieldTI) {
149
- return DiffFuncFieldInfo (field, fieldTI, parameterIndices);
153
+ DifferentiableFuncFieldInfo getFieldInfo (
154
+ unsigned index, DifferentiableFunctionExtractee component,
155
+ const TypeInfo &fieldTI) {
156
+ return DifferentiableFuncFieldInfo (component, fieldTI, parameterIndices);
150
157
}
151
158
152
- SILType getType (DiffFuncIndex field ) {
153
- if (field == DifferentiableFunctionExtractee::Original)
154
- return SILType::getPrimitiveObjectType (origFnTy ->getCanonicalType ());
155
- auto kind = *field .getExtracteeAsDerivativeFunction ();
156
- auto assocTy = origFnTy ->getAutoDiffDerivativeFunctionType (
159
+ SILType getType (DifferentiableFunctionExtractee component ) {
160
+ if (component == DifferentiableFunctionExtractee::Original)
161
+ return SILType::getPrimitiveObjectType (originalType ->getCanonicalType ());
162
+ auto kind = *component .getExtracteeAsDerivativeFunction ();
163
+ auto assocTy = originalType ->getAutoDiffDerivativeFunctionType (
157
164
parameterIndices, /* resultIndex*/ 0 , kind, IGM.getSILTypes (),
158
165
LookUpConformanceInModule (IGM.getSwiftModule ()));
159
166
return SILType::getPrimitiveObjectType (assocTy);
@@ -166,11 +173,161 @@ class DiffFuncTypeBuilder
166
173
};
167
174
} // end anonymous namespace
168
175
176
+ // ----------------------------------------------------------------------------//
177
+ // `@differentiable(linear)` function type info
178
+ // ----------------------------------------------------------------------------//
179
+ namespace {
180
+ class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
181
+ public:
182
+ LinearFuncFieldInfo (LinearDifferentiableFunctionTypeComponent component,
183
+ const TypeInfo &type, IndexSubset *parameterIndices)
184
+ : RecordField(type), component(component),
185
+ parameterIndices (parameterIndices) {}
186
+
187
+ // / The field index.
188
+ const LinearDifferentiableFunctionTypeComponent component;
189
+
190
+ // / The parameter indices.
191
+ IndexSubset *parameterIndices;
192
+
193
+ std::string getFieldName () const {
194
+ switch (component) {
195
+ case LinearDifferentiableFunctionTypeComponent::Original:
196
+ return " original" ;
197
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
198
+ return " transpose" ;
199
+ }
200
+ }
201
+
202
+ SILType getType (IRGenModule &IGM, SILType t) const {
203
+ auto fnTy = t.castTo <SILFunctionType>();
204
+ auto origFnTy = fnTy->getWithoutDifferentiability ();
205
+ switch (component) {
206
+ case LinearDifferentiableFunctionTypeComponent::Original:
207
+ return SILType::getPrimitiveObjectType (origFnTy);
208
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
209
+ auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType (
210
+ parameterIndices, IGM.getSILTypes (),
211
+ LookUpConformanceInModule (IGM.getSwiftModule ()));
212
+ return SILType::getPrimitiveObjectType (transposeTy);
213
+ }
214
+ }
215
+ };
216
+
217
+ class LinearFuncTypeInfo final
218
+ : public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
219
+ LinearFuncFieldInfo> {
220
+ using super =
221
+ RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
222
+
223
+ public:
224
+ LinearFuncTypeInfo (
225
+ ArrayRef<LinearFuncFieldInfo> fields, unsigned explosionSize,
226
+ llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
227
+ IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
228
+ : super(fields, explosionSize, ty, size, std::move(spareBits), align,
229
+ isPOD, alwaysFixedSize) {}
230
+
231
+ Address projectFieldAddress (IRGenFunction &IGF, Address addr, SILType T,
232
+ const LinearFuncFieldInfo &field) const {
233
+ return field.projectAddress (IGF, addr, getNonFixedOffsets (IGF, T));
234
+ }
235
+
236
+ void initializeFromParams (IRGenFunction &IGF, Explosion ¶ms, Address src,
237
+ SILType T, bool isOutlined) const override {
238
+ llvm_unreachable (" unexploded @differentiable function as argument?" );
239
+ }
240
+
241
+ void addToAggLowering (IRGenModule &IGM, SwiftAggLowering &lowering,
242
+ Size offset) const override {
243
+ for (auto &field : getFields ()) {
244
+ auto fieldOffset = offset + field.getFixedByteOffset ();
245
+ cast<LoadableTypeInfo>(field.getTypeInfo ())
246
+ .addToAggLowering (IGM, lowering, fieldOffset);
247
+ }
248
+ }
249
+
250
+ llvm::NoneType getNonFixedOffsets (IRGenFunction &IGF) const { return None; }
251
+ llvm::NoneType getNonFixedOffsets (IRGenFunction &IGF, SILType T) const {
252
+ return None;
253
+ }
254
+ };
255
+
256
+ class LinearFuncTypeBuilder
257
+ : public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
258
+ LinearDifferentiableFunctionTypeComponent> {
259
+
260
+ SILFunctionType *originalType;
261
+ IndexSubset *parameterIndices;
262
+
263
+ public:
264
+ LinearFuncTypeBuilder (IRGenModule &IGM, SILFunctionType *fnTy)
265
+ : RecordTypeBuilder(IGM),
266
+ originalType (fnTy->getWithoutDifferentiability ()),
267
+ parameterIndices(fnTy->getDifferentiationParameterIndices ()) {
268
+ assert (fnTy->getDifferentiabilityKind () == DifferentiabilityKind::Linear);
269
+ }
270
+
271
+ TypeInfo *createFixed (ArrayRef<LinearFuncFieldInfo> fields,
272
+ StructLayout &&layout) {
273
+ llvm_unreachable (" @differentiable functions are always loadable" );
274
+ }
275
+
276
+ LinearFuncTypeInfo *createLoadable (ArrayRef<LinearFuncFieldInfo> fields,
277
+ StructLayout &&layout,
278
+ unsigned explosionSize) {
279
+ return LinearFuncTypeInfo::create (
280
+ fields, explosionSize, layout.getType (), layout.getSize (),
281
+ std::move (layout.getSpareBits ()), layout.getAlignment (), layout.isPOD (),
282
+ layout.isAlwaysFixedSize ());
283
+ }
284
+
285
+ TypeInfo *createNonFixed (ArrayRef<LinearFuncFieldInfo> fields,
286
+ FieldsAreABIAccessible_t fieldsAccessible,
287
+ StructLayout &&layout) {
288
+ llvm_unreachable (" @differentiable functions are always loadable" );
289
+ }
290
+
291
+ LinearFuncFieldInfo getFieldInfo (
292
+ unsigned index, LinearDifferentiableFunctionTypeComponent field,
293
+ const TypeInfo &fieldTI) {
294
+ return LinearFuncFieldInfo (field, fieldTI, parameterIndices);
295
+ }
296
+
297
+ SILType getType (LinearDifferentiableFunctionTypeComponent component) {
298
+ switch (component) {
299
+ case LinearDifferentiableFunctionTypeComponent::Original:
300
+ return SILType::getPrimitiveObjectType (originalType->getCanonicalType ());
301
+ case LinearDifferentiableFunctionTypeComponent::Transpose:
302
+ auto transposeTy = originalType->getAutoDiffTransposeFunctionType (
303
+ parameterIndices, IGM.getSILTypes (),
304
+ LookUpConformanceInModule (IGM.getSwiftModule ()));
305
+ return SILType::getPrimitiveObjectType (transposeTy);
306
+ }
307
+ }
308
+
309
+ StructLayout performLayout (ArrayRef<const TypeInfo *> fieldTypes) {
310
+ return StructLayout (IGM, /* decl=*/ nullptr , LayoutKind::NonHeapObject,
311
+ LayoutStrategy::Universal, fieldTypes);
312
+ }
313
+ };
314
+ } // end anonymous namespace
315
+
316
+ // ----------------------------------------------------------------------------//
317
+ // Type converter entry points
318
+ // ----------------------------------------------------------------------------//
319
+
169
320
const TypeInfo *
170
- TypeConverter::convertDifferentiableFunctionType (SILFunctionType *type) {
171
- assert (type->isDifferentiable ());
172
- DiffFuncTypeBuilder builder (IGM, type);
321
+ TypeConverter::convertNormalDifferentiableFunctionType (SILFunctionType *type) {
322
+ DifferentiableFuncTypeBuilder builder (IGM, type);
173
323
return builder.layout ({DifferentiableFunctionExtractee::Original,
174
324
DifferentiableFunctionExtractee::JVP,
175
325
DifferentiableFunctionExtractee::VJP});
176
326
}
327
+
328
+ const TypeInfo *
329
+ TypeConverter::convertLinearDifferentiableFunctionType (SILFunctionType *type) {
330
+ LinearFuncTypeBuilder builder (IGM, type);
331
+ return builder.layout ({LinearDifferentiableFunctionTypeComponent::Original,
332
+ LinearDifferentiableFunctionTypeComponent::Transpose});
333
+ }
0 commit comments