Skip to content

Commit 5e43d80

Browse files
author
hongm
committed
---
yaml --- r: 311165 b: refs/heads/tensorflow-merge c: 730fc85 h: refs/heads/master i: 311163: 4695052
1 parent 90c5f2d commit 5e43d80

File tree

13 files changed

+482
-101
lines changed

13 files changed

+482
-101
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: 505ef09ac32aa169823591a8842003dc52edbf7f
1382+
refs/heads/tensorflow-merge: 730fc85b582c0997281b05e2da7a96da69a6c40b
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ PROTOCOL(CodingKey)
6767
PROTOCOL(Encodable)
6868
PROTOCOL(Decodable)
6969
// SWIFT_ENABLE_TENSORFLOW
70+
PROTOCOL(AccelerableByTensorFlow)
7071
PROTOCOL(TensorProtocol)
7172
PROTOCOL(Differentiable)
7273

branches/tensorflow-merge/lib/AST/ASTContext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ ProtocolDecl *ASTContext::getProtocol(KnownProtocolKind kind) const {
869869
M = getLoadedModule(Id_CoreFoundation);
870870
break;
871871
// SWIFT_ENABLE_TENSORFLOW
872+
case KnownProtocolKind::AccelerableByTensorFlow:
872873
case KnownProtocolKind::TensorProtocol:
873874
M = getLoadedModule(getIdentifier("TensorFlow"));
874875
break;

branches/tensorflow-merge/lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3472,6 +3472,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
34723472
case KnownProtocolKind::Encodable:
34733473
case KnownProtocolKind::Decodable:
34743474
// SWIFT_ENABLE_TENSORFLOW
3475+
case KnownProtocolKind::AccelerableByTensorFlow:
34753476
case KnownProtocolKind::TensorProtocol:
34763477
case KnownProtocolKind::Differentiable:
34773478
return SpecialProtocol::None;

branches/tensorflow-merge/lib/SILOptimizer/Mandatory/TFLowerGraph.cpp

Lines changed: 138 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ struct GraphGlobalConfiguration {
6262

6363
static const char DEVICE_TPU_REPLICATED_CORE[] = "TPU_REPLICATED_CORE";
6464
static const char DEVICE_TPU_SYSTEM[] = "TPU_SYSTEM";
65+
// Set a small number to exercise the bounded queue capacity more, increasing
66+
// test coverage.
67+
// FIXME: Tune the default value for performance, and/or make it configurable.
68+
static const int NAMED_TENSOR_QUEUE_CAPACITY = 1;
6569

6670
/// When generating a TF TPU graph, call this function to place an eligible TF
6771
/// graph node onto TPU device. Some nodes such as Placeholder and
@@ -142,6 +146,12 @@ namespace {
142146
/// This is a list of all of the operations that make up this function.
143147
std::vector<const TF_Operation*> operations;
144148

149+
// When true, lower effectful ops (e.g. Swift->TF send ops), if any, in the
150+
// corresponding TF function. Currently in a While op context, these ops
151+
// should not be run in the cond function.
152+
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
153+
// the cond function.
154+
bool shouldLowerEffectfulOps = true;
145155
public:
146156
GraphFunctionBody(GraphGlobalConfiguration configuration)
147157
: configuration(configuration), graph(TF_NewGraph(), &TF_DeleteGraph) {}
@@ -173,6 +183,21 @@ namespace {
173183

174184
return result;
175185
}
186+
187+
// If there is a control dependence value, run it before producing an output
188+
// tensor in GraphFunctionBody.
189+
TF_Output maybeRunEffectfulOp(TF_Output result, TF_Status *status) {
190+
if (!controlDependenceValue) return result;
191+
192+
std::string nodeName = "RunControlDependency";
193+
auto *desc = TF_NewOperation(getGraph(), "Identity", nodeName.c_str());
194+
TF_AddControlInput(desc, controlDependenceValue);
195+
TF_AddInput(desc, result);
196+
TF_Operation *newResult = finishOp(desc, /*hasSideEffects*/ false,
197+
/*isEligibleForTPU*/ false, status);
198+
controlDependenceValue = nullptr;
199+
return {newResult, 0};
200+
}
176201
};
177202
}
178203

@@ -194,6 +219,10 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
194219
// a value corresponds to, along with the scope ID of the value.
195220
ValueMappingScopedHashTable valueMapping;
196221

222+
// Track those tensor ids that have been lowered to graph ops for TF->Swift
223+
// tensor sends.
224+
llvm::SmallSet<int, 4> processedTensorIdsForSend;
225+
197226
/// This flag gets set if lowering code to the graph produces a TensorFlow
198227
/// error and emits a diagnostic. This tells us to stop lowering and give up
199228
/// gracefully.
@@ -458,10 +487,8 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
458487
void visitStringLiteralInst(StringLiteralInst *inst) {}
459488

460489
void visitBuiltinInst(BuiltinInst *inst);
461-
void visitBuiltinTFSendInst(BuiltinInst *inst) {
462-
internalError(inst->getLoc(),
463-
"GraphGen cannot lower a 'send' to the host yet");
464-
}
490+
void visitBuiltinTFSendInst(BuiltinInst *inst);
491+
465492
void visitBuiltinTFReceiveInst(BuiltinInst *inst) {
466493
internalError(inst->getLoc(),
467494
"GraphGen cannot lower a 'receive' from the host yet");
@@ -751,6 +778,99 @@ static void decodeShapeArray(SILInstruction *inst,
751778
}
752779
}
753780

781+
void TFGraphLowering::visitBuiltinTFSendInst(BuiltinInst *inst) {
782+
auto &graphFn = getCurrentGraphFunction();
783+
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
784+
// the while cond function.
785+
if (!graphFn.shouldLowerEffectfulOps) return;
786+
787+
// Decode the tensor id from the builtin name.
788+
// Example: builtin "tensorflowSend_0"<TensorHandle<Float>>(...) : $()
789+
int tensorId = -1;
790+
{
791+
auto name = inst->getName().str();
792+
auto tensorIdStr = name.substr(strlen("tensorflowSend_"));
793+
bool isInt = llvm::to_integer(tensorIdStr, tensorId, 10);
794+
assert(isInt);
795+
}
796+
797+
// Type check and process the parameter.
798+
TF_Output inputOp;
799+
TF_DataType inputType;
800+
{
801+
assert(inst->getNumOperands() == 1);
802+
auto operand = inst->getOperand(0);
803+
inputOp = getOperandValue(operand);
804+
if (!inputOp.oper) return; // Error occurred.
805+
inputType = getTensorFlowDataType(operand->getType(), inst->getLoc());
806+
}
807+
808+
// Add enqueue to the local graph function, and the corresponding dequeue to
809+
// the top level function, so that caller can dequeue tensors via SessionRun.
810+
TF_Operation *queueOp;
811+
{
812+
auto opName = "fifo_queue_" + llvm::itostr(tensorId);
813+
auto *desc =
814+
TF_NewOperation(graphFn.getGraph(), "FIFOQueueV2", opName.c_str());
815+
TF_SetDevice(desc, "/device:CPU:0");
816+
TF_SetAttrInt(desc, "capacity", NAMED_TENSOR_QUEUE_CAPACITY);
817+
TF_SetAttrTypeList(desc, "component_types", &inputType, 1);
818+
TF_SetAttrString(desc, "shared_name", opName.data(), opName.size());
819+
queueOp = graphFn.finishOp(desc, /*hasSideEffects*/ false,
820+
/*isEligibleForTPU*/ false, status);
821+
if (checkStatus(getUserSourceLocation(inst->getDebugLocation())))
822+
return;
823+
}
824+
825+
{
826+
auto opName = "fifo_queue_enqueue_" + llvm::itostr(tensorId);
827+
auto *desc =
828+
TF_NewOperation(graphFn.getGraph(), "QueueEnqueueV2", opName.c_str());
829+
TF_AddInput(desc, {queueOp, 0});
830+
TF_AddInputList(desc, &inputOp, 1);
831+
TF_SetDevice(desc, "/device:CPU:0");
832+
TF_SetAttrTypeList(desc, "Tcomponents", &inputType, 1);
833+
834+
graphFn.finishOp(desc, /*hasSideEffects*/ true,
835+
/*isEligibleForTPU*/ false, status);
836+
if (checkStatus(getUserSourceLocation(inst->getDebugLocation())))
837+
return;
838+
}
839+
840+
// Now add dequeue to the top level graph function.
841+
// Multiple graph functions can have an enqueue op over the same tensorId.
842+
// One example is to enqueue tensors both within the while op's body
843+
// function, and also right after the while op is executed.
844+
// In that case, we only generate a single dequeue op at the top level.
845+
if (!processedTensorIdsForSend.insert(tensorId).second) return;
846+
847+
// The code here is different enough from the above that it's not worth
848+
// extracting common code into functions.
849+
TF_Operation *globalQueueOp;
850+
{
851+
auto opName = "fifo_queue_" + llvm::itostr(tensorId);
852+
auto *desc = TF_NewOperation(resultGraph, "FIFOQueueV2", opName.c_str());
853+
TF_SetDevice(desc, "/device:CPU:0");
854+
TF_SetAttrInt(desc, "capacity", NAMED_TENSOR_QUEUE_CAPACITY);
855+
TF_SetAttrTypeList(desc, "component_types", &inputType, 1);
856+
// FIXME: Revisit whether to populate "shared_name".
857+
TF_SetAttrString(desc, "shared_name", opName.data(), opName.size());
858+
globalQueueOp = TF_FinishOperation(desc, status);
859+
if (checkStatus(getUserSourceLocation(inst->getDebugLocation())))
860+
return;
861+
}
862+
863+
{
864+
auto opName = "fifo_queue_dequeue_" + llvm::itostr(tensorId);
865+
auto *desc = TF_NewOperation(resultGraph, "QueueDequeueV2", opName.c_str());
866+
TF_AddInput(desc, {globalQueueOp, 0});
867+
TF_SetDevice(desc, "/device:CPU:0");
868+
TF_SetAttrTypeList(desc, "component_types", &inputType, 1);
869+
TF_FinishOperation(desc, status);
870+
if (checkStatus(getUserSourceLocation(inst->getDebugLocation()))) return;
871+
}
872+
}
873+
754874
void TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
755875
// FIXME: Also support dataset/iterator outside of TPU context.
756876
if(!configuration.isTPUEnabled || !configuration.isTPUInfeedEnabled) {
@@ -1190,11 +1310,15 @@ void TFGraphLowering::visitReturnInst(ReturnInst *inst) {
11901310
for (auto &operand : ti->getAllOperands()) {
11911311
auto result = getOperandValue(operand.get());
11921312
if (!result.oper) return; // Error occurred.
1313+
result = graphFn.maybeRunEffectfulOp(result, status);
1314+
if (checkStatus(SILFn.getLocation())) return;
11931315
graphFn.outputs.push_back({ /*SILArgument*/nullptr, result });
11941316
}
11951317
} else {
11961318
auto result = getOperandValue(inst->getOperand());
11971319
if (!result.oper) return; // Error occurred.
1320+
result = graphFn.maybeRunEffectfulOp(result, status);
1321+
if (checkStatus(SILFn.getLocation())) return;
11981322
graphFn.outputs.push_back({ /*SILArgument*/nullptr, result });
11991323
}
12001324
}
@@ -1214,6 +1338,8 @@ void TFGraphLowering::visitBranchInst(BranchInst *inst) {
12141338
for (unsigned i = 0, e = inst->getNumArgs(); i != e; ++i) {
12151339
auto result = getOperandValue(inst->getArg(i));
12161340
if (!result.oper) return; // Error occurred.
1341+
result = graphFn.maybeRunEffectfulOp(result, status);
1342+
if (checkStatus(SILFn.getLocation())) return;
12171343
graphFn.outputs.push_back({ destBB->getArgument(i), result });
12181344
}
12191345
}
@@ -1354,9 +1480,9 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
13541480
// body, we are required to emit the computation into both functions, and
13551481
// rely on XLA to CSE it where possible (which I suspect it doesn't do).
13561482
//
1357-
// This will also be problematic when the condition is allowed to have
1358-
// side effects (e.g. because of send and recv) because they cannot be
1359-
// reissued in general.
1483+
// This will also be problematic when the condition is allowed to have side
1484+
// effects (e.g. because of send and recv) because they cannot be reissued
1485+
// in general.
13601486
//
13611487
// A better model for while loop is to change the condition to be a function
13621488
// "T -> (U, bool)" and have the loop body be "U -> T". This structure
@@ -1389,9 +1515,10 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
13891515

13901516
for (unsigned i = loopBodyFn.outputs.size(), e = loopBodyFn.inputs.size();
13911517
i != e; ++i) {
1392-
loopBodyFn.outputs.push_back({
1393-
/*SILArgument*/nullptr, loopBodyFn.inputs[i].parameter
1394-
});
1518+
auto result =
1519+
loopBodyFn.maybeRunEffectfulOp(loopBodyFn.inputs[i].parameter, status);
1520+
if (checkStatus(SILFn.getLocation())) return;
1521+
loopBodyFn.outputs.push_back({/*SILArgument*/ nullptr, result});
13951522
}
13961523

13971524
// Next, lower the condition function into a 'stop predicate' for the loop.
@@ -1412,6 +1539,7 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
14121539

14131540
// Lower any code in the header block, which may be used by the termination
14141541
// condition. It ends with a conditional branch which we handle manually.
1542+
graphFn.shouldLowerEffectfulOps = false;
14151543
lowerBasicBlock(r->getHeader(), /*skipTerminator:*/ true);
14161544
if (errorOccurred) return;
14171545

0 commit comments

Comments
 (0)