Skip to content

Commit d2ae8b4

Browse files
authored
---
yaml --- r: 262134 b: refs/heads/tensorflow c: 31290d0 h: refs/heads/master
1 parent 4bc227f commit d2ae8b4

19 files changed

+575
-174
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: c30ca154a84ab7befcf2f6b086bd7b29a9288c97
821+
refs/heads/tensorflow: 31290d09842654ead00dbd0b4852ece9ceb84639
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/lib/IRGen/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ add_swift_host_library(swiftIRGen STATIC
2121
GenControl.cpp
2222
GenCoverage.cpp
2323
GenDecl.cpp
24+
# SWIFT_ENABLE_TENSORFLOW
25+
GenDiffFunc.cpp
2426
GenEnum.cpp
2527
GenExistential.cpp
2628
GenFunc.cpp
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
//===--- GenDiffFunc.cpp - Swift IR Generation For @autodiff Functions ---===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
// SWIFT_ENABLE_TENSORFLOW
13+
14+
#include "swift/AST/Decl.h"
15+
#include "swift/AST/Pattern.h"
16+
#include "swift/AST/Types.h"
17+
#include "swift/SIL/SILModule.h"
18+
#include "swift/SIL/SILType.h"
19+
#include "llvm/IR/DerivedTypes.h"
20+
21+
#include "Explosion.h"
22+
#include "GenHeap.h"
23+
#include "GenRecord.h"
24+
#include "GenType.h"
25+
#include "IRGenFunction.h"
26+
#include "IRGenModule.h"
27+
#include "IndirectTypeInfo.h"
28+
#include "NonFixedTypeInfo.h"
29+
30+
#pragma clang diagnostic ignored "-Winconsistent-missing-override"
31+
32+
using namespace swift;
33+
using namespace irgen;
34+
35+
using DiffFuncIndex =
36+
std::pair<AutoDiffFunctionExtractInst::Extractee, unsigned>;
37+
38+
namespace {
39+
class DiffFuncFieldInfo final : public RecordField<DiffFuncFieldInfo> {
40+
public:
41+
DiffFuncFieldInfo(DiffFuncIndex index, const TypeInfo &type)
42+
: RecordField(type), Index(index) {}
43+
44+
/// The field index.
45+
const DiffFuncIndex Index;
46+
47+
std::string getFieldName() const {
48+
auto extractee = std::get<0>(Index);
49+
auto differentiationOrder = std::get<1>(Index);
50+
switch (extractee) {
51+
case AutoDiffFunctionExtractInst::Extractee::Original:
52+
return "original";
53+
case AutoDiffFunctionExtractInst::Extractee::JVP:
54+
return "jvp_" + llvm::itostr(differentiationOrder);
55+
case AutoDiffFunctionExtractInst::Extractee::VJP:
56+
return "vjp_" + llvm::itostr(differentiationOrder);
57+
}
58+
}
59+
60+
SILType getType(IRGenModule &IGM, SILType t) const {
61+
auto fnTy = t.castTo<SILFunctionType>();
62+
auto extInfo = fnTy->getExtInfo();
63+
auto nondiffExtInfo = extInfo.withDifferentiable(false);
64+
auto origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
65+
if (std::get<0>(Index) == AutoDiffFunctionExtractInst::Extractee::Original)
66+
return SILType::getPrimitiveObjectType(origFnTy);
67+
auto differentiationOrder = std::get<1>(Index);
68+
auto kind = *std::get<0>(Index).getExtracteeAsAssociatedFunction();
69+
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
70+
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
71+
differentiationOrder, kind, IGM.getSILModule(),
72+
LookUpConformanceInModule(IGM.getSwiftModule()));
73+
return SILType::getPrimitiveObjectType(assocTy);
74+
}
75+
};
76+
77+
class DiffFuncTypeInfo final
78+
: public RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo,
79+
DiffFuncFieldInfo> {
80+
using super =
81+
RecordTypeInfo<DiffFuncTypeInfo, LoadableTypeInfo, DiffFuncFieldInfo>;
82+
83+
public:
84+
DiffFuncTypeInfo(ArrayRef<DiffFuncFieldInfo> fields, unsigned explosionSize,
85+
llvm::Type *ty, Size size, SpareBitVector &&spareBits,
86+
Alignment align, IsPOD_t isPOD,
87+
IsFixedSize_t alwaysFixedSize)
88+
: super(fields, explosionSize, ty, size, std::move(spareBits), align,
89+
isPOD, alwaysFixedSize) {}
90+
91+
Address projectFieldAddress(IRGenFunction &IGF, Address addr, SILType T,
92+
const DiffFuncFieldInfo &field) const {
93+
return field.projectAddress(IGF, addr, getNonFixedOffsets(IGF, T));
94+
}
95+
96+
void initializeFromParams(IRGenFunction &IGF, Explosion &params, Address src,
97+
SILType T, bool isOutlined) const override {
98+
llvm_unreachable("unexploded @autodiff function as argument?");
99+
}
100+
101+
void addToAggLowering(IRGenModule &IGM, SwiftAggLowering &lowering,
102+
Size offset) const override {
103+
for (auto &field : getFields()) {
104+
auto fieldOffset = offset + field.getFixedByteOffset();
105+
cast<LoadableTypeInfo>(field.getTypeInfo())
106+
.addToAggLowering(IGM, lowering, fieldOffset);
107+
}
108+
}
109+
110+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF) const { return None; }
111+
llvm::NoneType getNonFixedOffsets(IRGenFunction &IGF, SILType T) const {
112+
return None;
113+
}
114+
};
115+
116+
class DiffFuncTypeBuilder
117+
: public RecordTypeBuilder<DiffFuncTypeBuilder, DiffFuncFieldInfo,
118+
DiffFuncIndex> {
119+
120+
SILFunctionType *origFnTy;
121+
122+
public:
123+
DiffFuncTypeBuilder(IRGenModule &IGM, SILFunctionType *fnTy)
124+
: RecordTypeBuilder(IGM) {
125+
assert(fnTy->isDifferentiable());
126+
auto extInfo = fnTy->getExtInfo();
127+
auto nondiffExtInfo = extInfo.withDifferentiable(false);
128+
origFnTy = fnTy->getWithExtInfo(nondiffExtInfo);
129+
}
130+
131+
TypeInfo *createFixed(ArrayRef<DiffFuncFieldInfo> fields,
132+
StructLayout &&layout) {
133+
llvm_unreachable("@autodiff functions are always loadable");
134+
}
135+
136+
DiffFuncTypeInfo *createLoadable(ArrayRef<DiffFuncFieldInfo> fields,
137+
StructLayout &&layout,
138+
unsigned explosionSize) {
139+
return DiffFuncTypeInfo::create(
140+
fields, explosionSize, layout.getType(), layout.getSize(),
141+
std::move(layout.getSpareBits()), layout.getAlignment(), layout.isPOD(),
142+
layout.isAlwaysFixedSize());
143+
}
144+
145+
TypeInfo *createNonFixed(ArrayRef<DiffFuncFieldInfo> fields,
146+
FieldsAreABIAccessible_t fieldsAccessible,
147+
StructLayout &&layout) {
148+
llvm_unreachable("@autodiff functions are always loadable");
149+
}
150+
151+
DiffFuncFieldInfo getFieldInfo(unsigned index, DiffFuncIndex field,
152+
const TypeInfo &fieldTI) {
153+
return DiffFuncFieldInfo(field, fieldTI);
154+
}
155+
156+
SILType getType(DiffFuncIndex field) {
157+
if (std::get<0>(field) == AutoDiffFunctionExtractInst::Extractee::Original)
158+
return SILType::getPrimitiveObjectType(origFnTy->getCanonicalType());
159+
auto differentiationOrder = std::get<1>(field);
160+
auto kind = *std::get<0>(field).getExtracteeAsAssociatedFunction();
161+
auto assocTy = origFnTy->getAutoDiffAssociatedFunctionType(
162+
SmallBitVector(origFnTy->getNumParameters(), true), /*resultIndex*/ 0,
163+
differentiationOrder, kind, IGM.getSILModule(),
164+
LookUpConformanceInModule(IGM.getSwiftModule()));
165+
return SILType::getPrimitiveObjectType(assocTy);
166+
}
167+
168+
StructLayout performLayout(ArrayRef<const TypeInfo *> fieldTypes) {
169+
return StructLayout(IGM, /*decl=*/nullptr, LayoutKind::NonHeapObject,
170+
LayoutStrategy::Universal, fieldTypes);
171+
}
172+
};
173+
} // end anonymous namespace
174+
175+
const TypeInfo *
176+
TypeConverter::convertDifferentiableFunctionType(SILFunctionType *type) {
177+
assert(type->isDifferentiable());
178+
DiffFuncTypeBuilder builder(IGM, type);
179+
SmallVector<DiffFuncIndex, 3> fields;
180+
fields.push_back(
181+
std::make_pair(AutoDiffFunctionExtractInst::Extractee::Original, 0));
182+
fields.push_back(
183+
std::make_pair(AutoDiffFunctionExtractInst::Extractee::JVP, 1));
184+
fields.push_back(
185+
std::make_pair(AutoDiffFunctionExtractInst::Extractee::VJP, 1));
186+
return builder.layout(fields);
187+
}

