Skip to content

Commit 0eabb88

Browse files
committed
[mlir][gpu] NFC let user pick the threadID values when distributing foreach_thread
Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D144219
1 parent e3a88a4 commit 0eabb88

File tree

2 files changed

+25
-15
lines changed

2 files changed

+25
-15
lines changed

mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ namespace gpu {
4242
/// supported. Dynamic block dim sizes are currently not supported.
4343
DiagnosedSilenceableFailure mapNestedForeachToThreadsImpl(
4444
RewriterBase &rewriter, Operation *target,
45-
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
46-
std::optional<TransformOpInterface> transformOp,
45+
const SmallVectorImpl<int64_t> &blockDim,
46+
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
47+
SmallVectorImpl<Value> &)>
48+
threadIdGenerator,
49+
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
4750
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes);
4851

4952
/// Maps the top level `scf.foreach_thread` op to GPU Thread Blocks. Mapping is

mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,11 @@ static DiagnosedSilenceableFailure rewriteOneForeachThreadToGpuThreads(
502502

503503
DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
504504
RewriterBase &rewriter, Operation *target,
505-
const SmallVectorImpl<int64_t> &blockDim, bool syncAfterDistribute,
506-
std::optional<TransformOpInterface> transformOp,
505+
const SmallVectorImpl<int64_t> &blockDim,
506+
function_ref<void(RewriterBase &, scf::ForeachThreadOp,
507+
SmallVectorImpl<Value> &)>
508+
threadIdGenerator,
509+
bool syncAfterDistribute, std::optional<TransformOpInterface> transformOp,
507510
const ArrayRef<DeviceMappingAttrInterface> &threadMappingAttributes) {
508511
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
509512
target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
@@ -517,14 +520,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForeachToThreadsImpl(
517520
foreachThreadOp.getMapping(), transformOp);
518521
if (diag.succeeded()) {
519522
rewriter.setInsertionPoint(foreachThreadOp);
520-
IndexType indexType = rewriter.getIndexType();
521-
SmallVector<Value> threadOps{
522-
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
523-
Dimension::x),
524-
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
525-
Dimension::y),
526-
rewriter.create<ThreadIdOp>(foreachThreadOp.getLoc(), indexType,
527-
Dimension::z)};
523+
SmallVector<Value> threadOps;
524+
threadIdGenerator(rewriter, foreachThreadOp, threadOps);
528525
diag = rewriteOneForeachThreadToGpuThreads(
529526
rewriter, foreachThreadOp, blockDim, threadOps, syncAfterDistribute,
530527
transformOp, threadMappingAttributes);
@@ -562,10 +559,20 @@ DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne(
562559
GPUThreadMappingAttr::get(ctx, Threads::DimX),
563560
GPUThreadMappingAttr::get(ctx, Threads::DimY),
564561
GPUThreadMappingAttr::get(ctx, Threads::DimZ)};
565-
562+
auto threadIdGenerator = [](RewriterBase &rewriter,
563+
scf::ForeachThreadOp foreachThreadOp,
564+
SmallVectorImpl<Value> &threadIds) {
565+
IndexType indexType = rewriter.getIndexType();
566+
threadIds.assign({rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
567+
indexType, Dimension::x),
568+
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
569+
indexType, Dimension::y),
570+
rewriter.create<ThreadIdOp>(foreachThreadOp->getLoc(),
571+
indexType, Dimension::z)});
572+
};
566573
diag = mlir::transform::gpu::mapNestedForeachToThreadsImpl(
567-
rewriter, target, blockDim, getSyncAfterDistribute(), transformOp,
568-
threadMappingAttributes);
574+
rewriter, target, blockDim, threadIdGenerator, getSyncAfterDistribute(),
575+
transformOp, threadMappingAttributes);
569576

570577
if (diag.succeeded()) {
571578
diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt,

0 commit comments

Comments
 (0)