|
| 1 | +//===--- GenDiffFunc.cpp - Swift IR Generation For @autodiff Functions ---===// |
| 2 | +// |
| 3 | +// This source file is part of the Swift.org open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors |
| 6 | +// Licensed under Apache License v2.0 with Runtime Library Exception |
| 7 | +// |
| 8 | +// See https://swift.org/LICENSE.txt for license information |
| 9 | +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | +// SWIFT_ENABLE_TENSORFLOW |
| 13 | + |
| 14 | +#include "swift/AST/Decl.h" |
| 15 | +#include "swift/AST/Pattern.h" |
| 16 | +#include "swift/AST/Types.h" |
| 17 | +#include "swift/SIL/SILModule.h" |
| 18 | +#include "swift/SIL/SILType.h" |
| 19 | +#include "llvm/IR/DerivedTypes.h" |
| 20 | + |
| 21 | +#include "Explosion.h" |
| 22 | +#include "GenHeap.h" |
| 23 | +#include "GenRecord.h" |
| 24 | +#include "GenType.h" |
| 25 | +#include "IRGenFunction.h" |
| 26 | +#include "IRGenModule.h" |
| 27 | +#include "IndirectTypeInfo.h" |
| 28 | +#include "NonFixedTypeInfo.h" |
| 29 | + |
| 30 | +#pragma clang diagnostic ignored "-Winconsistent-missing-override" |
| 31 | + |
| 32 | +using namespace swift; |
| 33 | +using namespace irgen; |
| 34 | + |
| 35 | +using DiffFuncIndex = |
| 36 | + std::pair<AutoDiffFunctionExtractInst::Extractee, unsigned>; |
| 37 | + |
| 38 | +namespace { |
| 39 | +class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> { |
| 40 | +public: |
| 41 | + DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type) |
| 42 | + : RecordField(type), Index(index) {} |
| 43 | + |
| 44 | + /// The field index. |
| 45 | + const DiffFuncIndex Index; |
| 46 | + |
| 47 | + std::string getFieldName() const { |
| 48 | + auto extractee = std::get<0>(Index); |
| 49 | + auto differentiationOrder = std::get<1>(Index); |
| 50 | + switch (extractee) { |
| 51 | + case AutoDiffFunctionExtractInst::Extractee::Original: |
| 52 | + return "original"; |
| 53 | + case AutoDiffFunctionExtractInst::Extractee::JVP: |
| 54 | + return "jvp_" + llvm::itostr(differentiationOrder); |
| 55 | + case AutoDiffFunctionExtractInst::Extractee::VJP: |
| 56 | + return "vjp_" + llvm::itostr(differentiationOrder); |
| 57 | + } |
| 58 | + } |
| 59 | + |
| 60 | + SILType getType(IRGenModule &IGM, SILType t) const { |
| 61 | + auto fnTy = t.castTo<SILFunctionType>(); |
| 62 | + auto extInfo = fnTy->getExtInfo(); |
| 63 | + auto nondiffExtInfo = extInfo.withDifferentiable(false); |
| 64 | + auto origFnTy = fnTy->getWithExtInfo(nondiffExtInfo); |
| 65 | + if (std::get<0>(Index) == AutoDiffFunctionExtractInst::Extractee::Original) |
| 66 | + return SILType::getPrimitiveObjectType(origFnTy); |
| 67 | + auto differentiationOrder = std::get<1>(Index); |
| 68 | + auto kind = *std::get<0>(Index).getExtracteeAsAssociatedFunction(); |
| 69 | + auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( |
| 70 | + SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0, |
| 71 | + differentiationOrder, kind, IGM.getSILModule(), |
| 72 | + LookUpConformanceInModule(IGM.getSwiftModule())); |
| 73 | + return SILType::getPrimitiveObjectType(assocTy); |
| 74 | + } |
| 75 | +}; |
| 76 | + |
| 77 | +class DiffFuncTypeInfo final |
| 78 | + : public RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, |
| 79 | + DiffFuncFieldInfo> { |
| 80 | + using super = |
| 81 | + RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, DiffFuncFieldInfo>; |
| 82 | + |
| 83 | +public: |
| 84 | + DiffFuncTypeInfo(ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize, |
| 85 | + llvm::Type *ty, Size size, SpareBitVector &&spareBits, |
| 86 | + Alignment align, IsPOD_t isPOD, |
| 87 | + IsFixedSize_t alwaysFixedSize) |
| 88 | + : super(fields, explosionSize, ty, size, std::move(spareBits), align, |
| 89 | + isPOD, alwaysFixedSize) {} |
| 90 | + |
| 91 | + Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T, |
| 92 | + const DiffFuncFieldInfo &field) const { |
| 93 | + return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T)); |
| 94 | + } |
| 95 | + |
| 96 | + void initializeFromParams(IRGenFunction &IGF, Explosion ¶ms, Address src, |
| 97 | + SILType T, bool isOutlined) const override { |
| 98 | + llvm_unreachable("unexploded @autodiff function as argument?"); |
| 99 | + } |
| 100 | + |
| 101 | + void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering, |
| 102 | + Size offset) const override { |
| 103 | + for (auto &field : getFields()) { |
| 104 | + auto fieldOffset = offset + field.getFixedByteOffset(); |
| 105 | + cast<LoadableTypeInfo>(field.getTypeInfo()) |
| 106 | + .addToAggLowering(IGM, lowering, fieldOffset); |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; } |
| 111 | + llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const { |
| 112 | + return None; |
| 113 | + } |
| 114 | +}; |
| 115 | + |
| 116 | +class DiffFuncTypeBuilder |
| 117 | + : public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo, |
| 118 | + DiffFuncIndex> { |
| 119 | + |
| 120 | + SILFunctionType *origFnTy; |
| 121 | + |
| 122 | +public: |
| 123 | + DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy) |
| 124 | + : RecordTypeBuilder(IGM) { |
| 125 | + assert(fnTy->isDifferentiable()); |
| 126 | + auto extInfo = fnTy->getExtInfo(); |
| 127 | + auto nondiffExtInfo = extInfo.withDifferentiable(false); |
| 128 | + origFnTy = fnTy->getWithExtInfo(nondiffExtInfo); |
| 129 | + } |
| 130 | + |
| 131 | + TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields, |
| 132 | + StructLayout &&layout) { |
| 133 | + llvm_unreachable("@autodiff functions are always loadable"); |
| 134 | + } |
| 135 | + |
| 136 | + DiffFuncTypeInfo *createLoadable(ArrayRef<DiffFuncFieldInfo> fields, |
| 137 | + StructLayout &&layout, |
| 138 | + unsigned explosionSize) { |
| 139 | + return DiffFuncTypeInfo::create( |
| 140 | + fields, explosionSize, layout.getType(), layout.getSize(), |
| 141 | + std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(), |
| 142 | + layout.isAlwaysFixedSize()); |
| 143 | + } |
| 144 | + |
| 145 | + TypeInfo *createNonFixed(ArrayRef<DiffFuncFieldInfo> fields, |
| 146 | + FieldsAreABIAccessible_t fieldsAccessible, |
| 147 | + StructLayout &&layout) { |
| 148 | + llvm_unreachable("@autodiff functions are always loadable"); |
| 149 | + } |
| 150 | + |
| 151 | + DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field, |
| 152 | + const TypeInfo &fieldTI) { |
| 153 | + return DiffFuncFieldInfo(field, fieldTI); |
| 154 | + } |
| 155 | + |
| 156 | + SILType getType(DiffFuncIndex field) { |
| 157 | + if (std::get<0>(field) == AutoDiffFunctionExtractInst::Extractee::Original) |
| 158 | + return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType()); |
| 159 | + auto differentiationOrder = std::get<1>(field); |
| 160 | + auto kind = *std::get<0>(field).getExtracteeAsAssociatedFunction(); |
| 161 | + auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType( |
| 162 | + SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0, |
| 163 | + differentiationOrder, kind, IGM.getSILModule(), |
| 164 | + LookUpConformanceInModule(IGM.getSwiftModule())); |
| 165 | + return SILType::getPrimitiveObjectType(assocTy); |
| 166 | + } |
| 167 | + |
| 168 | + StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) { |
| 169 | + return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject, |
| 170 | + LayoutStrategy::Universal, fieldTypes); |
| 171 | + } |
| 172 | +}; |
| 173 | +} // end anonymous namespace |
| 174 | + |
| 175 | +const TypeInfo * |
| 176 | +TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) { |
| 177 | + assert(type->isDifferentiable()); |
| 178 | + DiffFuncTypeBuilder builder(IGM, type); |
| 179 | + SmallVector<DiffFuncIndex, 3> fields; |
| 180 | + fields.push_back( |
| 181 | + std::make_pair(AutoDiffFunctionExtractInst::Extractee::Original, 0)); |
| 182 | + fields.push_back( |
| 183 | + std::make_pair(AutoDiffFunctionExtractInst::Extractee::JVP, 1)); |
| 184 | + fields.push_back( |
| 185 | + std::make_pair(AutoDiffFunctionExtractInst::Extractee::VJP, 1)); |
| 186 | + return builder.layout(fields); |
| 187 | +} |
0 commit comments