Skip to content

Commit 182245a

Browse files
authored
Add support for tfc.makeIteratorGetNextWithDatasets from graph operations (#18223)
1 parent f7f62df commit 182245a

File tree

2 files changed

+137
-13
lines changed

2 files changed

+137
-13
lines changed

lib/SILOptimizer/Mandatory/TFLowerGraph.cpp

Lines changed: 101 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
348348
struct DatasetCreationContext {
349349
/// The instruction corresponding to the builtin
350350
/// tfc.makeIteratorGetNextWithDatasets.
351-
BuiltinInst *datasetInst = nullptr;
351+
SILInstruction *datasetInst = nullptr;
352352

353353
/// Specifies which (hard-coded) iterator stack to create.
354354
enum DataSource {
@@ -376,7 +376,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
376376
std::vector<TF_DataType> infeedInputDtypes;
377377

378378
public:
379-
DatasetCreationContext(BuiltinInst *datasetInst, DataSource dataSource,
379+
DatasetCreationContext(SILInstruction *datasetInst, DataSource dataSource,
380380
StringRef filePath, int batchSize,
381381
ArrayRef<int64_t> dims, ArrayRef<int> numDims,
382382
ArrayRef<TF_DataType> dTypes)
@@ -584,7 +584,11 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
584584
///
585585
/// FIXME: Dissolve this builtin into a set of finer-grained, composable
586586
/// features.
587-
GLStatus visitTFDataset(BuiltinInst *inst);
587+
template <typename Inst>
588+
GLStatus createDatasetCreationContext(
589+
Inst *inst, std::vector<SILOpResult>& results);
590+
template <typename Inst>
591+
GLStatus visitTFDataset(Inst *inst);
588592
bool createDatasetIteratorNodesWithInfeedEnqueue();
589593

590594
GLStatus visitTFOpInst(BuiltinInst *inst);
@@ -1538,17 +1542,72 @@ GLStatus TFGraphLowering::visitGraphOpD2DTensorSendInst(
15381542
}
15391543
}
15401544

1541-
GLStatus TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
1542-
// FIXME: Also support dataset/iterator outside of TPU context.
1543-
if (thisDeviceType != DeviceType::TPU || !deviceInfo.isTPUInfeedEnabled) {
1544-
internalError(
1545-
getUserSourceLocation(inst->getDebugLocation()),
1546-
"Builtin tfc.makeIteratorGetNextWithDatasets can only be used when "
1547-
"generating TPU TF graphs with infeed support.",
1548-
diag::tfop_invalid_tfop);
1545+
template <>
1546+
GLStatus TFGraphLowering::createDatasetCreationContext(
1547+
GraphOperationInst *inst, std::vector<SILOpResult> &outputResults) {
1548+
GraphOperationInfo graphOpInfo(inst);
1549+
// Type check and process the first attribute: dataSource.
1550+
auto dataSource =
1551+
llvm::StringSwitch<DatasetCreationContext::DataSource>(
1552+
graphOpInfo.getStringAttr(0, "dataSource"))
1553+
.Case("fake", DatasetCreationContext::FAKE)
1554+
.Case("mnist", DatasetCreationContext::MNIST)
1555+
.Default(DatasetCreationContext::IMAGENET);
1556+
1557+
// Type check and process the second attribute: filePath.
1558+
// When dataSource is FAKE, this attribute needs to be present, but is not
1559+
// used.
1560+
StringRef filePath = (dataSource == DatasetCreationContext::FAKE)
1561+
? ""
1562+
: graphOpInfo.getStringAttr(1, "filePath");
1563+
// Type check and process the third attribute: batchSize
1564+
int batchSize = graphOpInfo.getIntAttr(2, "batchSize");
1565+
1566+
// Type check and process the fourth attribute: outputShapes
1567+
auto attr = inst->getAttribute(3);
1568+
SmallVector<int64_t, 8> dims;
1569+
SmallVector<int, 3> numDims;
1570+
SmallVector<int64_t*, 8> dimPtrs;
1571+
decodeShapeArray(attr.value, dims, numDims, dimPtrs);
1572+
1573+
// Even when this built-in returns multiple tensors, they are always presented
1574+
// by a single tuple.
1575+
std::vector<TF_DataType> outputTypes;
1576+
for (const SILValue &result : inst->getResults()) {
1577+
auto outputType = result->getType().getASTType();
1578+
auto tfType = getTFDataTypeFromTensorGenericType(outputType);
1579+
if (tfType == 0) {
1580+
internalError(getUserSourceLocation(inst->getDebugLocation()),
1581+
"Encountered a non-tensor type during dataset creation.",
1582+
diag::tfop_invalid_tfop);
1583+
return GLStatus::Error;
1584+
}
1585+
outputTypes.push_back(static_cast<TF_DataType>(tfType));
1586+
outputResults.emplace_back(result, 0);
1587+
}
1588+
1589+
if (outputTypes.size() != numDims.size()) {
1590+
internalError(getUserSourceLocation(inst->getDebugLocation()),
1591+
"Must specify the same number of shapes and output tensors.",
1592+
diag::tfop_invalid_tfop);
15491593
return GLStatus::Error;
15501594
}
15511595

1596+
// Defer the creation of the dataset / iterator related nodes, along with the
1597+
// associated infeed enqueue till the creation of top level function
1598+
// nodes. Here we fill in the dataset creation context, and then create an
1599+
// infeed dequeue node to feed the user(s) of `inst`.
1600+
datasetCreationContext.reset(new DatasetCreationContext(
1601+
inst, dataSource, filePath, batchSize, dims, numDims, outputTypes));
1602+
1603+
1604+
return GLStatus::Success;
1605+
}
1606+
1607+
// TODO: Remove this version when graph op takes over completely.
1608+
template <>
1609+
GLStatus TFGraphLowering::createDatasetCreationContext(
1610+
BuiltinInst *inst, std::vector<SILOpResult> &outputResults) {
15521611
SILTensorOpInfo tfopInfo = SILTensorOpInfo::decode(inst).getValue();
15531612
// Type check and process the first attribute: dataSource.
15541613
DatasetCreationContext::DataSource dataSource;
@@ -1636,6 +1695,31 @@ GLStatus TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
16361695
datasetCreationContext.reset(new DatasetCreationContext(
16371696
inst, dataSource, filePath, batchSize, dims, numDims, outputTypes));
16381697

1698+
for (auto i : indices(outputTypes)) {
1699+
outputResults.emplace_back(inst, i);
1700+
}
1701+
1702+
return GLStatus::Success;
1703+
}
1704+
1705+
template <typename Inst>
1706+
GLStatus TFGraphLowering::visitTFDataset(Inst *inst) {
1707+
// FIXME: Also support dataset/iterator outside of TPU context.
1708+
if (thisDeviceType != DeviceType::TPU || !deviceInfo.isTPUInfeedEnabled) {
1709+
internalError(
1710+
getUserSourceLocation(inst->getDebugLocation()),
1711+
"Builtin tfc.makeIteratorGetNextWithDatasets can only be used when "
1712+
"generating TPU TF graphs with infeed support.",
1713+
diag::tfop_invalid_tfop);
1714+
return GLStatus::Error;
1715+
}
1716+
1717+
std::vector<SILOpResult> outputResults;
1718+
GLStatus datasetStatus = createDatasetCreationContext(inst, outputResults);
1719+
if (datasetStatus != GLStatus::Success) {
1720+
// Error is already recorded.
1721+
return datasetStatus;
1722+
}
16391723
{
16401724
auto &graphFn = getCurrentGraphFunction();
16411725
auto *desc = TF_NewOperation(graphFn.getGraph(), "InfeedDequeueTuple",
@@ -1646,8 +1730,8 @@ GLStatus TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
16461730
if (checkStatus(getUserSourceLocation(inst->getDebugLocation())))
16471731
return GLStatus::Error;
16481732

1649-
for (int i = 0, n = outputTypes.size(); i != n; ++i) {
1650-
addValueMapping({inst, i}, {dequeue, i});
1733+
for (int i = 0, n = outputResults.size(); i != n; ++i) {
1734+
addValueMapping(outputResults[i], {dequeue, i});
16511735
}
16521736
}
16531737
return GLStatus::Success;
@@ -1765,6 +1849,10 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) {
17651849
else if (opName == "tfc.D2DTensorSend")
17661850
return visitGraphOpD2DTensorSendInst(decoder);
17671851

1852+
// Dataset creation
1853+
if (opName.startswith("tfc.makeIteratorGetNextWithDatasets"))
1854+
return visitTFDataset<GraphOperationInst>(inst);
1855+
17681856
auto &graphFn = getCurrentGraphFunction();
17691857

17701858
// The name label we put on the op is summarized from the "stack trace" of

test/TensorFlow/dataset.swift

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: %target-swift-frontend -Xllvm -tf-dump-intermediates -Xllvm -tf-dump-graph -O -emit-sil -verify %s | %FileCheck %s
2+
// RUN: %target-swift-frontend -Xllvm -tf-dump-intermediates -Xllvm -tf-dump-graph -Xllvm -tf-strict-deabstraction -O -emit-sil -verify %s | %FileCheck %s --check-prefix=STRICTDA
23
import TensorFlow
34

45
public func testDatasetWithFakeData() {
@@ -19,6 +20,12 @@ public func testDatasetWithFakeData() {
1920
// CHECK: [[RESULT:%[0-9]+]] = builtin "__tfop_Add,$in,$in,T,__device"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
2021
// CHECK-NEXT: return [[RESULT]] : $TensorHandle<Float>
2122

23+
// STRICTDA-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithFakeData{{.*}}
24+
// STRICTDA: bb0:
25+
// STRICTDA: [[GETNEXT:%[0-9]+]] = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>
26+
// STRICTDA: [[RESULT:%[0-9]+]] = graph_op "Add,i,i"([[GETNEXT]] : $TensorHandle<Float>, {{.*}} : $TensorHandle<Float>
27+
// STRICTDA-NEXT: return [[RESULT]] : $TensorHandle<Float>
28+
2229
public func testDatasetWithMNIST() {
2330
TensorFlow.enableTPU(infeed: true)
2431
let (images1, labels1): (TensorHandle<Float>, TensorHandle<Int32>) = #tfop(
@@ -47,6 +54,14 @@ public func testDatasetWithMNIST() {
4754
// CHECK: [[RESULT:%.*]] = tuple ({{.*}} : $TensorHandle<{{.*}}>, {{.*}} : $TensorHandle<{{.*}}>)
4855
// CHECK-NEXT: return [[RESULT]] : $(TensorHandle<{{.*}}>, TensorHandle<{{.*}}>)
4956

57+
// STRICTDA-LABEL: --- TFPartition Accelerator Result: {{.*}}testDatasetWithMNIST{{.*}}
58+
// STRICTDA: bb0:
59+
// STRICTDA: (%0, %1) = graph_op "tfc.makeIteratorGetNextWithDatasets{{.*}} : $TensorHandle<Float>, $TensorHandle<Int32>
60+
// STRICTDA: graph_op "Add,i,i"(
61+
// STRICTDA: graph_op "Add,i,i"(
62+
// The operands can appear in arbitrary order here.
63+
// STRICTDA: [[RESULT:%.*]] = tuple ({{.*}} : $TensorHandle<{{.*}}>, {{.*}} : $TensorHandle<{{.*}}>)
64+
// STRICTDA-NEXT: return [[RESULT]] : $(TensorHandle<{{.*}}>, TensorHandle<{{.*}}>)
5065

5166
// Creates a dataset, which produces one float scalar value in each get next
5267
// call.
@@ -80,6 +95,16 @@ public func createMockDataSet() -> VariantHandle {
8095
// CHECK-NEXT: type: DT_VARIANT
8196
// CHECK-NEXT: }
8297

98+
// STRICTDA-LABEL: --- TFPartition Accelerator Result: {{.*}}createMockDataSet{{.*}}
99+
// STRICTDA-NOT: node {
100+
// STRICTDA: function {
101+
// STRICTDA-NEXT: signature {
102+
// STRICTDA-NEXT: name: "{{.*}}createMockDataSet{{.*}}.tf_only"
103+
// STRICTDA: output_arg {
104+
// STRICTDA-NEXT: name: "op_createmockdataset{{.*}}"
105+
// STRICTDA-NEXT: type: DT_VARIANT
106+
// STRICTDA-NEXT: }
107+
83108
// TODO: support taking the following function typed parameter.
84109
// _ datasetCreator : @convention(tensorflow) () -> VariantHandle
85110
// TODO(SR-8117): Support "" for container and shared_name.
@@ -125,3 +150,14 @@ public func model() {
125150
// CHECK: function {
126151
// CHECK: function {
127152
// CHECK-NOT: function {
153+
154+
// STRICTDA-LABEL: --- TFPartition Accelerator Result: {{.*}}model{{.*}}
155+
// STRICTDA: node {
156+
// STRICTDA-NEXT: name: "{{.*}}model{{.*}}"
157+
// STRICTDA-NEXT: op: "{{.*}}model{{.*}}.tf_CPU.device_partition"
158+
// STRICTDA: node {
159+
// STRICTDA-NEXT: name: "tfc_output_0_{{.*}}model{{.*}}"
160+
161+
// STRICTDA: function {
162+
// STRICTDA: function {
163+
// STRICTDA-NOT: function {

0 commit comments

Comments
 (0)