Skip to content

Commit db0c87e

Browse files
Mingsheng HongMarc Rasi
authored andcommitted
Part 4 of cross-device sends/recvs support: Added initial support for SIL (#17165)
Part 4 of cross-device sends/recvs support: Added initial support for SIL accelerator function partitioning based on TF devices {CPU, GPU}, including control flow. Summary of changes: 1. Extended the DeviceType with an ALL enum value, indicating that an associated instruction runs on all devices with which the TF computation is involved. For example, promoted scalars run on ALL devices. Also, for ease of control flow handling, BB args are present on ALL devices. The exception is the function input arguments, which are only present in the primary device function (recall the primary function is the partitioned function that runs on a target device given by TensorFlow.enableGPU(), TensorFlow.enableTPU() or a default policy), while the helper functions do not take input or output tensors. 2. Added a new pass DevicePartitioner that sits between the PartitionerCloner pass in TFPartition and the TFGraphLowering pass in TFLowerGraph. It has two phases described as follows. In the analysis/mark phase, it inserts instructions for cross-device tensor sends/recvs, represented by "__tfop_tfc.TensorTransfer" builtin's. For example, when tensor x is produced on device D1, and is then consumed by tensor op foo() on device D2, it inserts right before foo() a "__tfop_tfc.TensorTransfer" builtin to send that tensor from D1 to D2. This builtin helps maintain the invariant that for any instruction I running on some device D, for any operand OP of I, OP must be present on D (either because OP is produced on D, or it is transfered via this builtin). When tf-dump-graph flag is on, the output SIL of this phase is dumped under a header like: --- TFDevicePartition Cross Device Tensor Transfer Annotation Result: $S3tmp10testScalar1fySf_tF.tf In the partitioning phase (DevicePartitionCloner), it extracts all instructions related to a given target device D into a new SIL function, to be lowered by TFGraphLowering. For a "__tfop_tfc.TensorTransfer" builtin: - If D is its source/send device, it gets lowered to a TF _Send op in the CPU/GPU device context, via a "__tfop_tfc.D2DTensorSend" builtin. - If D is its dest/recv device, it gets lowered to a TF _Recv op in the CPU/GPU device context, via a "__tfop_tfc.D2DTensorRecv" builtin. For control flow support, in each partitioned, device-spcific SIL function produced by DevicePartitionCloner, it retains all basic blocks from the input accelerator SIL function, along with the BB args. When tf-dump-graph flag is on, the output of this phase is dumped under a header like: --- TFDevicePartition Per-Device Function Extraction Result: $S3tmp10testScalar1fySf_tF.tf_CPU.device_partition 3. Extended the TFGraphLowering pass to turn D2DTensorSend/D2DTensorRecv into TF _Send and _Recv nodes. These nodes work on CPU and GPU. In the TPU device context, the above can be lowered to infeed/outfeed or HostCompute. This is to be explored later. 4. Also upgraded "tensorflowSend" and "tensorflowReceive" built-ins with "tfc.SendToHost" and "tfc.RecvFromHost" builtins, with proper tfop attributes to represent the tensor transfer id and send/recv devices.
1 parent 5a3ff57 commit db0c87e

File tree

9 files changed

+1307
-167
lines changed

9 files changed

+1307
-167
lines changed

lib/SILOptimizer/Mandatory/TFDevicePartition.cpp

Lines changed: 727 additions & 0 deletions
Large diffs are not rendered by default.

lib/SILOptimizer/Mandatory/TFLowerGraph.cpp

Lines changed: 145 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ static const char TPU_CLUSTER_ATTR_VALUE[] = "TPUReplicate/cluster";
5454
// FIXME: Tune the default value for performance, and/or make it configurable.
5555
static const int NAMED_TENSOR_QUEUE_CAPACITY = 1;
5656

57+
// The send device incarnation for TF sends/recvs.
58+
// FIXME: revisit whether using a fixed value is good enough.
59+
static const int DEVICE_INCARNATION_ID = 1;
60+
5761
/// When generating a TF TPU graph, call this function to place an eligible TF
5862
/// graph node onto TPU device. Some nodes such as Placeholder and
5963
/// Dataset/Iterator nodes are not eligible for TPU.
@@ -191,7 +195,8 @@ namespace {
191195
struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
192196
SILFunction &SILFn;
193197
// The TF device to which the generated graph is targeting.
194-
DeviceType thisDeviceType;
198+
const DeviceType thisDeviceType;
199+
const std::string thisDeviceTypeStr;
195200
const GraphGlobalConfiguration &configuration;
196201
TF_Graph *resultGraph;
197202
TF_Status *status;
@@ -225,6 +230,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
225230
TF_Graph *resultGraph, TF_Status *status)
226231
: SILFn(fn),
227232
thisDeviceType(thisDeviceType),
233+
thisDeviceTypeStr(getDeviceString(thisDeviceType)),
228234
configuration(configuration),
229235
resultGraph(resultGraph),
230236
status(status) {}
@@ -496,17 +502,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
496502
void visitStringLiteralInst(StringLiteralInst *inst) {}
497503

498504
void visitBuiltinInst(BuiltinInst *inst);
499-
void visitBuiltinTFSendInst(BuiltinInst *inst);
500-
void visitBuiltinTFReceiveInst(BuiltinInst *inst);
501-
502-
/// Create a stack of TF dataset and iterator nodes up to IteratorGetNext.
503-
///
504-
/// FIXME: Dissolve this builtin into a set of finer-grained, composer
505-
/// features.
506-
void visitTFDataset(BuiltinInst *inst);
507-
bool createDatasetIteratorNodesWithInfeedEnqueue();
508505

509-
void visitTFOpInst(BuiltinInst *inst);
510506
void visitTupleInst(TupleInst *inst);
511507
void visitTupleExtractInst(TupleExtractInst *inst);
512508
void visitUncheckedRefCastInst(UncheckedRefCastInst *inst);
@@ -534,6 +530,25 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering> {
534530
void lowerSequenceRegion(SequenceSESERegion *r);
535531
void lowerWhileLoopRegion(WhileLoopSESERegion *r);
536532
void lowerConditionalRegion(ConditionalSESERegion *r);
533+
534+
private: // Helpers for lowering.
535+
/// Create a stack of TF dataset and iterator nodes up to IteratorGetNext.
536+
///
537+
/// FIXME: Dissolve this builtin into a set of finer-grained, composable
538+
/// features.
539+
void visitTFDataset(BuiltinInst *inst);
540+
bool createDatasetIteratorNodesWithInfeedEnqueue();
541+
542+
void visitTFOpInst(BuiltinInst *inst);
543+
544+
void visitBuiltinSendToHostInst(SILTensorOpInfo &tfopInfo, BuiltinInst *inst);
545+
void visitBuiltinRecvFromHostInst(SILTensorOpInfo &tfopInfo,
546+
BuiltinInst *inst);
547+
// D2D means device-to-device.
548+
void visitBuiltinD2DTensorRecvInst(SILTensorOpInfo &tfopInfo,
549+
BuiltinInst *inst);
550+
void visitBuiltinD2DTensorSendInst(SILTensorOpInfo &tfopInfo,
551+
BuiltinInst *inst);
537552
};
538553
}
539554

