Skip to content

Commit 4a4fe83

Browse files
authored
[AutoDiff] [IRGen] Lower @differentiable(linear) function types. (swiftlang#27661)
This patch adds support for lowering `@differentiable(linear)` function types to LLVM IR. * Add `SILFunctionType` transpose type calculation utility: `SILFunctionType:: getAutoDiffTransposeFunctionType`. * Refactor `TypeClassifierBase` methods used for computing recursive properties of `@differentiable` function types and eliminated the need for repeatedly checking differentiability kind. * Add `LinearDifferentiableSILFunctionTypeLowering`. Rename the original `DifferentiableSILFunctionTypeLowering` to `NormalDifferentiableSILFunctionTypeLowering` to make clear it's not for linear functions. * Add `TypeConverter::convertLinearDifferentiableFunctionType`. Rename the original `TypeConverter::convertDifferentiableFunctionType` to `TypeConverter::convertNormalDifferentiableFunctionType` to make clear it's not for linear functions. Resolves [TF-902](https://bugs.swift.org/browse/TF-902).
1 parent 5e52226 commit 4a4fe83

File tree

8 files changed

+470
-92
lines changed

8 files changed

+470
-92
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,24 @@ class SILFunctionType;
3939
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
4040
enum class SILLinkage : uint8_t;
4141

42-
enum class DifferentiabilityKind: uint8_t {
42+
enum class DifferentiabilityKind : uint8_t {
4343
NonDifferentiable = 0b00,
4444
Normal = 0b01,
4545
Linear = 0b11
4646
};
4747

48+
// TODO(TF-904): Replace `DifferentiableFunctionExtractInst::Extractee`.
49+
enum class NormalDifferentiableFunctionTypeComponent : uint8_t {
50+
Original = 0,
51+
JVP = 1,
52+
VJP = 2
53+
};
54+
55+
enum class LinearDifferentiableFunctionTypeComponent : uint8_t {
56+
Original = 0,
57+
Transpose = 1
58+
};
59+
4860
class ParsedAutoDiffParameter {
4961
public:
5062
enum class Kind { Named, Ordered, Self };

include/swift/AST/Types.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4220,14 +4220,19 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42204220

42214221
CanSILFunctionType getWithoutDifferentiability();
42224222

4223-
/// Returns the type of a differentiation function that is associated with
4224-
/// a function of this type.
4223+
/// Returns the type of the derivative function.
42254224
CanSILFunctionType getAutoDiffDerivativeFunctionType(
42264225
IndexSubset *parameterIndices, unsigned resultIndex,
42274226
AutoDiffDerivativeFunctionKind kind, Lowering::TypeConverter &TC,
42284227
LookupConformanceFn lookupConformance,
42294228
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
42304229

4230+
/// Returns the type of the transpose function.
4231+
CanSILFunctionType getAutoDiffTransposeFunctionType(
4232+
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
4233+
LookupConformanceFn lookupConformance,
4234+
CanGenericSignature derivativeFunctionGenericSignature = nullptr);
4235+
42314236
/// Returns a bit vector that specifices which parameters you can
42324237
/// differentiate with respect to for this differentiable function type. (e.g.
42334238
/// which parameters are not `@nondiff`). The function type must be

lib/IRGen/GenDiffFunc.cpp

Lines changed: 201 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,28 @@
3232
using namespace swift;
3333
using namespace irgen;
3434

35-
using DiffFuncIndex = DifferentiableFunctionExtractee;
3635

36+
//----------------------------------------------------------------------------//
37+
// `@differentiable` (non-linear) function type info
38+
//----------------------------------------------------------------------------//
3739
namespace {
38-
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
40+
class DifferentiableFuncFieldInfo final
41+
: public RecordField<DifferentiableFuncFieldInfo> {
3942
public:
40-
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type,
41-
IndexSubset *parameterIndices)
42-
: RecordField(type), Index(index), ParameterIndices(parameterIndices) {}
43+
DifferentiableFuncFieldInfo(
44+
DifferentiableFunctionExtractee component, const TypeInfo &type,
45+
IndexSubset *parameterIndices)
46+
: RecordField(type), component(component),
47+
parameterIndices(parameterIndices) {}
4348

4449
/// The field index.
45-
const DiffFuncIndex Index;
50+
const DifferentiableFunctionExtractee component;
4651

4752
/// The parameter indices.
48-
IndexSubset *ParameterIndices;
53+
IndexSubset *parameterIndices;
4954

5055
std::string getFieldName() const {
51-
switch (Index) {
56+
switch (component) {
5257
case DifferentiableFunctionExtractee::Original:
5358
return "original";
5459
case DifferentiableFunctionExtractee::JVP:
@@ -61,32 +66,32 @@ class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
6166
SILType getType(IRGenModule &IGM, SILType t) const {
6267
auto fnTy = t.castTo<SILFunctionType>();
6368
auto origFnTy = fnTy->getWithoutDifferentiability();
64-
if (Index == DifferentiableFunctionExtractee::Original)
69+
if (component == DifferentiableFunctionExtractee::Original)
6570
return SILType::getPrimitiveObjectType(origFnTy);
66-
auto kind = *Index.getExtracteeAsDerivativeFunction();
71+
auto kind = *component.getExtracteeAsDerivativeFunction();
6772
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
68-
ParameterIndices, /*resultIndex*/ 0, kind,
73+
parameterIndices, /*resultIndex*/ 0, kind,
6974
IGM.getSILTypes(), LookUpConformanceInModule(IGM.getSwiftModule()));
7075
return SILType::getPrimitiveObjectType(assocTy);
7176
}
7277
};
7378

74-
class DiffFuncTypeInfo final
75-
: public RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo,
76-
DiffFuncFieldInfo> {
79+
class DifferentiableFuncTypeInfo final
80+
: public RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo,
81+
DifferentiableFuncFieldInfo> {
7782
using super =
78-
RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, DiffFuncFieldInfo>;
83+
RecordTypeInfo<DifferentiableFuncTypeInfo, LoadableTypeInfo, DifferentiableFuncFieldInfo>;
7984

8085
public:
81-
DiffFuncTypeInfo(ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
82-
llvm::Type *ty, Size size, SpareBitVector &&spareBits,
83-
Alignment align, IsPOD_t isPOD,
84-
IsFixedSize_t alwaysFixedSize)
86+
DifferentiableFuncTypeInfo(
87+
ArrayRef<DifferentiableFuncFieldInfo> fields, unsigned explosionSize,
88+
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
89+
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
8590
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
8691
isPOD, alwaysFixedSize) {}
8792

8893
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
89-
const DiffFuncFieldInfo &field) const {
94+
const DifferentiableFuncFieldInfo &field) const {
9095
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
9196
}
9297

@@ -110,50 +115,52 @@ class DiffFuncTypeInfo final
110115
}
111116
};
112117

113-
class DiffFuncTypeBuilder
114-
: public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo,
115-
DiffFuncIndex> {
118+
class DifferentiableFuncTypeBuilder
119+
: public RecordTypeBuilder<DifferentiableFuncTypeBuilder, DifferentiableFuncFieldInfo,
120+
DifferentiableFunctionExtractee> {
116121

117-
SILFunctionType *origFnTy;
122+
SILFunctionType *originalType;
118123
IndexSubset *parameterIndices;
119124

120125
public:
121-
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
122-
: RecordTypeBuilder(IGM), origFnTy(fnTy->getWithoutDifferentiability()),
126+
DifferentiableFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
127+
: RecordTypeBuilder(IGM),
128+
originalType(fnTy->getWithoutDifferentiability()),
123129
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
124-
assert(fnTy->isDifferentiable());
130+
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Normal);
125131
}
126132

127-
TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
133+
TypeInfo *createFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
128134
StructLayout &&layout) {
129135
llvm_unreachable("@differentiable functions are always loadable");
130136
}
131137

132-
DiffFuncTypeInfo *createLoadable(ArrayRef<DiffFuncFieldInfo> fields,
133-
StructLayout &&layout,
134-
unsigned explosionSize) {
135-
return DiffFuncTypeInfo::create(
138+
DifferentiableFuncTypeInfo *createLoadable(
139+
ArrayRef<DifferentiableFuncFieldInfo> fields, StructLayout &&layout,
140+
unsigned explosionSize) {
141+
return DifferentiableFuncTypeInfo::create(
136142
fields, explosionSize, layout.getType(), layout.getSize(),
137143
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
138144
layout.isAlwaysFixedSize());
139145
}
140146

141-
TypeInfo *createNonFixed(ArrayRef<DiffFuncFieldInfo> fields,
147+
TypeInfo *createNonFixed(ArrayRef<DifferentiableFuncFieldInfo> fields,
142148
FieldsAreABIAccessible_t fieldsAccessible,
143149
StructLayout &&layout) {
144150
llvm_unreachable("@differentiable functions are always loadable");
145151
}
146152

147-
DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
148-
const TypeInfo &fieldTI) {
149-
return DiffFuncFieldInfo(field, fieldTI, parameterIndices);
153+
DifferentiableFuncFieldInfo getFieldInfo(
154+
unsigned index, DifferentiableFunctionExtractee component,
155+
const TypeInfo &fieldTI) {
156+
return DifferentiableFuncFieldInfo(component, fieldTI, parameterIndices);
150157
}
151158

152-
SILType getType(DiffFuncIndex field) {
153-
if (field == DifferentiableFunctionExtractee::Original)
154-
return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType());
155-
auto kind = *field.getExtracteeAsDerivativeFunction();
156-
auto assocTy = origFnTy->getAutoDiffDerivativeFunctionType(
159+
SILType getType(DifferentiableFunctionExtractee component) {
160+
if (component == DifferentiableFunctionExtractee::Original)
161+
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
162+
auto kind = *component.getExtracteeAsDerivativeFunction();
163+
auto assocTy = originalType->getAutoDiffDerivativeFunctionType(
157164
parameterIndices, /*resultIndex*/ 0, kind, IGM.getSILTypes(),
158165
LookUpConformanceInModule(IGM.getSwiftModule()));
159166
return SILType::getPrimitiveObjectType(assocTy);
@@ -166,11 +173,161 @@ class DiffFuncTypeBuilder
166173
};
167174
} // end anonymous namespace
168175

176+
//----------------------------------------------------------------------------//
177+
// `@differentiable(linear)` function type info
178+
//----------------------------------------------------------------------------//
179+
namespace {
180+
class LinearFuncFieldInfo final : public RecordField<LinearFuncFieldInfo> {
181+
public:
182+
LinearFuncFieldInfo(LinearDifferentiableFunctionTypeComponent component,
183+
const TypeInfo &type, IndexSubset *parameterIndices)
184+
: RecordField(type), component(component),
185+
parameterIndices(parameterIndices) {}
186+
187+
/// The field index.
188+
const LinearDifferentiableFunctionTypeComponent component;
189+
190+
/// The parameter indices.
191+
IndexSubset *parameterIndices;
192+
193+
std::string getFieldName() const {
194+
switch (component) {
195+
case LinearDifferentiableFunctionTypeComponent::Original:
196+
return "original";
197+
case LinearDifferentiableFunctionTypeComponent::Transpose:
198+
return "transpose";
199+
}
200+
}
201+
202+
SILType getType(IRGenModule &IGM, SILType t) const {
203+
auto fnTy = t.castTo<SILFunctionType>();
204+
auto origFnTy = fnTy->getWithoutDifferentiability();
205+
switch (component) {
206+
case LinearDifferentiableFunctionTypeComponent::Original:
207+
return SILType::getPrimitiveObjectType(origFnTy);
208+
case LinearDifferentiableFunctionTypeComponent::Transpose:
209+
auto transposeTy = origFnTy->getAutoDiffTransposeFunctionType(
210+
parameterIndices, IGM.getSILTypes(),
211+
LookUpConformanceInModule(IGM.getSwiftModule()));
212+
return SILType::getPrimitiveObjectType(transposeTy);
213+
}
214+
}
215+
};
216+
217+
class LinearFuncTypeInfo final
218+
: public RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo,
219+
LinearFuncFieldInfo> {
220+
using super =
221+
RecordTypeInfo<LinearFuncTypeInfo, LoadableTypeInfo, LinearFuncFieldInfo>;
222+
223+
public:
224+
LinearFuncTypeInfo(
225+
ArrayRef<LinearFuncFieldInfo> fields, unsigned explosionSize,
226+
llvm::Type *ty, Size size, SpareBitVector &&spareBits, Alignment align,
227+
IsPOD_t isPOD, IsFixedSize_t alwaysFixedSize)
228+
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
229+
isPOD, alwaysFixedSize) {}
230+
231+
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
232+
const LinearFuncFieldInfo &field) const {
233+
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
234+
}
235+
236+
void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
237+
SILType T, bool isOutlined) const override {
238+
llvm_unreachable("unexploded @differentiable function as argument?");
239+
}
240+
241+
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
242+
Size offset) const override {
243+
for (auto &field : getFields()) {
244+
auto fieldOffset = offset + field.getFixedByteOffset();
245+
cast<LoadableTypeInfo>(field.getTypeInfo())
246+
.addToAggLowering(IGM, lowering, fieldOffset);
247+
}
248+
}
249+
250+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
251+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
252+
return None;
253+
}
254+
};
255+
256+
class LinearFuncTypeBuilder
257+
: public RecordTypeBuilder<LinearFuncTypeBuilder, LinearFuncFieldInfo,
258+
LinearDifferentiableFunctionTypeComponent> {
259+
260+
SILFunctionType *originalType;
261+
IndexSubset *parameterIndices;
262+
263+
public:
264+
LinearFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
265+
: RecordTypeBuilder(IGM),
266+
originalType(fnTy->getWithoutDifferentiability()),
267+
parameterIndices(fnTy->getDifferentiationParameterIndices()) {
268+
assert(fnTy->getDifferentiabilityKind() == DifferentiabilityKind::Linear);
269+
}
270+
271+
TypeInfo *createFixed(ArrayRef<LinearFuncFieldInfo> fields,
272+
StructLayout &&layout) {
273+
llvm_unreachable("@differentiable functions are always loadable");
274+
}
275+
276+
LinearFuncTypeInfo *createLoadable(ArrayRef<LinearFuncFieldInfo> fields,
277+
StructLayout &&layout,
278+
unsigned explosionSize) {
279+
return LinearFuncTypeInfo::create(
280+
fields, explosionSize, layout.getType(), layout.getSize(),
281+
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
282+
layout.isAlwaysFixedSize());
283+
}
284+
285+
TypeInfo *createNonFixed(ArrayRef<LinearFuncFieldInfo> fields,
286+
FieldsAreABIAccessible_t fieldsAccessible,
287+
StructLayout &&layout) {
288+
llvm_unreachable("@differentiable functions are always loadable");
289+
}
290+
291+
LinearFuncFieldInfo getFieldInfo(
292+
unsigned index, LinearDifferentiableFunctionTypeComponent field,
293+
const TypeInfo &fieldTI) {
294+
return LinearFuncFieldInfo(field, fieldTI, parameterIndices);
295+
}
296+
297+
SILType getType(LinearDifferentiableFunctionTypeComponent component) {
298+
switch (component) {
299+
case LinearDifferentiableFunctionTypeComponent::Original:
300+
return SILType::getPrimitiveObjectType(originalType->getCanonicalType());
301+
case LinearDifferentiableFunctionTypeComponent::Transpose:
302+
auto transposeTy = originalType->getAutoDiffTransposeFunctionType(
303+
parameterIndices, IGM.getSILTypes(),
304+
LookUpConformanceInModule(IGM.getSwiftModule()));
305+
return SILType::getPrimitiveObjectType(transposeTy);
306+
}
307+
}
308+
309+
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
310+
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
311+
LayoutStrategy::Universal, fieldTypes);
312+
}
313+
};
314+
} // end anonymous namespace
315+
316+
//----------------------------------------------------------------------------//
317+
// Type converter entry points
318+
//----------------------------------------------------------------------------//
319+
169320
const TypeInfo *
170-
TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) {
171-
assert(type->isDifferentiable());
172-
DiffFuncTypeBuilder builder(IGM, type);
321+
TypeConverter::convertNormalDifferentiableFunctionType(SILFunctionType *type) {
322+
DifferentiableFuncTypeBuilder builder(IGM, type);
173323
return builder.layout({DifferentiableFunctionExtractee::Original,
174324
DifferentiableFunctionExtractee::JVP,
175325
DifferentiableFunctionExtractee::VJP});
176326
}
327+
328+
const TypeInfo *
329+
TypeConverter::convertLinearDifferentiableFunctionType(SILFunctionType *type) {
330+
LinearFuncTypeBuilder builder(IGM, type);
331+
return builder.layout({LinearDifferentiableFunctionTypeComponent::Original,
332+
LinearDifferentiableFunctionTypeComponent::Transpose});
333+
}

lib/IRGen/GenFunc.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,8 +480,14 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF,
480480

481481
const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) {
482482
// SWIFT_ENABLE_TENSORFLOW
483-
if (T->isDifferentiable())
484-
return convertDifferentiableFunctionType(T);
483+
switch (T->getDifferentiabilityKind()) {
484+
case DifferentiabilityKind::Normal:
485+
return convertNormalDifferentiableFunctionType(T);
486+
case DifferentiabilityKind::Linear:
487+
return convertLinearDifferentiableFunctionType(T);
488+
case DifferentiabilityKind::NonDifferentiable:
489+
break;
490+
}
485491

486492
switch (T->getRepresentation()) {
487493
case SILFunctionType::Representation::Block:

lib/IRGen/GenType.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class TypeConverter {
138138
const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D);
139139
const TypeInfo *convertFunctionType(SILFunctionType *T);
140140
// SWIFT_ENABLE_TENSORFLOW
141-
const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T);
141+
const TypeInfo *convertNormalDifferentiableFunctionType(SILFunctionType *T);
142+
const TypeInfo *convertLinearDifferentiableFunctionType(SILFunctionType *T);
142143
const TypeInfo *convertBlockStorageType(SILBlockStorageType *T);
143144
const TypeInfo *convertBoxType(SILBoxType *T);
144145
const TypeInfo *convertArchetypeType(ArchetypeType *T);

0 commit comments

Comments
 (0)