Skip to content

Commit 4c8b716

Browse files
committed
[MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL
1 parent 9799746 commit 4c8b716

File tree

4 files changed

+116
-58
lines changed

4 files changed

+116
-58
lines changed

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "mlir/Conversion/GPUToROCDL/Runtimes.h"
1212
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
13+
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1314
#include <memory>
1415

1516
namespace mlir {
@@ -46,11 +47,7 @@ void configureGpuToROCDLConversionLegality(ConversionTarget &target);
4647
/// index bitwidth used for the lowering of the device side index computations
4748
/// is configurable.
4849
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
49-
createLowerGpuOpsToROCDLOpsPass(
50-
const std::string &chipset = "gfx900",
51-
unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout,
52-
bool useBarePtrCallConv = false,
53-
gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown);
50+
createLowerGpuOpsToROCDLOpsPass();
5451

5552
} // namespace mlir
5653

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,10 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
608608
clEnumValN(::mlir::gpu::amd::Runtime::HIP, "HIP", "HIP"),
609609
clEnumValN(::mlir::gpu::amd::Runtime::OpenCL, "OpenCL",
610610
"OpenCL"))}]>,
611+
Option<"subgroupSize", "subgroup-size", "unsigned",
612+
"0",
613+
"specify subgroup size for the kernel, if left empty, the default "
614+
"value will be decided by the target chipset.">,
611615
ListOption<"allowedDialects", "allowed-dialects", "std::string",
612616
"Run conversion patterns of only the specified dialects">,
613617
];

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,6 @@ namespace mlir {
5252

5353
using namespace mlir;
5454

55-
// Truncate or extend the result depending on the index bitwidth specified
56-
// by the LLVMTypeConverter options.
57-
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
58-
Location loc, Value value,
59-
const LLVMTypeConverter &converter) {
60-
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
61-
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
62-
auto indexBitwidthType =
63-
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
64-
// TODO: use <=> in C++20.
65-
if (indexBitwidth > intWidth) {
66-
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
67-
}
68-
if (indexBitwidth < intWidth) {
69-
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
70-
}
71-
return value;
72-
}
73-
7455
/// Returns true if the given `gpu.func` can be safely called using the bare
7556
/// pointer calling convention.
7657
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
@@ -99,6 +80,26 @@ static constexpr StringLiteral amdgcnDataLayout =
9980
"64-S32-A5-G1-ni:7:8:9";
10081

10182
namespace {
83+
84+
// Truncate or extend the result depending on the index bitwidth specified
85+
// by the LLVMTypeConverter options.
86+
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
87+
Location loc, Value value,
88+
const LLVMTypeConverter &converter) {
89+
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
90+
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
91+
auto indexBitwidthType =
92+
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
93+
// TODO: use <=> in C++20.
94+
if (indexBitwidth > intWidth) {
95+
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
96+
}
97+
if (indexBitwidth < intWidth) {
98+
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
99+
}
100+
return value;
101+
}
102+
102103
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
103104
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
104105

@@ -117,16 +118,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
117118
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
118119
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
119120
loc, intTy, ValueRange{minus1, mbcntLo});
120-
// Truncate or extend the result depending on the index bitwidth specified
121-
// by the LLVMTypeConverter options.
122-
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
123-
if (indexBitwidth > 32) {
124-
laneId = rewriter.create<LLVM::SExtOp>(
125-
loc, IntegerType::get(context, indexBitwidth), laneId);
126-
} else if (indexBitwidth < 32) {
127-
laneId = rewriter.create<LLVM::TruncOp>(
128-
loc, IntegerType::get(context, indexBitwidth), laneId);
129-
}
121+
laneId = truncOrExtToLLVMType(rewriter, loc, laneId, *getTypeConverter());
130122
rewriter.replaceOp(op, {laneId});
131123
return success();
132124
}
@@ -150,11 +142,11 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
150142
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
151143
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
152144
}
153-
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
145+
Value wavefrontSizeOp = rewriter.create<ROCDL::WavefrontSizeOp>(
154146
op.getLoc(), rewriter.getI32Type(), bounds);
155-
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
156-
*getTypeConverter());
157-
rewriter.replaceOp(op, {wavefrontOp});
147+
wavefrontSizeOp = truncOrExtToLLVMType(
148+
rewriter, op.getLoc(), wavefrontSizeOp, *getTypeConverter());
149+
rewriter.replaceOp(op, {wavefrontSizeOp});
158150
return success();
159151
}
160152

@@ -239,6 +231,65 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
239231
}
240232
};
241233

