Skip to content

Commit afcd55e

Browse files
authored
---
yaml --- r: 262079 b: refs/heads/tensorflow c: 2ab4b8a h: refs/heads/master i: 262077: d30379e 262075: 421798a 262071: afc2ce4 262063: fead07f 262047: 2bceee3 262015: 1263945
1 parent ad40c41 commit afcd55e

File tree

4 files changed

+117
-19
lines changed

4 files changed

+117
-19
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: 1cd29f7f6b9d69819306c0bb92428e0e144084fe
821+
refs/heads/tensorflow: 2ab4b8a8e7842ac52ea835932d132f82b7990753
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,11 @@ struct DifferentiationInvoker {
265265
// a Swift AST node.
266266
GradientInst,
267267

268+
// No known invoker. This is the case when the differentiation is requested
269+
// from SIL source via a `autodiff_function` instruction **without** being
270+
// linked to a Swift AST node.
271+
AutoDiffFunctionInst,
272+
268273
// Invoked by the indirect application of differentiation. This case has an
269274
// associated differentiation task reference.
270275
IndirectDifferentiation,
@@ -287,10 +292,14 @@ struct DifferentiationInvoker {
287292
private:
288293
Kind kind;
289294
union Value {
290-
/// The instruction associated with the `SILSource` case.
295+
/// The instruction associated with the `GradientInst` case.
291296
GradientInst *gradientInst;
292297
Value(GradientInst *inst) : gradientInst(inst) {}
293298

299+
/// The instruction associated with the `AutoDiffFunctionInst` case.
300+
AutoDiffFunctionInst *adFuncInst;
301+
Value(AutoDiffFunctionInst *inst) : adFuncInst(inst) {}
302+
294303
/// The parent differentiation task associated with the
295304
/// `IndirectDifferentiation` case.
296305
std::pair<ApplyInst *, DifferentiationTask *> indirectDifferentiation;
@@ -322,6 +331,8 @@ struct DifferentiationInvoker {
322331
public:
323332
DifferentiationInvoker(GradientInst *inst)
324333
: kind(Kind::GradientInst), value(inst) {}
334+
DifferentiationInvoker(AutoDiffFunctionInst *inst)
335+
: kind(Kind::AutoDiffFunctionInst), value(inst) {}
325336
DifferentiationInvoker(ApplyInst *applyInst, DifferentiationTask *task)
326337
: kind(Kind::IndirectDifferentiation), value(applyInst, task) {}
327338
DifferentiationInvoker(ReverseAutoDiffExpr *expr)
@@ -338,6 +349,11 @@ struct DifferentiationInvoker {
338349
return value.gradientInst;
339350
}
340351

352+
AutoDiffFunctionInst *getAutoDiffFunctionInst() const {
353+
assert(kind == Kind::AutoDiffFunctionInst);
354+
return value.adFuncInst;
355+
}
356+
341357
std::pair<ApplyInst *, DifferentiationTask *>
342358
getIndirectDifferentiation() const {
343359
assert(kind == Kind::IndirectDifferentiation);
@@ -812,6 +828,9 @@ void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
812828
case Kind::GradientInst:
813829
os << "gradient_inst=(" << *getGradientInst() << ")";
814830
break;
831+
case Kind::AutoDiffFunctionInst:
832+
os << "autodiff_function_inst=(" << *getAutoDiffFunctionInst() << ")";
833+
break;
815834
case Kind::IndirectDifferentiation: {
816835
auto indDiff = getIndirectDifferentiation();
817836
os << "indirect_differentiation=(apply_inst=(" << *indDiff.first
@@ -1177,6 +1196,7 @@ void ADContext::emitNondifferentiabilityError(SILInstruction *inst,
11771196
// associated with any source location, we emit a diagnostic at the
11781197
// instruction source location.
11791198
case DifferentiationInvoker::Kind::GradientInst:
1199+
case DifferentiationInvoker::Kind::AutoDiffFunctionInst:
11801200
case DifferentiationInvoker::Kind::SILDifferentiableAttribute:
11811201
diagnose(opLoc,
11821202
diag.getValueOr(diag::autodiff_expression_is_not_differentiable));
@@ -1293,7 +1313,7 @@ class DifferentiableActivityInfo;
12931313
/// does not matter for the final result.
12941314
///
12951315
/// Reference:
1296-
/// Laurent Hascoët. Automatic Differentiation by Program Transformation. 2017.
1316+
/// Laurent Hascoët. Automatic Differentiation by Program Transformation. 2007.
12971317
class DifferentiableActivityAnalysis
12981318
: public FunctionAnalysisBase<DifferentiableActivityInfo> {
12991319
private:
@@ -4673,6 +4693,14 @@ static ReverseAutoDiffExpr *findDifferentialOperator(GradientInst *inst) {
46734693
return inst->getLoc().getAsASTNode<ReverseAutoDiffExpr>();
46744694
}
46754695

4696+
/// Given an `autodiff_function` instruction, find the corresponding
4697+
/// differential operator used in the AST. If no differential operator is found,
4698+
/// return nullptr.
4699+
static AutoDiffFunctionExpr *findDifferentialOperator(
4700+
AutoDiffFunctionInst *inst) {
4701+
return inst->getLoc().getAsASTNode<AutoDiffFunctionExpr>();
4702+
}
4703+
46764704
// Retrieve or create an empty gradient function based on a `gradient`
46774705
// instruction and replace all users of the `gradient` instruction with the
46784706
// gradient function. Returns the gradient function.
@@ -4953,6 +4981,9 @@ class Differentiation : public SILModuleTransform {
49534981
/// pushing a differentiation task onto the global list. Returns true if any
49544982
/// error occurred.
49554983
bool processGradientInst(GradientInst *gi, ADContext &context);
4984+
4985+
bool processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi,
4986+
ADContext &context);
49564987
};
49574988
} // end anonymous namespace
49584989

@@ -5003,6 +5034,60 @@ bool Differentiation::processGradientInst(GradientInst *gi,
50035034
return false;
50045035
}
50055036

5037+
bool Differentiation::processAutoDiffFunctionInst(AutoDiffFunctionInst *adfi,
5038+
ADContext &context) {
5039+
SILFunction *parent = adfi->getFunction();
5040+
auto origFnOperand = adfi->getOriginalFunction();
5041+
// If it traces back to a `function_ref`, differentiate that.
5042+
if (auto *originalFRI = findReferenceToVisibleFunction(origFnOperand)) {
5043+
auto *original = originalFRI->getReferencedFunction();
5044+
// TODO: Find syntax-level AD invoker from `adfi`.
5045+
auto *task = context.lookUpOrRegisterDifferentiationTask(
5046+
original, SILAutoDiffIndices(0, adfi->getParameterIndices()),
5047+
DifferentiationInvoker(adfi));
5048+
// Expand the `autodiff_function` instruction by adding the JVP and VJP
5049+
// functions.
5050+
SILBuilder builder(adfi);
5051+
auto loc = parent->getLocation();
5052+
auto *vjp = task->getVJP();
5053+
auto *vjpRef = builder.createFunctionRef(loc, vjp);
5054+
auto finalVJP = reapplyFunctionConversion(
5055+
vjpRef, originalFRI, origFnOperand, builder, loc);
5056+
SILValue finalJVP;
5057+
// TODO: Implement "forward-mode differentiation" to get JVP. Currently it's
5058+
// just an undef because we won't use it.
5059+
{
5060+
auto origFnTy = origFnOperand->getType().getAs<SILFunctionType>();
5061+
auto jvpType = origFnTy->getAutoDiffAssociatedFunctionType(
5062+
adfi->getParameterIndices(), /*resultIndex*/ 0,
5063+
/*differentiationOrder*/ 1, AutoDiffAssociatedFunctionKind::JVP,
5064+
*getModule(),
5065+
/*lookupConformance*/
5066+
LookUpConformanceInModule(getModule()->getSwiftModule()));
5067+
finalJVP = SILUndef::get(SILType::getPrimitiveObjectType(jvpType),
5068+
getModule());
5069+
}
5070+
auto *newADFI = builder.createAutoDiffFunction(
5071+
loc, adfi->getParameterIndices(), adfi->getDifferentiationOrder(),
5072+
origFnOperand, {finalJVP, finalVJP});
5073+
adfi->replaceAllUsesWith(newADFI);
5074+
adfi->eraseFromParent();
5075+
}
5076+
// Differentiating opaque functions is not supported yet.
5077+
else {
5078+
// Find the original differential operator expression. Show an error at the
5079+
// operator, highlight the argument, and show a note at the definition site
5080+
// of the argument.
5081+
if (auto *expr = findDifferentialOperator(adfi))
5082+
context.diagnose(expr->getSubExpr()->getLoc(),
5083+
diag::autodiff_function_not_differentiable)
5084+
.highlight(expr->getSubExpr()->getSourceRange());
5085+
return true;
5086+
}
5087+
PM->invalidateAnalysis(parent, SILAnalysis::InvalidationKind::FunctionBody);
5088+
return false;
5089+
}
5090+
50065091
/// AD pass entry.
50075092
void Differentiation::run() {
50085093
auto &module = *getModule();
@@ -5013,6 +5098,7 @@ void Differentiation::run() {
50135098
SmallVector<std::pair<SILFunction *,
50145099
SILDifferentiableAttr *>, 8> diffAttrs;
50155100
SmallVector<GradientInst *, 16> gradInsts;
5101+
SmallVector<AutoDiffFunctionInst *, 16> autodiffInsts;
50165102
// Handle each `gradient` instruction and each `differentiable`
50175103
// attribute in the module.
50185104
for (SILFunction &f : module) {
@@ -5037,6 +5123,8 @@ void Differentiation::run() {
50375123
// operator, push it to the work list.
50385124
if (auto *gi = dyn_cast<GradientInst>(&i))
50395125
gradInsts.push_back(gi);
5126+
else if (auto *adfi = dyn_cast<AutoDiffFunctionInst>(&i))
5127+
autodiffInsts.push_back(adfi);
50405128
}
50415129
}
50425130
}
@@ -5081,9 +5169,12 @@ void Differentiation::run() {
50815169
// turned into a differentiation task. But we don't back out just yet - primal
50825170
// synthesis and adjoint synthesis for newly created differentiation tasks
50835171
// should still run because they will diagnose more errors.
5084-
bool errorProcessingGradInsts = false;
5172+
bool errorProcessingAutoDiffInsts = false;
50855173
for (auto *gi : gradInsts)
5086-
errorProcessingGradInsts |= processGradientInst(gi, context);
5174+
errorProcessingAutoDiffInsts |= processGradientInst(gi, context);
5175+
5176+
for (auto *adfi : autodiffInsts)
5177+
errorProcessingAutoDiffInsts |= processAutoDiffFunctionInst(adfi, context);
50875178

50885179
// Run primal generation for newly created differentiation tasks. If any error
50895180
// occurs, back out.
@@ -5099,7 +5190,7 @@ void Differentiation::run() {
50995190

51005191
// If there was any error that occurred during `gradient` instruction
51015192
// processing, back out.
5102-
if (errorProcessingGradInsts)
5193+
if (errorProcessingAutoDiffInsts)
51035194
return;
51045195

51055196
// Fill the body of each empty canonical gradient function.

branches/tensorflow/test/AutoDiff/autodiff_basic.sil

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,30 @@ bb0(%0 : @trivial $Float):
1717

1818
sil hidden @bar : $@convention(thin) (Float) -> (Float, Float) {
1919
bb0(%0 : @trivial $Float):
20-
%1 = function_ref @foo : $@convention(thin) (Float) -> Float
21-
%2 = gradient [source 0] [wrt 0] %1 : $@convention(thin) (Float) -> Float
22-
%3 = apply %2(%0) : $@convention(thin) (Float) -> Float
23-
%4 = gradient [source 0] [wrt 0] [preserving_result] %1 : $@convention(thin) (Float) -> Float
24-
%5 = apply %4(%3) : $@convention(thin) (Float) -> (Float, Float)
25-
return %5 : $(Float, Float)
20+
%fref = function_ref @foo : $@convention(thin) (Float) -> Float
21+
22+
%grad = gradient [source 0] [wrt 0] %fref : $@convention(thin) (Float) -> Float
23+
%grad_result = apply %grad(%0) : $@convention(thin) (Float) -> Float
24+
25+
%value_preserving_grad = gradient [source 0] [wrt 0] [preserving_result] %fref : $@convention(thin) (Float) -> Float
26+
%value_and_grad = apply %value_preserving_grad(%0) : $@convention(thin) (Float) -> (Float, Float)
27+
28+
%adfunc = autodiff_function [wrt 0] [order 1] %fref : $@convention(thin) (Float) -> Float
29+
30+
return %value_and_grad : $(Float, Float)
2631
}
2732

2833
// Here all `gradient` instructions have been replaced by `function_ref`s.
2934

3035
// CHECK-LABEL: sil hidden @bar :
31-
// CHECK: bb0
32-
// CHECK: %1 = function_ref @AD__foo__grad_src_0_wrt_0 : $@convention(thin) (Float) -> Float
33-
// CHECK: %2 = apply %1(%0) : $@convention(thin) (Float) -> Float
34-
// CHECK: %3 = function_ref @AD__foo__grad_src_0_wrt_0_p : $@convention(thin) (Float) -> (Float, Float)
35-
// CHECK: %4 = apply %3(%2) : $@convention(thin) (Float) -> (Float, Float)
36-
// CHECK: return %4 : $(Float, Float)
36+
// CHECK: [[FREF:%.*]] = function_ref @foo : $@convention(thin) (Float) -> Float
37+
// CHECK: [[GRAD_REF:%.*]] = function_ref @AD__foo__grad_src_0_wrt_0 : $@convention(thin) (Float) -> Float
38+
// CHECK: [[GRAD_RESULT:%.*]] = apply [[GRAD_REF]](%0) : $@convention(thin) (Float) -> Float
39+
// CHECK: [[VALUE_PRESERVING_GRAD_REF:%.*]] = function_ref @AD__foo__grad_src_0_wrt_0_p : $@convention(thin) (Float) -> (Float, Float)
40+
// CHECK: [[VALUE_PRESERVING_GRAD_RESULT:%.*]] = apply [[VALUE_PRESERVING_GRAD_REF]](%0) : $@convention(thin) (Float) -> (Float, Float)
41+
// CHECK: [[VJP_REF:%.*]] = function_ref @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
42+
// CHECK: autodiff_function [wrt 0] [order 1] [[FREF]] : $@convention(thin) (Float) -> Float with {undef : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float), [[VJP_REF]] : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)}
43+
// CHECK: return [[VALUE_PRESERVING_GRAD_RESULT]] : $(Float, Float)
3744
// CHECK: }
3845

3946
// CHECK-LABEL:sil hidden @AD__foo__grad_src_0_wrt_0_s_p :

branches/tensorflow/test/AutoDiff/autodiff_function_inst_irgen.sil

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ bb0(%0 : @trivial $Float, %1 : @trivial $Float, %2 : @trivial $Float):
1616
}
1717

1818
// The original function with an attribute that specifies the compiler-emitted pullback.
19-
sil hidden [differentiable source 0 wrt 0 primal @foo adjoint @foo_adj primitive] @foo : $@convention(thin) (Float) -> Float {
19+
sil hidden [differentiable source 0 wrt 0 primal @foo adjoint @foo_adj primitive vjp @foo_vjp] @foo : $@convention(thin) (Float) -> Float {
2020
bb0(%0 : @trivial $Float):
2121
return %0 : $Float
2222
}

0 commit comments

Comments
 (0)