@@ -348,7 +348,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
348
348
struct DatasetCreationContext {
349
349
// / The instruction corresponding to the builtin
350
350
// / tfc.makeIteratorGetNextWithDatasets.
351
- BuiltinInst *datasetInst = nullptr ;
351
+ SILInstruction *datasetInst = nullptr ;
352
352
353
353
// / Specifies which (hard-coded) iterator stack to create.
354
354
enum DataSource {
@@ -376,7 +376,7 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
376
376
std::vector<TF_DataType> infeedInputDtypes;
377
377
378
378
public:
379
- DatasetCreationContext (BuiltinInst *datasetInst, DataSource dataSource,
379
+ DatasetCreationContext (SILInstruction *datasetInst, DataSource dataSource,
380
380
StringRef filePath, int batchSize,
381
381
ArrayRef<int64_t > dims, ArrayRef<int > numDims,
382
382
ArrayRef<TF_DataType> dTypes)
@@ -584,7 +584,11 @@ struct TFGraphLowering : public SILInstructionVisitor<TFGraphLowering, GLStatus>
584
584
// /
585
585
// / FIXME: Dissolve this builtin into a set of finer-grained, composable
586
586
// / features.
587
- GLStatus visitTFDataset (BuiltinInst *inst);
587
+ template <typename Inst>
588
+ GLStatus createDatasetCreationContext (
589
+ Inst *inst, std::vector<SILOpResult>& results);
590
+ template <typename Inst>
591
+ GLStatus visitTFDataset (Inst *inst);
588
592
bool createDatasetIteratorNodesWithInfeedEnqueue ();
589
593
590
594
GLStatus visitTFOpInst (BuiltinInst *inst);
@@ -1538,17 +1542,72 @@ GLStatus TFGraphLowering::visitGraphOpD2DTensorSendInst(
1538
1542
}
1539
1543
}
1540
1544
1541
- GLStatus TFGraphLowering::visitTFDataset (BuiltinInst *inst) {
1542
- // FIXME: Also support dataset/iterator outside of TPU context.
1543
- if (thisDeviceType != DeviceType::TPU || !deviceInfo.isTPUInfeedEnabled ) {
1544
- internalError (
1545
- getUserSourceLocation (inst->getDebugLocation ()),
1546
- " Builtin tfc.makeIteratorGetNextWithDatasets can only be used when "
1547
- " generating TPU TF graphs with infeed support." ,
1548
- diag::tfop_invalid_tfop);
1545
+ template <>
1546
+ GLStatus TFGraphLowering::createDatasetCreationContext (
1547
+ GraphOperationInst *inst, std::vector<SILOpResult> &outputResults) {
1548
+ GraphOperationInfo graphOpInfo (inst);
1549
+ // Type check and process the first attribute: dataSource.
1550
+ auto dataSource =
1551
+ llvm::StringSwitch<DatasetCreationContext::DataSource>(
1552
+ graphOpInfo.getStringAttr (0 , " dataSource" ))
1553
+ .Case (" fake" , DatasetCreationContext::FAKE)
1554
+ .Case (" mnist" , DatasetCreationContext::MNIST)
1555
+ .Default (DatasetCreationContext::IMAGENET);
1556
+
1557
+ // Type check and process the second attribute: filePath.
1558
+ // When dataSource is FAKE, this attribute needs to be present, but is not
1559
+ // used.
1560
+ StringRef filePath = (dataSource == DatasetCreationContext::FAKE)
1561
+ ? " "
1562
+ : graphOpInfo.getStringAttr (1 , " filePath" );
1563
+ // Type check and process the third attribute: batchSize
1564
+ int batchSize = graphOpInfo.getIntAttr (2 , " batchSize" );
1565
+
1566
+ // Type check and process the fourth attribute: outputShapes
1567
+ auto attr = inst->getAttribute (3 );
1568
+ SmallVector<int64_t , 8 > dims;
1569
+ SmallVector<int , 3 > numDims;
1570
+ SmallVector<int64_t *, 8 > dimPtrs;
1571
+ decodeShapeArray (attr.value , dims, numDims, dimPtrs);
1572
+
1573
+ // Even when this built-in returns multiple tensors, they are always presented
1574
+ // by a single tuple.
1575
+ std::vector<TF_DataType> outputTypes;
1576
+ for (const SILValue &result : inst->getResults ()) {
1577
+ auto outputType = result->getType ().getASTType ();
1578
+ auto tfType = getTFDataTypeFromTensorGenericType (outputType);
1579
+ if (tfType == 0 ) {
1580
+ internalError (getUserSourceLocation (inst->getDebugLocation ()),
1581
+ " Encountered a non-tensor type during dataset creation." ,
1582
+ diag::tfop_invalid_tfop);
1583
+ return GLStatus::Error;
1584
+ }
1585
+ outputTypes.push_back (static_cast <TF_DataType>(tfType));
1586
+ outputResults.emplace_back (result, 0 );
1587
+ }
1588
+
1589
+ if (outputTypes.size () != numDims.size ()) {
1590
+ internalError (getUserSourceLocation (inst->getDebugLocation ()),
1591
+ " Must specify the same number of shapes and output tensors." ,
1592
+ diag::tfop_invalid_tfop);
1549
1593
return GLStatus::Error;
1550
1594
}
1551
1595
1596
+ // Defer the creation of the dataset / iterator related nodes, along with the
1597
+ // associated infeed enqueue till the creation of top level function
1598
+ // nodes. Here we fill in the dataset creation context, and then create an
1599
+ // infeed dequeue node to feed the user(s) of `inst`.
1600
+ datasetCreationContext.reset (new DatasetCreationContext (
1601
+ inst, dataSource, filePath, batchSize, dims, numDims, outputTypes));
1602
+
1603
+
1604
+ return GLStatus::Success;
1605
+ }
1606
+
1607
+ // TODO: Remove this version when graph op takes over completely.
1608
+ template <>
1609
+ GLStatus TFGraphLowering::createDatasetCreationContext (
1610
+ BuiltinInst *inst, std::vector<SILOpResult> &outputResults) {
1552
1611
SILTensorOpInfo tfopInfo = SILTensorOpInfo::decode (inst).getValue ();
1553
1612
// Type check and process the first attribute: dataSource.
1554
1613
DatasetCreationContext::DataSource dataSource;
@@ -1636,6 +1695,31 @@ GLStatus TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
1636
1695
datasetCreationContext.reset (new DatasetCreationContext (
1637
1696
inst, dataSource, filePath, batchSize, dims, numDims, outputTypes));
1638
1697
1698
+ for (auto i : indices (outputTypes)) {
1699
+ outputResults.emplace_back (inst, i);
1700
+ }
1701
+
1702
+ return GLStatus::Success;
1703
+ }
1704
+
1705
+ template <typename Inst>
1706
+ GLStatus TFGraphLowering::visitTFDataset (Inst *inst) {
1707
+ // FIXME: Also support dataset/iterator outside of TPU context.
1708
+ if (thisDeviceType != DeviceType::TPU || !deviceInfo.isTPUInfeedEnabled ) {
1709
+ internalError (
1710
+ getUserSourceLocation (inst->getDebugLocation ()),
1711
+ " Builtin tfc.makeIteratorGetNextWithDatasets can only be used when "
1712
+ " generating TPU TF graphs with infeed support." ,
1713
+ diag::tfop_invalid_tfop);
1714
+ return GLStatus::Error;
1715
+ }
1716
+
1717
+ std::vector<SILOpResult> outputResults;
1718
+ GLStatus datasetStatus = createDatasetCreationContext (inst, outputResults);
1719
+ if (datasetStatus != GLStatus::Success) {
1720
+ // Error is already recorded.
1721
+ return datasetStatus;
1722
+ }
1639
1723
{
1640
1724
auto &graphFn = getCurrentGraphFunction ();
1641
1725
auto *desc = TF_NewOperation (graphFn.getGraph (), " InfeedDequeueTuple" ,
@@ -1646,8 +1730,8 @@ GLStatus TFGraphLowering::visitTFDataset(BuiltinInst *inst) {
1646
1730
if (checkStatus (getUserSourceLocation (inst->getDebugLocation ())))
1647
1731
return GLStatus::Error;
1648
1732
1649
- for (int i = 0 , n = outputTypes .size (); i != n; ++i) {
1650
- addValueMapping ({inst, i} , {dequeue, i});
1733
+ for (int i = 0 , n = outputResults .size (); i != n; ++i) {
1734
+ addValueMapping (outputResults[i] , {dequeue, i});
1651
1735
}
1652
1736
}
1653
1737
return GLStatus::Success;
@@ -1765,6 +1849,10 @@ GLStatus TFGraphLowering::visitGraphOperationInst(GraphOperationInst *inst) {
1765
1849
else if (opName == " tfc.D2DTensorSend" )
1766
1850
return visitGraphOpD2DTensorSendInst (decoder);
1767
1851
1852
+ // Dataset creation
1853
+ if (opName.startswith (" tfc.makeIteratorGetNextWithDatasets" ))
1854
+ return visitTFDataset<GraphOperationInst>(inst);
1855
+
1768
1856
auto &graphFn = getCurrentGraphFunction ();
1769
1857
1770
1858
// The name label we put on the op is summarized from the "stack trace" of
0 commit comments