234+
struct GPUSubgroupIdOpToROCDL final
235+
: ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
236+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
237+
238+
LogicalResult
239+
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
240+
ConversionPatternRewriter &rewriter) const override {
241+
// Calculation of the thread's subgroup identifier.
242+
//
243+
// The process involves mapping the thread's 3D identifier within its
244+
// workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index.
245+
// This linearization assumes a layout where the x-dimension (w_dim.x)
246+
// varies most rapidly (i.e., it is the innermost dimension).
247+
//
248+
// The formula for the linearized thread index is:
249+
// L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
250+
//
251+
// Subsequently, the range of linearized indices [0, N_threads-1] is
252+
// divided into consecutive, non-overlapping segments, each representing
253+
// a subgroup of size 'subgroup_size'.
254+
//
255+
// Example Partitioning (N = subgroup_size):
256+
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
257+
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
258+
//
259+
// The subgroup identifier is obtained via integer division of the
260+
// linearized thread index by the predefined 'subgroup_size'.
261+
//
262+
// subgroup_id = floor( L / subgroup_size )
263+
// = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) /
264+
// subgroup_size
265+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
266+
Location loc = op.getLoc();
267+
LLVM::IntegerOverflowFlags flags =
268+
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
269+
Value workitemIdX = rewriter.create<ROCDL::ThreadIdXOp>(loc, int32Type);
270+
Value workitemIdY = rewriter.create<ROCDL::ThreadIdYOp>(loc, int32Type);
271+
Value workitemIdZ = rewriter.create<ROCDL::ThreadIdZOp>(loc, int32Type);
272+
Value workitemDimX = rewriter.create<ROCDL::BlockDimXOp>(loc, int32Type);
273+
Value workitemDimY = rewriter.create<ROCDL::BlockDimYOp>(loc, int32Type);
274+
Value dimYxIdZ = rewriter.create<LLVM::MulOp>(loc, int32Type, workitemDimY,
275+
workitemIdZ, flags);
276+
Value dimYxIdZPlusIdY = rewriter.create<LLVM::AddOp>(
277+
loc, int32Type, dimYxIdZ, workitemIdY, flags);
278+
Value dimYxIdZPlusIdYTimesDimX = rewriter.create<LLVM::MulOp>(
279+
loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
280+
Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
281+
rewriter.create<LLVM::AddOp>(loc, int32Type, workitemIdX,
282+
dimYxIdZPlusIdYTimesDimX, flags);
283+
Value subgroupSize = rewriter.create<ROCDL::WavefrontSizeOp>(
284+
loc, rewriter.getI32Type(), nullptr);
285+
Value waveIdOp = rewriter.create<LLVM::UDivOp>(
286+
loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
287+
rewriter.replaceOp(op, {truncOrExtToLLVMType(rewriter, loc, waveIdOp,
288+
*getTypeConverter())});
289+
return success();
290+
}
291+
};
292+
242293
/// Import the GPU Ops to ROCDL Patterns.
243294
#include "GPUToROCDL.cpp.inc"
244295

@@ -249,19 +300,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
249300
// code.
250301
struct LowerGpuOpsToROCDLOpsPass final
251302
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
252-
LowerGpuOpsToROCDLOpsPass() = default;
253-
LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth,
254-
bool useBarePtrCallConv,
255-
gpu::amd::Runtime runtime) {
256-
if (this->chipset.getNumOccurrences() == 0)
257-
this->chipset = chipset;
258-
if (this->indexBitwidth.getNumOccurrences() == 0)
259-
this->indexBitwidth = indexBitwidth;
260-
if (this->useBarePtrCallConv.getNumOccurrences() == 0)
261-
this->useBarePtrCallConv = useBarePtrCallConv;
262-
if (this->runtime.getNumOccurrences() == 0)
263-
this->runtime = runtime;
264-
}
303+
using Base::Base;
265304

266305
void getDependentDialects(DialectRegistry &registry) const override {
267306
Base::getDependentDialects(registry);
@@ -456,18 +495,14 @@ void mlir::populateGpuToROCDLConversionPatterns(
456495
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
457496

458497
patterns
459-
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
498+
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
460499
converter);
461500
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
462501

463502
populateMathToROCDLConversionPatterns(converter, patterns);
464503
}
465504

466505
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
467-
mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset,
468-
unsigned indexBitwidth,
469-
bool useBarePtrCallConv,
470-
gpu::amd::Runtime runtime) {
471-
return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
472-
chipset, indexBitwidth, useBarePtrCallConv, runtime);
506+
mlir::createLowerGpuOpsToROCDLOpsPass() {
507+
return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
473508
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,25 @@ gpu.module @test_module {
763763
gpu.module @test_custom_data_layout attributes {llvm.data_layout = "e"} {
764764

765765
}
766+
767+
// -----
768+
769+
gpu.module @test_module {
770+
// CHECK-LABEL: func @gpu_subgroup_id()
771+
func.func @gpu_subgroup_id() -> (index) {
772+
// CHECK: %[[widx:.*]] = rocdl.workitem.id.x : i32
773+
// CHECK: %[[widy:.*]] = rocdl.workitem.id.y : i32
774+
// CHECK: %[[widz:.*]] = rocdl.workitem.id.z : i32
775+
// CHECK: %[[dimx:.*]] = rocdl.workgroup.dim.x : i32
776+
// CHECK: %[[dimy:.*]] = rocdl.workgroup.dim.y : i32
777+
// CHECK: %[[int5:.*]] = llvm.mul %[[dimy]], %[[widz]] overflow<nsw, nuw> : i32
778+
// CHECK: %[[int6:.*]] = llvm.add %[[int5]], %[[widy]] overflow<nsw, nuw> : i32
779+
// CHECK: %[[int7:.*]] = llvm.mul %[[dimx]], %[[int6]] overflow<nsw, nuw> : i32
780+
// CHECK: %[[int8:.*]] = llvm.add %[[widx]], %[[int7]] overflow<nsw, nuw> : i32
781+
// CHECK: %[[wavefrontsize:.*]] = rocdl.wavefrontsize : i32
782+
// CHECK: %[[result:.*]] = llvm.udiv %[[int8]], %[[wavefrontsize]] : i32
783+
// CHECK: = llvm.sext %[[result]] : i32 to i64
784+
%subgroupId = gpu.subgroup_id : index
785+
func.return %subgroupId : index
786+
}
787+
}

0 commit comments

Comments
 (0)