@@ -62,6 +62,10 @@ struct GraphGlobalConfiguration {
62
62
63
63
static const char DEVICE_TPU_REPLICATED_CORE[] = " TPU_REPLICATED_CORE" ;
64
64
static const char DEVICE_TPU_SYSTEM[] = " TPU_SYSTEM" ;
65
+ // Set a small number to exercise the bounded queue capacity more, increasing
66
+ // test coverage.
67
+ // FIXME: Tune the default value for performance, and/or make it configurable.
68
+ static const int NAMED_TENSOR_QUEUE_CAPACITY = 1 ;
65
69
66
70
// / When generating a TF TPU graph, call this function to place an eligible TF
67
71
// / graph node onto TPU device. Some nodes such as Placeholder and
@@ -142,6 +146,12 @@ namespace {
142
146
// / This is a list of all of the operations that make up this function.
143
147
std::vector<const TF_Operation*> operations;
144
148
149
+ // When true, lower effectful ops (e.g. Swift->TF send ops), if any, in the
150
+ // corresponding TF function. Currently in a While op context, these ops
151
+ // should not be run in the cond function.
152
+ // TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
153
+ // the cond function.
154
+ bool shouldLowerEffectfulOps = true ;
145
155
public:
146
156
GraphFunctionBody (GraphGlobalConfiguration configuration)
147
157
: configuration(configuration), graph(TF_NewGraph(), &TF_DeleteGraph) {}
@@ -173,6 +183,21 @@ namespace {
173
183
174
184
return result;
175
185
}
186
+
187
+ // If there is a control dependence value, run it before producing an output
188
+ // tensor in GraphFunctionBody.
189
+ TF_Output maybeRunEffectfulOp (TF_Output result, TF_Status *status) {
190
+ if (!controlDependenceValue) return result;
191
+
192
+ std::string nodeName = " RunControlDependency" ;
193
+ auto *desc = TF_NewOperation (getGraph (), " Identity" , nodeName.c_str ());
194
+ TF_AddControlInput (desc, controlDependenceValue);
195
+ TF_AddInput (desc, result);
196
+ TF_Operation *newResult = finishOp (desc, /* hasSideEffects*/ false ,
197
+ /* isEligibleForTPU*/ false , status);
198
+ controlDependenceValue = nullptr ;
199
+ return {newResult, 0 };
200
+ }
176
201
};
177
202
}
178
203
@@ -194,6 +219,10 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
194
219
// a value corresponds to, along with the scope ID of the value.
195
220
ValueMappingScopedHashTable valueMapping;
196
221
222
+ // Track those tensor ids that have been lowered to graph ops for TF->Swift
223
+ // tensor sends.
224
+ llvm::SmallSet<int , 4 > processedTensorIdsForSend;
225
+
197
226
// / This flag gets set if lowering code to the graph produces a TensorFlow
198
227
// / error and emits a diagnostic. This tells us to stop lowering and give up
199
228
// / gracefully.
@@ -458,10 +487,8 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
458
487
void visitStringLiteralInst (StringLiteralInst *inst) {}
459
488
460
489
void visitBuiltinInst (BuiltinInst *inst);
461
- void visitBuiltinTFSendInst (BuiltinInst *inst) {
462
- internalError (inst->getLoc (),
463
- " GraphGen cannot lower a 'send' to the host yet" );
464
- }
490
+ void visitBuiltinTFSendInst (BuiltinInst *inst);
491
+
465
492
void visitBuiltinTFReceiveInst (BuiltinInst *inst) {
466
493
internalError (inst->getLoc (),
467
494
" GraphGen cannot lower a 'receive' from the host yet" );
@@ -751,6 +778,99 @@ static void decodeShapeArray(SILInstruction *inst,
751
778
}
752
779
}
753
780
781
+ void TFGraphLowering::visitBuiltinTFSendInst (BuiltinInst *inst) {
782
+ auto &graphFn = getCurrentGraphFunction ();
783
+ // TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
784
+ // the while cond function.
785
+ if (!graphFn.shouldLowerEffectfulOps ) return ;
786
+
787
+ // Decode the tensor id from the builtin name.
788
+ // Example: builtin "tensorflowSend_0"<TensorHandle<Float>>(...) : $()
789
+ int tensorId = -1 ;
790
+ {
791
+ auto name = inst->getName ().str ();
792
+ auto tensorIdStr = name.substr (strlen (" tensorflowSend_" ));
793
+ bool isInt = llvm::to_integer (tensorIdStr, tensorId, 10 );
794
+ assert (isInt);
795
+ }
796
+
797
+ // Type check and process the parameter.
798
+ TF_Output inputOp;
799
+ TF_DataType inputType;
800
+ {
801
+ assert (inst->getNumOperands () == 1 );
802
+ auto operand = inst->getOperand (0 );
803
+ inputOp = getOperandValue (operand);
804
+ if (!inputOp.oper ) return ; // Error occurred.
805
+ inputType = getTensorFlowDataType (operand->getType (), inst->getLoc ());
806
+ }
807
+
808
+ // Add enqueue to the local graph function, and the corresponding dequeue to
809
+ // the top level function, so that caller can dequeue tensors via SessionRun.
810
+ TF_Operation *queueOp;
811
+ {
812
+ auto opName = " fifo_queue_" + llvm::itostr (tensorId);
813
+ auto *desc =
814
+ TF_NewOperation (graphFn.getGraph (), " FIFOQueueV2" , opName.c_str ());
815
+ TF_SetDevice (desc, " /device:CPU:0" );
816
+ TF_SetAttrInt (desc, " capacity" , NAMED_TENSOR_QUEUE_CAPACITY);
817
+ TF_SetAttrTypeList (desc, " component_types" , &inputType, 1 );
818
+ TF_SetAttrString (desc, " shared_name" , opName.data (), opName.size ());
819
+ queueOp = graphFn.finishOp (desc, /* hasSideEffects*/ false ,
820
+ /* isEligibleForTPU*/ false , status);
821
+ if (checkStatus (getUserSourceLocation (inst->getDebugLocation ())))
822
+ return ;
823
+ }
824
+
825
+ {
826
+ auto opName = " fifo_queue_enqueue_" + llvm::itostr (tensorId);
827
+ auto *desc =
828
+ TF_NewOperation (graphFn.getGraph (), " QueueEnqueueV2" , opName.c_str ());
829
+ TF_AddInput (desc, {queueOp, 0 });
830
+ TF_AddInputList (desc, &inputOp, 1 );
831
+ TF_SetDevice (desc, " /device:CPU:0" );
832
+ TF_SetAttrTypeList (desc, " Tcomponents" , &inputType, 1 );
833
+
834
+ graphFn.finishOp (desc, /* hasSideEffects*/ true ,
835
+ /* isEligibleForTPU*/ false , status);
836
+ if (checkStatus (getUserSourceLocation (inst->getDebugLocation ())))
837
+ return ;
838
+ }
839
+
840
+ // Now add dequeue to the top level graph function.
841
+ // Multiple graph functions can have an enqueue op over the same tensorId.
842
+ // One example is to enqueue tensors both within the while op's body
843
+ // function, and also right after the while op is executed.
844
+ // In that case, we only generate a single dequeue op at the top level.
845
+ if (!processedTensorIdsForSend.insert (tensorId).second ) return ;
846
+
847
+ // The code here is different enough from the above that it's not worth
848
+ // extracting common code into functions.
849
+ TF_Operation *globalQueueOp;
850
+ {
851
+ auto opName = " fifo_queue_" + llvm::itostr (tensorId);
852
+ auto *desc = TF_NewOperation (resultGraph, " FIFOQueueV2" , opName.c_str ());
853
+ TF_SetDevice (desc, " /device:CPU:0" );
854
+ TF_SetAttrInt (desc, " capacity" , NAMED_TENSOR_QUEUE_CAPACITY);
855
+ TF_SetAttrTypeList (desc, " component_types" , &inputType, 1 );
856
+ // FIXME: Revisit whether to populate "shared_name".
857
+ TF_SetAttrString (desc, " shared_name" , opName.data (), opName.size ());
858
+ globalQueueOp = TF_FinishOperation (desc, status);
859
+ if (checkStatus (getUserSourceLocation (inst->getDebugLocation ())))
860
+ return ;
861
+ }
862
+
863
+ {
864
+ auto opName = " fifo_queue_dequeue_" + llvm::itostr (tensorId);
865
+ auto *desc = TF_NewOperation (resultGraph, " QueueDequeueV2" , opName.c_str ());
866
+ TF_AddInput (desc, {globalQueueOp, 0 });
867
+ TF_SetDevice (desc, " /device:CPU:0" );
868
+ TF_SetAttrTypeList (desc, " component_types" , &inputType, 1 );
869
+ TF_FinishOperation (desc, status);
870
+ if (checkStatus (getUserSourceLocation (inst->getDebugLocation ()))) return ;
871
+ }
872
+ }
873
+
754
874
void TFGraphLowering::visitTFDataset (BuiltinInst *inst) {
755
875
// FIXME: Also support dataset/iterator outside of TPU context.
756
876
if (!configuration.isTPUEnabled || !configuration.isTPUInfeedEnabled ) {
@@ -1190,11 +1310,15 @@ void TFGraphLowering::visitReturnInst(ReturnInst *inst) {
1190
1310
for (auto &operand : ti->getAllOperands ()) {
1191
1311
auto result = getOperandValue (operand.get ());
1192
1312
if (!result.oper ) return ; // Error occurred.
1313
+ result = graphFn.maybeRunEffectfulOp (result, status);
1314
+ if (checkStatus (SILFn.getLocation ())) return ;
1193
1315
graphFn.outputs .push_back ({ /* SILArgument*/ nullptr , result });
1194
1316
}
1195
1317
} else {
1196
1318
auto result = getOperandValue (inst->getOperand ());
1197
1319
if (!result.oper ) return ; // Error occurred.
1320
+ result = graphFn.maybeRunEffectfulOp (result, status);
1321
+ if (checkStatus (SILFn.getLocation ())) return ;
1198
1322
graphFn.outputs .push_back ({ /* SILArgument*/ nullptr , result });
1199
1323
}
1200
1324
}
@@ -1214,6 +1338,8 @@ void TFGraphLowering::visitBranchInst(BranchInst *inst) {
1214
1338
for (unsigned i = 0 , e = inst->getNumArgs (); i != e; ++i) {
1215
1339
auto result = getOperandValue (inst->getArg (i));
1216
1340
if (!result.oper ) return ; // Error occurred.
1341
+ result = graphFn.maybeRunEffectfulOp (result, status);
1342
+ if (checkStatus (SILFn.getLocation ())) return ;
1217
1343
graphFn.outputs .push_back ({ destBB->getArgument (i), result });
1218
1344
}
1219
1345
}
@@ -1354,9 +1480,9 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
1354
1480
// body, we are required to emit the computation into both functions, and
1355
1481
// rely on XLA to CSE it where possible (which I suspect it doesn't do).
1356
1482
//
1357
- // This will also be problematic when the condition is allowed to have
1358
- // side effects (e.g. because of send and recv) because they cannot be
1359
- // reissued in general.
1483
+ // This will also be problematic when the condition is allowed to have side
1484
+ // effects (e.g. because of send and recv) because they cannot be reissued
1485
+ // in general.
1360
1486
//
1361
1487
// A better model for while loop is to change the condition to be a function
1362
1488
// "T -> (U, bool)" and have the loop body be "U -> T". This structure
@@ -1389,9 +1515,10 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
1389
1515
1390
1516
for (unsigned i = loopBodyFn.outputs .size (), e = loopBodyFn.inputs .size ();
1391
1517
i != e; ++i) {
1392
- loopBodyFn.outputs .push_back ({
1393
- /* SILArgument*/ nullptr , loopBodyFn.inputs [i].parameter
1394
- });
1518
+ auto result =
1519
+ loopBodyFn.maybeRunEffectfulOp (loopBodyFn.inputs [i].parameter , status);
1520
+ if (checkStatus (SILFn.getLocation ())) return ;
1521
+ loopBodyFn.outputs .push_back ({/* SILArgument*/ nullptr , result});
1395
1522
}
1396
1523
1397
1524
// Next, lower the condition function into a 'stop predicate' for the loop.
@@ -1412,6 +1539,7 @@ void TFGraphLowering::lowerWhileLoopRegion(WhileLoopSESERegion *r) {
1412
1539
1413
1540
// Lower any code in the header block, which may be used by the termination
1414
1541
// condition. It ends with a conditional branch which we handle manually.
1542
+ graphFn.shouldLowerEffectfulOps = false ;
1415
1543
lowerBasicBlock (r->getHeader (), /* skipTerminator:*/ true );
1416
1544
if (errorOccurred) return ;
1417
1545
0 commit comments