@@ -577,9 +592,9 @@ std::string TFGraphLowering::getUniqueName(SILDebugLocation loc,
577592
auto lineCol = SM.getLineAndColumn(ds->Loc.getSourceLoc());
578593
auto fnName = F->getName();
579594

580-
// Drop ".tf_partition" suffix off function names.
581-
if (fnName.endswith(".tf_partition"))
582-
fnName = fnName.drop_back(strlen(".tf_partition"));
595+
// Drop ".device_partition" suffix off function names.
596+
if (fnName.endswith(".device_partition"))
597+
fnName = fnName.drop_back(strlen(".device_partition"));
583598

584599
name += "." + fnName.str() + "." + llvm::utostr(lineCol.first);
585600
name += "." + llvm::utostr(lineCol.second);
@@ -684,10 +699,6 @@ void TFGraphLowering::visitBuiltinInst(BuiltinInst *inst) {
684699
// handle it directly.
685700
if (inst->getName().str() == "tf_tensor_to_i1")
686701
return;
687-
if (inst->getName().str().startswith("tensorflowReceive_"))
688-
return visitBuiltinTFReceiveInst(inst);
689-
if (inst->getName().str().startswith("tensorflowSend_"))
690-
return visitBuiltinTFSendInst(inst);
691702
if (inst->getName().str().startswith(
692703
"__tfop_tfc.makeIteratorGetNextWithDatasets"))
693704
return visitTFDataset(inst);
@@ -793,32 +804,30 @@ static void decodeShapeArray(SILInstruction *inst,
793804
}
794805
}
795806

796-
void TFGraphLowering::visitBuiltinTFSendInst(BuiltinInst *inst) {
807+
void TFGraphLowering::visitBuiltinSendToHostInst(SILTensorOpInfo &tfopInfo,
808+
BuiltinInst *inst) {
797809
auto &graphFn = getCurrentGraphFunction();
798810
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
799811
// the while cond function.
800812
if (!graphFn.shouldLowerEffectfulOps) return;
801813

802-
// Decode the tensor id from the builtin name.
803-
// Example: builtin "tensorflowSend_0"<TensorHandle<Float>>(...) : $()
804-
int tensorId = -1;
805-
{
806-
auto name = inst->getName().str();
807-
auto tensorIdStr = name.substr(strlen("tensorflowSend_"));
808-
bool isInt = llvm::to_integer(tensorIdStr, tensorId, 10);
809-
assert(isInt);
810-
}
814+
// Type check and process the parameters.
815+
// SendToHost has type <T> (input$T, tensorId$int, device$str) -> ()
816+
assert(inst->getNumResults() == 1);
817+
assert(inst->getNumOperands() == 3);
818+
assert(tfopInfo.isInput(0));
811819

812-
// Type check and process the parameter.
813820
TF_Output inputOp;
814821
TF_DataType inputType;
815822
{
816-
assert(inst->getNumOperands() == 1);
817823
auto operand = inst->getOperand(0);
818824
inputOp = getOperandValue(operand);
819825
if (!inputOp.oper) return; // Error occurred.
820826
inputType = getTensorFlowDataType(operand->getType(), inst->getLoc());
821827
}
828+
int tensorId = tfopInfo.getIntAttrOperand(1, "tensorId");
829+
assert(tfopInfo.getDeviceString() == DEFAULT_CPU_DEVICE &&
830+
"SendToHost must run on CPU device");
822831

823832
// Add enqueue to the local graph function, and the corresponding dequeue to
824833
// the top level function, so that caller can dequeue tensors via SessionRun.
@@ -886,26 +895,24 @@ void TFGraphLowering::visitBuiltinTFSendInst(BuiltinInst *inst) {
886895
}
887896
}
888897

889-
void TFGraphLowering::visitBuiltinTFReceiveInst(BuiltinInst *inst) {
898+
void TFGraphLowering::visitBuiltinRecvFromHostInst(SILTensorOpInfo &tfopInfo,
899+
BuiltinInst *inst) {
890900
auto &graphFn = getCurrentGraphFunction();
891901
// TODO(b/78472806): Add a more thorough and proper fix for effectful ops in
892902
// the while cond function.
893903
if (!graphFn.shouldLowerEffectfulOps) return;
894904

895-
// Decode the tensor id from the builtin name.
896-
// Example: builtin "tensorflowReceive_0"<TensorHandle<Float>>(...) : $()
897-
int tensorId = -1;
898-
{
899-
auto name = inst->getName().str();
900-
auto tensorIdStr = name.substr(strlen("tensorflowReceive_"));
901-
bool isInt = llvm::to_integer(tensorIdStr, tensorId, 10);
902-
assert(isInt);
903-
}
905+
// Type check and process the parameters.
906+
// recvFromHost has type <T> (tensorId$int, device$string) -> (T)
907+
assert(inst->getNumResults() == 1);
908+
assert(inst->getNumOperands() == 2);
909+
910+
int tensorId = tfopInfo.getIntAttrOperand(0, "tensorId");
911+
assert(tfopInfo.getDeviceString() == DEFAULT_CPU_DEVICE &&
912+
"SendToHost must run on CPU device");
904913

905-
// Type check and process the result.
906914
TF_DataType outputType;
907915
{
908-
assert(inst->getNumOperands() == 0);
909916
assert(inst->getNumResults() == 1);
910917
outputType =
911918
getTensorFlowDataType(inst->getResults()[0]->getType(), inst->getLoc());
@@ -991,6 +998,72 @@ void TFGraphLowering::visitBuiltinTFReceiveInst(BuiltinInst *inst) {
991998
}
992999
}
9931000

1001+
void TFGraphLowering::visitBuiltinD2DTensorRecvInst(SILTensorOpInfo &tfopInfo,
1002+
BuiltinInst *inst) {
1003+
// Signature: "__tfop_tfc.D2DTensorRecv,transferId,srcDevice,device"
1004+
assert(inst->getNumResults() == 1);
1005+
assert(inst->getNumOperands() == 3);
1006+
1007+
int transferId = tfopInfo.getIntAttrOperand(0, "transferId");
1008+
auto srcDevice = tfopInfo.getStringAttrOperand(1, "srcDevice");
1009+
auto thisDevice = thisDeviceTypeStr;
1010+
assert(tfopInfo.getDeviceString() == thisDevice);
1011+
auto opName = "tf_recv_" + llvm::itostr(transferId);
1012+
auto &graphFn = getCurrentGraphFunction();
1013+
auto *desc = TF_NewOperation(graphFn.getGraph(), "_Recv", opName.c_str());
1014+
TF_SetDevice(desc, thisDevice.c_str());
1015+
1016+
auto outputFromRecvVal = inst->getResults()[0];
1017+
auto tfType =
1018+
getTensorFlowDataType(outputFromRecvVal->getType(), inst->getLoc());
1019+
assert(tfType > 0);
1020+
TF_SetAttrType(desc, "tensor_type", (TF_DataType)tfType);
1021+
1022+
auto tensorName = "tensor_transfer_" + llvm::itostr(transferId);
1023+
TF_SetAttrString(desc, "tensor_name", tensorName.data(), tensorName.size());
1024+
TF_SetAttrString(desc, "send_device", srcDevice.data(), srcDevice.size());
1025+
TF_SetAttrInt(desc, "send_device_incarnation", DEVICE_INCARNATION_ID);
1026+
TF_SetAttrString(desc, "recv_device", thisDevice.data(), thisDevice.size());
1027+
auto *recvOp = graphFn.finishOp(desc, /*hasSideEffects*/ false,
1028+
/*isEligibleForTPU*/ false, status);
1029+
if (checkStatus(getUserSourceLocation(inst->getDebugLocation()))) return;
1030+
addValueMapping({inst, 0}, {recvOp, 0});
1031+
}
1032+
1033+
void TFGraphLowering::visitBuiltinD2DTensorSendInst(SILTensorOpInfo &tfopInfo,
1034+
BuiltinInst *inst) {
1035+
// Signature: "__tfop_tfc.D2DTensorSend,$in,transferId,destDevice,device"
1036+
assert(inst->getNumResults() == 1);
1037+
assert(inst->getNumOperands() == 4);
1038+
1039+
assert(tfopInfo.isInput(0));
1040+
auto inputToSendVal = inst->getOperand(0);
1041+
auto inputToSendOp = getOperandValue(inputToSendVal);
1042+
if (!inputToSendOp.oper) return; // Error occurred.
1043+
1044+
int transferId = tfopInfo.getIntAttrOperand(1, "transferId");
1045+
auto destDevice = tfopInfo.getStringAttrOperand(2, "destDevice");
1046+
auto thisDevice = thisDeviceTypeStr;
1047+
assert(tfopInfo.getDeviceString() == thisDevice);
1048+
auto opName = "tf_send_" + llvm::itostr(transferId);
1049+
auto &graphFn = getCurrentGraphFunction();
1050+
auto *desc = TF_NewOperation(graphFn.getGraph(), "_Send", opName.c_str());
1051+
TF_SetDevice(desc, thisDevice.c_str());
1052+
TF_AddInput(desc, inputToSendOp);
1053+
auto tfType =
1054+
getTensorFlowDataType(inputToSendVal->getType(), inst->getLoc());
1055+
assert(tfType > 0);
1056+
TF_SetAttrType(desc, "T", (TF_DataType)tfType);
1057+
auto tensorName = "tensor_transfer_" + llvm::itostr(transferId);
1058+
TF_SetAttrString(desc, "tensor_name", tensorName.data(), tensorName.size());
1059+
TF_SetAttrString(desc, "send_device", thisDevice.data(), thisDevice.size());
1060+
TF_SetAttrInt(desc, "send_device_incarnation", DEVICE_INCARNATION_ID);
1061+
TF_SetAttrString(desc, "recv_device", destDevice.data(), destDevice.size());
1062+
/* sendOp = */ graphFn.finishOp(desc, /*hasSideEffects*/ true,
1063+
/*isEligibleForTPU*/ false, status);
1064+
checkStatus(getUserSourceLocation(inst->getDebugLocation()));
1065+
}
1066+
9941067
void TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
9951068
// FIXME: Also support dataset/iterator outside of TPU context.
9961069
if(!configuration.isTPUEnabled() || !configuration.isTPUInfeedEnabled) {
@@ -1109,6 +1182,20 @@ void TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
11091182
///
11101183
void TFGraphLowering::visitTFOpInst(BuiltinInst *inst) {
11111184
SILTensorOpInfo tfopInfo = SILTensorOpInfo::decode(inst).getValue();
1185+
1186+
// Swift host <-> TF device sends/recvs.
1187+
if (tfopInfo.opName == "tfc.RecvFromHost")
1188+
return visitBuiltinRecvFromHostInst(tfopInfo, inst);
1189+
else if (tfopInfo.opName == "tfc.SendToHost")
1190+
return visitBuiltinSendToHostInst(tfopInfo, inst);
1191+
1192+
// Device-to-device sends/recvs.
1193+
if (tfopInfo.opName == "tfc.D2DTensorRecv")
1194+
return visitBuiltinD2DTensorRecvInst(tfopInfo, inst);
1195+
else if (tfopInfo.opName == "tfc.D2DTensorSend")
1196+
return visitBuiltinD2DTensorSendInst(tfopInfo, inst);
1197+
1198+
// Handle other TF ops.
11121199
auto &graphFn = getCurrentGraphFunction();
11131200

11141201
// The name label we put on the op is summarized from the "stack trace" of
@@ -1214,6 +1301,9 @@ void TFGraphLowering::visitTFOpInst(BuiltinInst *inst) {
12141301
if (name != DEVICE_ATTR) {
12151302
TF_SetAttrString(op, name.c_str(), value.data(), value.size());
12161303
} else {
1304+
if (value.str() == ALL_DEVICES) {
1305+
value = thisDeviceTypeStr;
1306+
}
12171307
if (value.str() != DEFAULT_TPU_DEVICE) {
12181308
TF_SetDevice(op, value.str().c_str());
12191309
} else {
@@ -2127,7 +2217,7 @@ bool TFGraphLowering::buildGraphNodesForTopLevelFunctionCall(
21272217
assert(thisDeviceType == DeviceType::TPU);
21282218
markNodeAsTPUReplicated(funcDesc);
21292219
} else {
2130-
TF_SetDevice(funcDesc, getDeviceString(thisDeviceType).c_str());
2220+
TF_SetDevice(funcDesc, thisDeviceTypeStr.c_str());
21312221
}
21322222

21332223
// FIXME: Revisit how to enable infeed outside the context of dataset /
@@ -2422,24 +2512,6 @@ static std::vector<char> serializeGraphProtoBuf(SILFunction &SILFn,
24222512
return std::vector<char>(bufPtr, bufPtr + buffer->length);
24232513
}
24242514

2425-
namespace {
2426-
class GraphPartitioner : public SILInstructionVisitor<GraphPartitioner> {
2427-
SILFunction &srcFn;
2428-
const GraphGlobalConfiguration &configuration;
2429-
2430-
public:
2431-
GraphPartitioner(SILFunction &srcFn,
2432-
const GraphGlobalConfiguration &configuration)
2433-
: srcFn(srcFn), configuration(configuration) {}
2434-
2435-
/// Returns a function extracted from `fn`, specialized on `deviceType`.
2436-
SILFunction *extractFunctionForDevice(DeviceType deviceType) {
2437-
// FIXME: Add real impl.
2438-
return &srcFn;
2439-
}
2440-
};
2441-
} // end anonymous namespace
2442-
24432515
#endif // SWIFT_ENABLE_TENSORFLOW
24442516

24452517
/// Gets a function name that can be used as a TF op name.
@@ -2489,10 +2561,11 @@ std::vector<char> tf::lowerTFGraph(
24892561
TF_DeleteGraph(resultGraph);
24902562
};
24912563

2492-
GraphPartitioner partitioner(*fn, configuration);
2564+
DevicePartitioner partitioner(*fn, configuration);
24932565
entryFnBaseName = getTFCompatibleFuncName(fn);
24942566
unsigned helperFuncId = 0;
24952567
for (const auto deviceType : configuration.usedDeviceTypes) {
2568+
assert(deviceType != DeviceType::ALL);
24962569
auto *perDeviceFn = partitioner.extractFunctionForDevice(deviceType);
24972570
bool isPrimaryFn = deviceType == configuration.deviceType;
24982571

@@ -2521,12 +2594,22 @@ std::vector<char> tf::lowerTFGraph(
25212594
// The func op type is `fnName`, with the caller node name being
25222595
// based on `funcNodeBaseName`.
25232596
std::string funcNodeBaseName = entryFnBaseName;
2597+
if (!isPrimaryFn) {
2598+
funcNodeBaseName += "_helper_" + llvm::utostr(helperFuncId);
2599+
++helperFuncId;
2600+
assert(inputTypes.empty());
2601+
assert(outputTypes.empty());
2602+
}
25242603

25252604
// Create the graph function for the top level code.
25262605
if (graphGen.buildGraphNodesForTopLevelFunctionCall(
25272606
fnName.str(), funcNodeBaseName, isPrimaryFn, inputTypes,
25282607
outputTypes))
25292608
return {};
2609+
2610+
// Remove the partitioned function so it doesn't go through the normal
2611+
// compiler flow.
2612+
perDeviceFn->getModule().eraseFunction(perDeviceFn);
25302613
}
25312614

25322615
// Ok, we're done! Serialize the resulting graph to a protobuf and return it.

0 commit comments

Comments
 (0)