Skip to content

Commit 421798a

Browse files
authored
---
yaml --- r: 262075 b: refs/heads/tensorflow c: 932f7fa h: refs/heads/master i: 262073: 3667a78 262071: afc2ce4
1 parent ff0944a commit 421798a

File tree

10 files changed

+155
-21
lines changed

10 files changed

+155
-21
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: b435eef3e693c6665ce53a287026d01f48411ca8
821+
refs/heads/tensorflow: 932f7fa6ec4db24d344ac26198d25432d94dbf46
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/include/swift/AST/Builtins.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ BUILTIN_SIL_OPERATION(AllocWithTailElems, "allocWithTailElems", Special)
376376
BUILTIN_SIL_OPERATION(ProjectTailElems, "projectTailElems", Special)
377377

378378
// SWIFT_ENABLE_TENSORFLOW
379-
/// autodifGetJVP has type <T: Differentiable, R: Differentiable>
379+
/// autodiffGetJVP has type <T: Differentiable, R: Differentiable>
380380
/// ((T) -> R) -> (T) -> (R, (T.TangentVector) -> R.TangentVector).
381381
BUILTIN_SIL_OPERATION(AutoDiffGetJVP, "autodiffGetJVP", Special)
382-
/// autodifGetVJP has type <T: Differentiable, R: Differentiable>
382+
/// autodiffGetVJP has type <T: Differentiable, R: Differentiable>
383383
/// ((T) -> R) -> (T) -> (R, (R.CotangentVector) -> T.CotangentVector).
384384
BUILTIN_SIL_OPERATION(AutoDiffGetVJP, "autodiffGetVJP", Special)
385385