branches/tensorflow/lib/IRGen/GenFunc.cpp

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -474,23 +474,9 @@ Address irgen::projectBlockStorageCapture(IRGenFunction &IGF,
474474

475475
const TypeInfo *TypeConverter::convertFunctionType(SILFunctionType *T) {
476476
// SWIFT_ENABLE_TENSORFLOW
477-
if (T->isDifferentiable()) {
478-
auto extInfo = T->getExtInfo();
479-
auto nondiffExtInfo = extInfo.withDifferentiable(false);
480-
auto origTy = T->getWithExtInfo(nondiffExtInfo);
481-
// TODO(rxwei): Use the parameter indices and diff order in the @autodiff
482-
// function type.
483-
auto jvpTy = origTy->getAutoDiffAssociatedFunctionType(
484-
SmallBitVector(T->getNumParameters(), true), /*resultIndex*/ 0,
485-
/*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::JVP,
486-
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
487-
auto vjpTy = origTy->getAutoDiffAssociatedFunctionType(
488-
SmallBitVector(T->getNumParameters(), true), /*resultIndex*/ 0,
489-
/*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::VJP,
490-
IGM.getSILModule(), LookUpConformanceInModule(IGM.getSwiftModule()));
491-
return convertTupleType(TupleType::get({origTy, jvpTy, vjpTy}, IGM.Context)
492-
->castTo<TupleType>());
493-
}
477+
if (T->isDifferentiable())
478+
return convertDifferentiableFunctionType(T);
479+
494480
switch (T->getRepresentation()) {
495481
case SILFunctionType::Representation::Block:
496482
return new BlockTypeInfo(CanSILFunctionType(T),

branches/tensorflow/lib/IRGen/GenType.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class TypeConverter {
134134
const TypeInfo *convertEnumType(TypeBase *key, CanType type, EnumDecl *D);
135135
const TypeInfo *convertStructType(TypeBase *key, CanType type, StructDecl *D);
136136
const TypeInfo *convertFunctionType(SILFunctionType *T);
137+
// SWIFT_ENABLE_TENSORFLOW
138+
const TypeInfo *convertDifferentiableFunctionType(SILFunctionType *T);
137139
const TypeInfo *convertBlockStorageType(SILBlockStorageType *T);
138140
const TypeInfo *convertBoxType(SILBoxType *T);
139141
const TypeInfo *convertArchetypeType(ArchetypeType *T);

branches/tensorflow/lib/IRGen/IRGenSIL.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5664,13 +5664,19 @@ void IRGenSILFunction::visitConvertFunctionInst(swift::ConvertFunctionInst *i) {
56645664

56655665
void IRGenSILFunction::visitConvertEscapeToNoEscapeInst(
56665666
swift::ConvertEscapeToNoEscapeInst *i) {
5667-
// This instruction makes the context trivial.
5667+
// SWIFT_ENABLE_TENSORFLOW
5668+
// This instruction makes the context(s) trivial. A function contains multiple
5669+
// function pointers and contexts when it's differentiable.
56685670
Explosion in = getLoweredExplosion(i->getOperand());
5669-
llvm::Value *fn = in.claimNext();
5670-
llvm::Value *ctx = in.claimNext();
56715671
Explosion out;
5672-
out.add(fn);
5673-
out.add(Builder.CreateBitCast(ctx, IGM.OpaquePtrTy));
5672+
// SWIFT_ENABLE_TENSORFLOW
5673+
for (unsigned index : range(in.size() / 2)) {
5674+
(void)index;
5675+
llvm::Value *fn = in.claimNext();
5676+
llvm::Value *ctx = in.claimNext();
5677+
out.add(fn);
5678+
out.add(Builder.CreateBitCast(ctx, IGM.OpaquePtrTy));
5679+
}
56745680
setLoweredExplosion(i, out);
56755681
}
56765682

0 commit comments

Comments
 (0)