@@ -54,6 +54,10 @@ static const char TPU_CLUSTER_ATTR_VALUE[] = "TPUReplicate/cluster";
54
54
// FIXME: Tune the default value for performance, and/or make it configurable.
55
55
static const int NAMED_TENSOR_QUEUE_CAPACITY = 1 ;
56
56
57
+ // The send device incarnation for TF sends/recvs.
58
+ // FIXME: revisit whether using a fixed value is good enough.
59
+ static const int DEVICE_INCARNATION_ID = 1 ;
60
+
57
61
// / When generating a TF TPU graph, call this function to place an eligible TF
58
62
// / graph node onto TPU device. Some nodes such as Placeholder and
59
63
// / Dataset/Iterator nodes are not eligible for TPU.
@@ -191,7 +195,8 @@ namespace {
191
195
struct TFGraphLowering : public SILInstructionVisitor <TFGraphLowering> {
192
196
SILFunction &SILFn;
193
197
// The TF device to which the generated graph is targeting.
194
- DeviceType thisDeviceType;
198
+ const DeviceType thisDeviceType;
199
+ const std::string thisDeviceTypeStr;
195
200
const GraphGlobalConfiguration &configuration;
196
201
TF_Graph *resultGraph;
197
202
TF_Status *status;
@@ -225,6 +230,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
225
230
TF_Graph *resultGraph, TF_Status *status)
226
231
: SILFn(fn),
227
232
thisDeviceType (thisDeviceType),
233
+ thisDeviceTypeStr(getDeviceString(thisDeviceType)),
228
234
configuration(configuration),
229
235
resultGraph(resultGraph),
230
236
status(status) {}
@@ -496,17 +502,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
496
502
void visitStringLiteralInst (StringLiteralInst *inst) {}
497
503
498
504
void visitBuiltinInst (BuiltinInst *inst);
499
- void visitBuiltinTFSendInst (BuiltinInst *inst);
500
- void visitBuiltinTFReceiveInst (BuiltinInst *inst);
501
-
502
- // / Create a stack of TF dataset and iterator nodes up to IteratorGetNext.
503
- // /
504
- // / FIXME: Dissolve this builtin into a set of finer-grained, composer
505
- // / features.
506
- void visitTFDataset (BuiltinInst *inst);
507
- bool createDatasetIteratorNodesWithInfeedEnqueue ();
508
505
509
- void visitTFOpInst (BuiltinInst *inst);
510
506
void visitTupleInst (TupleInst *inst);
511
507
void visitTupleExtractInst (TupleExtractInst *inst);
512
508
void visitUncheckedRefCastInst (UncheckedRefCastInst *inst);
@@ -534,6 +530,25 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
534
530
void lowerSequenceRegion (SequenceSESERegion *r);
535
531
void lowerWhileLoopRegion (WhileLoopSESERegion *r);
536
532
void lowerConditionalRegion (ConditionalSESERegion *r);
533
+
534
+ private: // Helpers for lowering.
535
+ // / Create a stack of TF dataset and iterator nodes up to IteratorGetNext.
536
+ // /
537
+ // / FIXME: Dissolve this builtin into a set of finer-grained, composable
538
+ // / features.
539
+ void visitTFDataset (BuiltinInst *inst);
540
+ bool createDatasetIteratorNodesWithInfeedEnqueue ();
541
+
542
+ void visitTFOpInst (BuiltinInst *inst);
543
+
544
+ void visitBuiltinSendToHostInst (SILTensorOpInfo &tfopInfo, BuiltinInst *inst);
545
+ void visitBuiltinRecvFromHostInst (SILTensorOpInfo &tfopInfo,
546
+ BuiltinInst *inst);
547
+ // D2D means device-to-device.
548
+ void visitBuiltinD2DTensorRecvInst (SILTensorOpInfo &tfopInfo,
549
+ BuiltinInst *inst);
550
+ void visitBuiltinD2DTensorSendInst (SILTensorOpInfo &tfopInfo,
551
+ BuiltinInst *inst);
537
552
};
538
553
}
539
554
@@ -577,9 +592,9 @@ std::string TFGraphLowering::getUniqueName(SILDebugLocation loc,
577
592
auto lineCol = SM.getLineAndColumn (ds->Loc .getSourceLoc ());
578
593
auto fnName = F->getName ();
579
594
580
- // Drop ".tf_partition " suffix off function names.
581
- if (fnName.endswith (" .tf_partition " ))
582
- fnName = fnName.drop_back (strlen (" .tf_partition " ));
595
+ // Drop ".device_partition " suffix off function names.
596
+ if (fnName.endswith (" .device_partition " ))
597
+ fnName = fnName.drop_back (strlen (" .device_partition " ));
583
598
584
599
name += " ." + fnName.str () + " ." + llvm::utostr (lineCol.first );
585
600
name += " ." + llvm::utostr (lineCol.second );
@@ -684,10 +699,6 @@ void TFGraphLowering::visitBuiltinInst(BuiltinInst *inst) {
684
699
// handle it directly.
685
700
if (inst->getName ().str () == " tf_tensor_to_i1" )
686
701
return ;
687
- if (inst->getName ().str ().startswith (" tensorflowReceive_" ))
688
- return visitBuiltinTFReceiveInst (inst);
689
- if (inst->getName ().str ().startswith (" tensorflowSend_" ))
690
- return visitBuiltinTFSendInst (inst);
691
702
if (inst->getName ().str ().startswith (
692
703
" __tfop_tfc.makeIteratorGetNextWithDatasets" ))
693
704
return visitTFDataset (inst);
@@ -793,32 +804,30 @@ static void decodeShapeArray(SILInstruction *inst,
793
804
}
794
805
}
795
806
796
- void TFGraphLowering::visitBuiltinTFSendInst (BuiltinInst *inst) {
807
+ void TFGraphLowering::visitBuiltinSendToHostInst (SILTensorOpInfo &tfopInfo,
808
+ BuiltinInst *inst) {
797
809
auto &graphFn = getCurrentGraphFunction ();
798
810
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
799
811
// the while cond function.
800
812
if (!graphFn.shouldLowerEffectfulOps ) return ;
801
813
802
- // Decode the tensor id from the builtin name.
803
- // Example: builtin "tensorflowSend_0"<TensorHandle<Float>>(...) : $()
804
- int tensorId = -1 ;
805
- {
806
- auto name = inst->getName ().str ();
807
- auto tensorIdStr = name.substr (strlen (" tensorflowSend_" ));
808
- bool isInt = llvm::to_integer (tensorIdStr, tensorId, 10 );
809
- assert (isInt);
810
- }
814
+ // Type check and process the parameters.
815
+ // SendToHost has type <T> (input$T, tensorId$int, device$str) -> ()
816
+ assert (inst->getNumResults () == 1 );
817
+ assert (inst->getNumOperands () == 3 );
818
+ assert (tfopInfo.isInput (0 ));
811
819
812
- // Type check and process the parameter.
813
820
TF_Output inputOp;
814
821
TF_DataType inputType;
815
822
{
816
- assert (inst->getNumOperands () == 1 );
817
823
auto operand = inst->getOperand (0 );
818
824
inputOp = getOperandValue (operand);
819
825
if (!inputOp.oper ) return ; // Error occurred.
820
826
inputType = getTensorFlowDataType (operand->getType (), inst->getLoc ());
821
827
}
828
+ int tensorId = tfopInfo.getIntAttrOperand (1 , " tensorId" );
829
+ assert (tfopInfo.getDeviceString () == DEFAULT_CPU_DEVICE &&
830
+ " SendToHost must run on CPU device" );
822
831
823
832
// Add enqueue to the local graph function, and the corresponding dequeue to
824
833
// the top level function, so that caller can dequeue tensors via SessionRun.
@@ -886,26 +895,24 @@ void TFGraphLowering::visitBuiltinTFSendInst(BuiltinInst *inst) {
886
895
}
887
896
}
888
897
889
- void TFGraphLowering::visitBuiltinTFReceiveInst (BuiltinInst *inst) {
898
+ void TFGraphLowering::visitBuiltinRecvFromHostInst (SILTensorOpInfo &tfopInfo,
899
+ BuiltinInst *inst) {
890
900
auto &graphFn = getCurrentGraphFunction ();
891
901
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
892
902
// the while cond function.
893
903
if (!graphFn.shouldLowerEffectfulOps ) return ;
894
904
895
- // Decode the tensor id from the builtin name.
896
- // Example: builtin "tensorflowReceive_0"<TensorHandle<Float>>(...) : $()
897
- int tensorId = -1 ;
898
- {
899
- auto name = inst->getName ().str ();
900
- auto tensorIdStr = name.substr (strlen (" tensorflowReceive_" ));
901
- bool isInt = llvm::to_integer (tensorIdStr, tensorId, 10 );
902
- assert (isInt);
903
- }
905
+ // Type check and process the parameters.
906
+ // recvFromHost has type <T> (tensorId$int, device$string) -> (T)
907
+ assert (inst->getNumResults () == 1 );
908
+ assert (inst->getNumOperands () == 2 );
909
+
910
+ int tensorId = tfopInfo.getIntAttrOperand (0 , " tensorId" );
911
+ assert (tfopInfo.getDeviceString () == DEFAULT_CPU_DEVICE &&
912
+ " SendToHost must run on CPU device" );
904
913
905
- // Type check and process the result.
906
914
TF_DataType outputType;
907
915
{
908
- assert (inst->getNumOperands () == 0 );
909
916
assert (inst->getNumResults () == 1 );
910
917
outputType =
911
918
getTensorFlowDataType (inst->getResults ()[0 ]->getType (), inst->getLoc ());
@@ -991,6 +998,72 @@ void TFGraphLowering::visitBuiltinTFReceiveInst(BuiltinInst *inst) {
991
998
}
992
999
}
993
1000
1001
+ void TFGraphLowering::visitBuiltinD2DTensorRecvInst (SILTensorOpInfo &tfopInfo,
1002
+ BuiltinInst *inst) {
1003
+ // Signature: "__tfop_tfc.D2DTensorRecv,transferId,srcDevice,device"
1004
+ assert (inst->getNumResults () == 1 );
1005
+ assert (inst->getNumOperands () == 3 );
1006
+
1007
+ int transferId = tfopInfo.getIntAttrOperand (0 , " transferId" );
1008
+ auto srcDevice = tfopInfo.getStringAttrOperand (1 , " srcDevice" );
1009
+ auto thisDevice = thisDeviceTypeStr;
1010
+ assert (tfopInfo.getDeviceString () == thisDevice);
1011
+ auto opName = " tf_recv_" + llvm::itostr (transferId);
1012
+ auto &graphFn = getCurrentGraphFunction ();
1013
+ auto *desc = TF_NewOperation (graphFn.getGraph (), " _Recv" , opName.c_str ());
1014
+ TF_SetDevice (desc, thisDevice.c_str ());
1015
+
1016
+ auto outputFromRecvVal = inst->getResults ()[0 ];
1017
+ auto tfType =
1018
+ getTensorFlowDataType (outputFromRecvVal->getType (), inst->getLoc ());
1019
+ assert (tfType > 0 );
1020
+ TF_SetAttrType (desc, " tensor_type" , (TF_DataType)tfType);
1021
+
1022
+ auto tensorName = " tensor_transfer_" + llvm::itostr (transferId);
1023
+ TF_SetAttrString (desc, " tensor_name" , tensorName.data (), tensorName.size ());
1024
+ TF_SetAttrString (desc, " send_device" , srcDevice.data (), srcDevice.size ());
1025
+ TF_SetAttrInt (desc, " send_device_incarnation" , DEVICE_INCARNATION_ID);
1026
+ TF_SetAttrString (desc, " recv_device" , thisDevice.data (), thisDevice.size ());
1027
+ auto *recvOp = graphFn.finishOp (desc, /* hasSideEffects*/ false ,
1028
+ /* isEligibleForTPU*/ false , status);
1029
+ if (checkStatus (getUserSourceLocation (inst->getDebugLocation ()))) return ;
1030
+ addValueMapping ({inst, 0 }, {recvOp, 0 });
1031
+ }
1032
+
1033
+ void TFGraphLowering::visitBuiltinD2DTensorSendInst (SILTensorOpInfo &tfopInfo,
1034
+ BuiltinInst *inst) {
1035
+ // Signature: "__tfop_tfc.D2DTensorSend,$in,transferId,destDevice,device"
1036
+ assert (inst->getNumResults () == 1 );
1037
+ assert (inst->getNumOperands () == 4 );
1038
+
1039
+ assert (tfopInfo.isInput (0 ));
1040
+ auto inputToSendVal = inst->getOperand (0 );
1041
+ auto inputToSendOp = getOperandValue (inputToSendVal);
1042
+ if (!inputToSendOp.oper ) return ; // Error occurred.
1043
+
1044
+ int transferId = tfopInfo.getIntAttrOperand (1 , " transferId" );
1045
+ auto destDevice = tfopInfo.getStringAttrOperand (2 , " destDevice" );
1046
+ auto thisDevice = thisDeviceTypeStr;
1047
+ assert (tfopInfo.getDeviceString () == thisDevice);
1048
+ auto opName = " tf_send_" + llvm::itostr (transferId);
1049
+ auto &graphFn = getCurrentGraphFunction ();
1050
+ auto *desc = TF_NewOperation (graphFn.getGraph (), " _Send" , opName.c_str ());
1051
+ TF_SetDevice (desc, thisDevice.c_str ());
1052
+ TF_AddInput (desc, inputToSendOp);
1053
+ auto tfType =
1054
+ getTensorFlowDataType (inputToSendVal->getType (), inst->getLoc ());
1055
+ assert (tfType > 0 );
1056
+ TF_SetAttrType (desc, " T" , (TF_DataType)tfType);
1057
+ auto tensorName = " tensor_transfer_" + llvm::itostr (transferId);
1058
+ TF_SetAttrString (desc, " tensor_name" , tensorName.data (), tensorName.size ());
1059
+ TF_SetAttrString (desc, " send_device" , thisDevice.data (), thisDevice.size ());
1060
+ TF_SetAttrInt (desc, " send_device_incarnation" , DEVICE_INCARNATION_ID);
1061
+ TF_SetAttrString (desc, " recv_device" , destDevice.data (), destDevice.size ());
1062
+ /* sendOp = */ graphFn.finishOp (desc, /* hasSideEffects*/ true ,
1063
+ /* isEligibleForTPU*/ false , status);
1064
+ checkStatus (getUserSourceLocation (inst->getDebugLocation ()));
1065
+ }
1066
+
994
1067
void TFGraphLowering::visitTFDataset (BuiltinInst *inst) {
995
1068
// FIXME: Also support dataset/iterator outside of TPU context.
996
1069
if (!configuration.isTPUEnabled () || !configuration.isTPUInfeedEnabled ) {
@@ -1109,6 +1182,20 @@ void TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
1109
1182
// /
1110
1183
void TFGraphLowering::visitTFOpInst (BuiltinInst *inst) {
1111
1184
SILTensorOpInfo tfopInfo = SILTensorOpInfo::decode (inst).getValue ();
1185
+
1186
+ // Swift host <-> TF device sends/recvs.
1187
+ if (tfopInfo.opName == " tfc.RecvFromHost" )
1188
+ return visitBuiltinRecvFromHostInst (tfopInfo, inst);
1189
+ else if (tfopInfo.opName == " tfc.SendToHost" )
1190
+ return visitBuiltinSendToHostInst (tfopInfo, inst);
1191
+
1192
+ // Device-to-device sends/recvs.
1193
+ if (tfopInfo.opName == " tfc.D2DTensorRecv" )
1194
+ return visitBuiltinD2DTensorRecvInst (tfopInfo, inst);
1195
+ else if (tfopInfo.opName == " tfc.D2DTensorSend" )
1196
+ return visitBuiltinD2DTensorSendInst (tfopInfo, inst);
1197
+
1198
+ // Handle other TF ops.
1112
1199
auto &graphFn = getCurrentGraphFunction ();
1113
1200
1114
1201
// The name label we put on the op is summarized from the "stack trace" of
@@ -1214,6 +1301,9 @@ void TFGraphLowering::visitTFOpInst(BuiltinInst *inst) {
1214
1301
if (name != DEVICE_ATTR) {
1215
1302
TF_SetAttrString (op, name.c_str (), value.data (), value.size ());
1216
1303
} else {
1304
+ if (value.str () == ALL_DEVICES) {
1305
+ value = thisDeviceTypeStr;
1306
+ }
1217
1307
if (value.str () != DEFAULT_TPU_DEVICE) {
1218
1308
TF_SetDevice (op, value.str ().c_str ());
1219
1309
} else {
@@ -2127,7 +2217,7 @@ bool TFGraphLowering::buildGraphNodesForTopLevelFunctionCall(
2127
2217
assert (thisDeviceType == DeviceType::TPU);
2128
2218
markNodeAsTPUReplicated (funcDesc);
2129
2219
} else {
2130
- TF_SetDevice (funcDesc, getDeviceString (thisDeviceType) .c_str ());
2220
+ TF_SetDevice (funcDesc, thisDeviceTypeStr .c_str ());
2131
2221
}
2132
2222
2133
2223
// FIXME: Revisit how to enable infeed outside the context of dataset /
@@ -2422,24 +2512,6 @@ static std::vector<char> serializeGraphProtoBuf(SILFunction &SILFn,
2422
2512
return std::vector<char >(bufPtr, bufPtr + buffer->length );
2423
2513
}
2424
2514
2425
- namespace {
2426
- class GraphPartitioner : public SILInstructionVisitor <GraphPartitioner> {
2427
- SILFunction &srcFn;
2428
- const GraphGlobalConfiguration &configuration;
2429
-
2430
- public:
2431
- GraphPartitioner (SILFunction &srcFn,
2432
- const GraphGlobalConfiguration &configuration)
2433
- : srcFn(srcFn), configuration(configuration) {}
2434
-
2435
- // / Returns a function extracted from `fn`, specialized on `deviceType`.
2436
- SILFunction *extractFunctionForDevice (DeviceType deviceType) {
2437
- // FIXME: Add real impl.
2438
- return &srcFn;
2439
- }
2440
- };
2441
- } // end anonymous namespace
2442
-
2443
2515
#endif // SWIFT_ENABLE_TENSORFLOW
2444
2516
2445
2517
// / Gets a function name that can be used as a TF op name.
@@ -2489,10 +2561,11 @@ std::vector<char> tf::lowerTFGraph(
2489
2561
TF_DeleteGraph (resultGraph);
2490
2562
};
2491
2563
2492
- GraphPartitioner partitioner (*fn, configuration);
2564
+ DevicePartitioner partitioner (*fn, configuration);
2493
2565
entryFnBaseName = getTFCompatibleFuncName (fn);
2494
2566
unsigned helperFuncId = 0 ;
2495
2567
for (const auto deviceType : configuration.usedDeviceTypes ) {
2568
+ assert (deviceType != DeviceType::ALL);
2496
2569
auto *perDeviceFn = partitioner.extractFunctionForDevice (deviceType);
2497
2570
bool isPrimaryFn = deviceType == configuration.deviceType ;
2498
2571
@@ -2521,12 +2594,22 @@ std::vector<char> tf::lowerTFGraph(
2521
2594
// The func op type is `fnName`, with the caller node name being
2522
2595
// based on `funcNodeBaseName`.
2523
2596
std::string funcNodeBaseName = entryFnBaseName;
2597
+ if (!isPrimaryFn) {
2598
+ funcNodeBaseName += " _helper_" + llvm::utostr (helperFuncId);
2599
+ ++helperFuncId;
2600
+ assert (inputTypes.empty ());
2601
+ assert (outputTypes.empty ());
2602
+ }
2524
2603
2525
2604
// Create the graph function for the top level code.
2526
2605
if (graphGen.buildGraphNodesForTopLevelFunctionCall (
2527
2606
fnName.str (), funcNodeBaseName, isPrimaryFn, inputTypes,
2528
2607
outputTypes))
2529
2608
return {};
2609
+
2610
+ // Remove the partitioned function so it doesn't go through the normal
2611
+ // compiler flow.
2612
+ perDeviceFn->getModule ().eraseFunction (perDeviceFn);
2530
2613
}
2531
2614
2532
2615
// Ok, we're done! Serialize the resulting graph to a protobuf and return it.
0 commit comments