branches/tensorflow/include/swift/SIL/SILInstruction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7648,13 +7648,14 @@ class AutoDiffFunctionExtractInst :
76487648
SingleValueInstruction> {
76497649
public:
76507650
struct Extractee {
7651-
enum innerty : uint8_t {
7651+
enum innerty : unsigned {
76527652
Original = 0,
76537653
JVP = 1,
76547654
VJP = 2
76557655
} rawValue;
76567656
Extractee() = default;
76577657
Extractee(innerty rawValue) : rawValue(rawValue) {}
7658+
Extractee(unsigned rawValue) : Extractee((innerty)rawValue) {}
76587659
explicit Extractee(StringRef name);
76597660
operator innerty() const { return rawValue; }
76607661
};

branches/tensorflow/include/swift/Serialization/ModuleFormat.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const uint16_t SWIFTMODULE_VERSION_MAJOR = 0;
5252
/// describe what change you made. The content of this comment isn't important;
5353
/// it just ensures a conflict if two people change the module format.
5454
/// Don't worry about adhering to the 80-column limit for this line.
55-
const uint16_t SWIFTMODULE_VERSION_MINOR = 457; // Last change: add jvp and vjp to sil differentiable attribute
55+
const uint16_t SWIFTMODULE_VERSION_MINOR = 458; // Last change: serialize autodiff_function and autodiff_function_extract
5656

5757
using DeclIDField = BCFixed<31>;
5858

branches/tensorflow/lib/SIL/OwnershipUtils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ bool swift::isOwnershipForwardingValueKind(SILNodeKind kind) {
4545
case SILNodeKind::DestructureTupleInst:
4646
// SWIFT_ENABLE_TENSORFLOW
4747
case SILNodeKind::GradientInst:
48+
case SILNodeKind::AutoDiffFunctionInst:
49+
case SILNodeKind::AutoDiffFunctionExtractInst:
4850
return true;
4951
default:
5052
return false;

branches/tensorflow/lib/Serialization/DeserializeSIL.cpp

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,9 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
972972
Builder.setInsertionPoint(BB);
973973
Builder.setCurrentDebugScope(Fn->getDebugScope());
974974
unsigned RawOpCode = 0, TyCategory = 0, TyCategory2 = 0, TyCategory3 = 0,
975-
Attr = 0, NumSubs = 0, NumConformances = 0, IsNonThrowingApply = 0;
975+
// SWIFT_ENABLE_TENSORFLOW
976+
Attr = 0, Attr2 = 0, NumSubs = 0, NumConformances = 0,
977+
IsNonThrowingApply = 0;
976978
// SWIFT_ENABLE_TENSORFLOW
977979
unsigned NumArguments = 0;
978980
unsigned GradResultIndex = 0;
@@ -1087,6 +1089,18 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
10871089
GradResultIndex, ListOfValues);
10881090
RawOpCode = (unsigned)SILInstructionKind::GradientInst;
10891091
break;
1092+
case SIL_INST_AUTODIFF_FUNCTION:
1093+
SILInstAutoDiffFunctionLayout::readRecord(scratch, /*order*/ Attr,
1094+
/*numParams*/ Attr2, NumArguments,
1095+
ListOfValues);
1096+
RawOpCode = (unsigned)SILInstructionKind::AutoDiffFunctionInst;
1097+
break;
1098+
case SIL_INST_AUTODIFF_FUNCTION_EXTRACT:
1099+
SILInstAutoDiffFunctionExtractLayout::readRecord(scratch, TyID, TyCategory,
1100+
ValID, /*extractee*/ Attr,
1101+
/*order*/ Attr2);
1102+
RawOpCode = (unsigned)SILInstructionKind::AutoDiffFunctionExtractInst;
1103+
break;
10901104
case SIL_INST_NO_OPERAND:
10911105
SILInstNoOperandLayout::readRecord(scratch, RawOpCode);
10921106
break;
@@ -1488,10 +1502,32 @@ bool SILDeserializer::readSILInstruction(SILFunction *Fn, SILBasicBlock *BB,
14881502
break;
14891503
}
14901504
case SILInstructionKind::AutoDiffFunctionInst: {
1491-
llvm_unreachable("FIXME: Unhandled");
1505+
auto numParamIndices = ListOfValues.size() - NumArguments * 3;
1506+
auto paramIndices = ListOfValues.take_front(numParamIndices);
1507+
auto numParams = Attr2;
1508+
llvm::SmallBitVector paramIndicesBitVec(numParams);
1509+
for (unsigned idx : paramIndices)
1510+
paramIndicesBitVec.set(idx);
1511+
SmallVector<SILValue, 4> operands;
1512+
for (auto i = numParamIndices; i < NumArguments * 3; i += 3) {
1513+
auto astTy = MF->getType(ListOfValues[i]);
1514+
auto silTy = getSILType(astTy, (SILValueCategory)ListOfValues[i+1]);
1515+
operands.push_back(getLocalValue(ListOfValues[i+2], silTy));
1516+
}
1517+
ResultVal = Builder.createAutoDiffFunction(Loc, paramIndicesBitVec,
1518+
/*differentiationOrder*/ Attr, operands[0],
1519+
ArrayRef<SILValue>(operands).drop_front());
1520+
break;
14921521
}
14931522
case SILInstructionKind::AutoDiffFunctionExtractInst: {
1494-
llvm_unreachable("FIXME: unhandled");
1523+
auto astTy = MF->getType(TyID);
1524+
auto silTy = getSILType(astTy, SILValueCategory::Object);
1525+
auto val = getLocalValue(ValID, silTy);
1526+
AutoDiffFunctionExtractee extractee(Attr);
1527+
auto order = Attr2;
1528+
ResultVal =
1529+
Builder.createAutoDiffFunctionExtract(Loc, extractee, order, val);
1530+
break;
14951531
}
14961532
case SILInstructionKind::GraphOperationInst: {
14971533
// TODO(SR-8848): Deserialize attributes.

branches/tensorflow/lib/Serialization/SILFormat.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ namespace sil_block {
175175
SIL_REVERSE_DIFFERENTIABLE_ATTR,
176176
SIL_INST_GRAPH_OPERATION,
177177
SIL_INST_GRADIENT,
178+
SIL_INST_AUTODIFF_FUNCTION,
179+
SIL_INST_AUTODIFF_FUNCTION_EXTRACT,
178180

179181
// We also share these layouts from the decls block. Their enumerators must
180182
// not overlap with ours.
@@ -293,7 +295,7 @@ namespace sil_block {
293295
BCFixed<3>, // side effect info.
294296
BCVBR<8>, // number of specialize attributes
295297
// SWIFT_ENABLE_TENSORFLOW
296-
BCVBR<8>, // number of reverse differentiable attributes
298+
BCVBR<8>, // number of differentiable attributes
297299
BCFixed<1>, // has qualified ownership
298300
BCFixed<1>, // must be weakly referenced
299301
TypeIDField, // SILFunctionType
@@ -429,6 +431,23 @@ namespace sil_block {
429431
BCArray<BCFixed<1>> // parameter indices
430432
>;
431433

434+
using SILInstAutoDiffFunctionLayout = BCRecordLayout<
435+
SIL_INST_AUTODIFF_FUNCTION,
436+
BCVBR<8>, // differentiation order
437+
BCVBR<8>, // number of function parameters
438+
BCVBR<8>, // number of operands
439+
BCArray<ValueIDField> // parameter indices and operands
440+
>;
441+
442+
using SILInstAutoDiffFunctionExtractLayout = BCRecordLayout<
443+
SIL_INST_AUTODIFF_FUNCTION_EXTRACT,
444+
TypeIDField,
445+
SILTypeCategoryField,
446+
ValueIDField,
447+
BCFixed<2>, // extractee
448+
BCVBR<8> // order
449+
>;
450+
432451
// SIL instructions with one type. (alloc_stack)
433452
using SILOneTypeLayout = BCRecordLayout<
434453
SIL_ONE_TYPE,

branches/tensorflow/lib/Serialization/SerializeSIL.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -963,10 +963,34 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
963963
break;
964964
}
965965
case SILInstructionKind::AutoDiffFunctionInst: {
966-
llvm_unreachable("FIXME: handle this");
966+
auto *adfi = cast<AutoDiffFunctionInst>(&SI);
967+
SmallVector<ValueID, 4> trailingInfo;
968+
auto &paramIndices = adfi->getParameterIndices();
969+
for (unsigned idx : paramIndices.set_bits())
970+
trailingInfo.push_back(idx);
971+
for (auto &op : adfi->getAllOperands()) {
972+
auto val = op.get();
973+
trailingInfo.push_back(S.addTypeRef(val->getType().getASTType()));
974+
trailingInfo.push_back((unsigned)val->getType().getCategory());
975+
trailingInfo.push_back(addValueRef(val));
976+
}
977+
SILInstAutoDiffFunctionLayout::emitRecord(Out, ScratchRecord,
978+
SILAbbrCodes[SILInstAutoDiffFunctionLayout::Code],
979+
adfi->getDifferentiationOrder(), (unsigned)paramIndices.size(),
980+
adfi->getNumOperands(), trailingInfo);
981+
break;
967982
}
968983
case SILInstructionKind::AutoDiffFunctionExtractInst: {
969-
llvm_unreachable("FIXME: handle this");
984+
auto *adfei = cast<AutoDiffFunctionExtractInst>(&SI);
985+
auto operandRef = addValueRef(adfei->getFunctionOperand());
986+
auto operandType = adfei->getFunctionOperand()->getType();
987+
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
988+
auto rawExtractee = (unsigned)adfei->getExtractee();
989+
SILInstAutoDiffFunctionExtractLayout::emitRecord(Out, ScratchRecord,
990+
SILAbbrCodes[SILInstAutoDiffFunctionExtractLayout::Code],
991+
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
992+
rawExtractee, adfei->getDifferentiationOrder());
993+
break;
970994
}
971995
case SILInstructionKind::GraphOperationInst: {
972996
// TODO(SR-8848): Serialize attributes.
@@ -2461,6 +2485,8 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
24612485
registerSILAbbr<SILDifferentiableAttrLayout>();
24622486
registerSILAbbr<SILInstGraphOperationLayout>();
24632487
registerSILAbbr<SILInstGradientLayout>();
2488+
registerSILAbbr<SILInstAutoDiffFunctionLayout>();
2489+
registerSILAbbr<SILInstAutoDiffFunctionExtractLayout>();
24642490

24652491
// Register the abbreviation codes so these layouts can exist in both
24662492
// decl blocks and sil blocks.

branches/tensorflow/stdlib/public/core/AutoDiff.swift

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,65 @@ public extension Differentiable where TangentVector == CotangentVector {
9494
// Differential Operators
9595
//===----------------------------------------------------------------------===//
9696

97-
@_transparent @_effects(readnone)
98-
public func gradient<T : Differentiable, R : FloatingPoint>(
99-
of f: /*@autodiff*/ @escaping (T) -> R
100-
) -> (T) -> T {
101-
return #gradient(f)
97+
// FIXME(rxwei): Fix SR-9458.
98+
#if false
99+
100+
@inlinable
101+
public func valueWithDifferential<T, R>(
102+
at x: T, in f: @escaping @autodiff (T) -> R
103+
) -> (value: R, differential: (T.TangentVector) -> R.TangentVector)
104+
where T : Differentiable, R : Differentiable {
105+
return Builtin.autodiffGetJVP(f)(x)
102106
}
103107

104-
@_transparent @_effects(readnone)
105-
public func gradient<T : Differentiable, R : FloatingPoint>(
106-
at x: T, in f: /*@autodiff*/ @escaping (T) -> R
107-
) -> T {
108-
return #gradient(f)(x)
108+
@inlinable
109+
public func valueWithPullback<T, R>(
110+
at x: T, in f: @escaping @autodiff (T) -> R
111+
) -> (value: R, pullback: (R.CotangentVector) -> T.CotangentVector)
112+
where T : Differentiable, R : Differentiable {
113+
return Builtin.autodiffGetVJP(f)(x)
109114
}
110115

116+
@inlinable
117+
public func derivative<T, R>(
118+
at x: T, in f: @escaping @autodiff (T) -> R
119+
) -> R.TangentVector
120+
where T : BinaryFloatingPoint & Differentiable, R : Differentiable,
121+
T.TangentVector == T {
122+
let (y, differential) = valueWithDifferential(at: x, in: f)
123+
return differential(1)
124+
}
125+
126+
@inlinable
127+
public func derivative<T, R>(
128+
of f: @escaping @autodiff (T) -> R
129+
) -> (T) -> R.TangentVector
130+
where T : BinaryFloatingPoint & Differentiable, R : Differentiable,
131+
T.TangentVector == T {
132+
return { x in derivative(at: x, in: f) }
133+
}
134+
135+
@inlinable
136+
public func gradient<T, R>(
137+
at x: T, in f: @escaping @autodiff (T) -> R
138+
) -> T.CotangentVector
139+
where T : Differentiable, R : BinaryFloatingPoint & Differentiable,
140+
R.CotangentVector == R {
141+
let (y, pullback) = valueWithPullback(at: x, in: f)
142+
return pullback(1)
143+
}
144+
145+
@inlinable
146+
public func gradient<T, R>(
147+
of f: @escaping @autodiff (T) -> R
148+
) -> (T) -> T.CotangentVector
149+
where T : Differentiable, R : BinaryFloatingPoint & Differentiable,
150+
R.CotangentVector == R {
151+
return { x in gradient(at: x, in: f) }
152+
}
153+
154+
#endif
155+
111156
//===----------------------------------------------------------------------===//
112157
// Builtins
113158
//===----------------------------------------------------------------------===//

branches/tensorflow/test/AutoDiff/autodiff_function_inst.sil

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
// RUN: %target-sil-opt %s | %FileCheck %s
22

3+
// RUN: %empty-directory(%t)
4+
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name autodiff_function
5+
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name autodiff_function
6+
// RUN: %target-sil-opt %t/tmp.2.sib -module-name autodiff_function | %FileCheck %s
7+
38
sil_stage raw
49

510
import Swift

0 commit comments

Comments
 (0)