@@ -1374,7 +1374,7 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
1374
1374
1375
1375
LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks (
1376
1376
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
1377
- function_ref<void (Operation *, const SmallVector< int64_t > &, IndexType ,
1377
+ function_ref<void (RewriterBase &, scf::ForeachThreadOp ,
1378
1378
SmallVector<Value> &)>
1379
1379
blockIdGenerator,
1380
1380
SmallVector<int64_t> &gridDims) {
@@ -1397,9 +1397,8 @@ LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
1397
1397
for (OpFoldResult ofr : *potentialGridDim)
1398
1398
gridDims.push_back (getConstantIntValue (ofr).value ());
1399
1399
1400
- IndexType indexType = rewriter.getIndexType ();
1401
1400
SmallVector<Value> blockOps;
1402
- blockIdGenerator (foreachThreadOp, gridDims, indexType , blockOps);
1401
+ blockIdGenerator (rewriter, foreachThreadOp , blockOps);
1403
1402
1404
1403
// Step 1. Move the body of foreachThreadOp.
1405
1404
// Erase the terminator first, it will not be used since we are on buffers.
@@ -1485,6 +1484,23 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
1485
1484
return launchOp;
1486
1485
}
1487
1486
1487
+ // / This is an helper that is only used in
1488
+ // / rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id
1489
+ static void generateGpuBlockIds (RewriterBase &rewriter,
1490
+ scf::ForeachThreadOp foreachOp,
1491
+ SmallVector<Value> &blockOps) {
1492
+ Location loc = foreachOp->getLoc ();
1493
+ OpBuilder::InsertionGuard guard (rewriter);
1494
+ rewriter.setInsertionPoint (foreachOp);
1495
+ IndexType indexType = rewriter.getIndexType ();
1496
+ SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1497
+ gpu::Dimension::z};
1498
+ for (int64_t idx : llvm::seq<int64_t >(0 , gpuDims.size ())) {
1499
+ blockOps.push_back (
1500
+ rewriter.create <gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
1501
+ }
1502
+ }
1503
+
1488
1504
DiagnosedSilenceableFailure
1489
1505
transform::MapNestedForeachThreadToGpuBlocks::applyToOne (
1490
1506
Operation *target, SmallVectorImpl<Operation *> &results,
@@ -1520,22 +1536,9 @@ transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
1520
1536
dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
1521
1537
}
1522
1538
1523
- auto generateBlocks = [&](Operation *op, const SmallVector<int64_t > &gridDims,
1524
- IndexType indexType, SmallVector<Value> &blockOps) {
1525
- Location loc = op->getLoc ();
1526
- OpBuilder::InsertionGuard guard (rewriter);
1527
- rewriter.setInsertionPoint (op);
1528
- SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1529
- gpu::Dimension::z};
1530
- for (int64_t idx : llvm::seq<int64_t >(0 , gridDims.size ())) {
1531
- blockOps.push_back (
1532
- rewriter.create <gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
1533
- }
1534
- };
1535
-
1536
1539
SmallVector<int64_t > gridDim = extractFromI64ArrayAttr (getGridDim ());
1537
1540
if (failed (mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks (
1538
- rewriter, topLevelForeachThreadOp, generateBlocks , gridDim)))
1541
+ rewriter, topLevelForeachThreadOp, generateGpuBlockIds , gridDim)))
1539
1542
return DiagnosedSilenceableFailure (reportUnknownTransformError (target));
1540
1543
1541
1544
if (failed (alterGpuLaunch (rewriter, gpuLaunch, gridDim[0 ], gridDim[1 ],
0 commit comments