Skip to content

Commit f8ad6ea

Browse files
committed
[mlir] Refactor transform dialect's gpu block func
This revision refactors gpu block id generator lambda that is used in the transform dialect. It removes the lambda and instead uses a static function that's name generateGpuBlockIds. It also simplifies arguments that the function takes. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134724
1 parent f4991bf commit f8ad6ea

File tree

2 files changed

+21
-18
lines changed

2 files changed

+21
-18
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ FailureOr<Operation *> fuseElementwiseOps(RewriterBase &rewriter,
132132
/// dim sizes are currently not supported.
133133
LogicalResult rewriteTopLevelForeachThreadToGpuBlocks(
134134
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
135-
function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
135+
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
136136
SmallVector<Value> &)>
137137
blockIdGenerator,
138138
SmallVector<int64_t> &gridDims);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ transform::MapNestedForeachThreadToGpuThreads::applyToOne(
13741374

13751375
LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
13761376
RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
1377-
function_ref<void(Operation *, const SmallVector<int64_t> &, IndexType,
1377+
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
13781378
SmallVector<Value> &)>
13791379
blockIdGenerator,
13801380
SmallVector<int64_t> &gridDims) {
@@ -1397,9 +1397,8 @@ LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
13971397
for (OpFoldResult ofr : *potentialGridDim)
13981398
gridDims.push_back(getConstantIntValue(ofr).value());
13991399

1400-
IndexType indexType = rewriter.getIndexType();
14011400
SmallVector<Value> blockOps;
1402-
blockIdGenerator(foreachThreadOp, gridDims, indexType, blockOps);
1401+
blockIdGenerator(rewriter, foreachThreadOp, blockOps);
14031402

14041403
// Step 1. Move the body of foreachThreadOp.
14051404
// Erase the terminator first, it will not be used since we are on buffers.
@@ -1485,6 +1484,23 @@ createGpuLaunch(RewriterBase &rewriter, Location loc,
14851484
return launchOp;
14861485
}
14871486

1487+
/// This is an helper that is only used in
1488+
/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id
1489+
static void generateGpuBlockIds(RewriterBase &rewriter,
1490+
scf::ForeachThreadOp foreachOp,
1491+
SmallVector<Value> &blockOps) {
1492+
Location loc = foreachOp->getLoc();
1493+
OpBuilder::InsertionGuard guard(rewriter);
1494+
rewriter.setInsertionPoint(foreachOp);
1495+
IndexType indexType = rewriter.getIndexType();
1496+
SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1497+
gpu::Dimension::z};
1498+
for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
1499+
blockOps.push_back(
1500+
rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
1501+
}
1502+
}
1503+
14881504
DiagnosedSilenceableFailure
14891505
transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
14901506
Operation *target, SmallVectorImpl<Operation *> &results,
@@ -1520,22 +1536,9 @@ transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
15201536
dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
15211537
}
15221538

1523-
auto generateBlocks = [&](Operation *op, const SmallVector<int64_t> &gridDims,
1524-
IndexType indexType, SmallVector<Value> &blockOps) {
1525-
Location loc = op->getLoc();
1526-
OpBuilder::InsertionGuard guard(rewriter);
1527-
rewriter.setInsertionPoint(op);
1528-
SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1529-
gpu::Dimension::z};
1530-
for (int64_t idx : llvm::seq<int64_t>(0, gridDims.size())) {
1531-
blockOps.push_back(
1532-
rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
1533-
}
1534-
};
1535-
15361539
SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
15371540
if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
1538-
rewriter, topLevelForeachThreadOp, generateBlocks, gridDim)))
1541+
rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim)))
15391542
return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
15401543

15411544
if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],

0 commit comments

Comments
 (0)