Skip to content

Commit d75c210

Browse files
committed
[MLIR][ROCDL] Add conversion for gpu.subgroup_id to ROCDL
Creates `rocdl.wave_id` op with llvm conversion to: `__builtin_amdgcn_s_get_waveid_in_workgroup`
1 parent 34f3466 commit d75c210

File tree

4 files changed

+58
-14
lines changed

4 files changed

+58
-14
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,14 @@ def ROCDL_ReadlaneOp : ROCDL_IntrOp<"readlane", [], [0], [AllTypesMatch<["res",
204204
}];
205205
}
206206

207+
// the intrinsic function name is too long so we use a shorter name for rocdl.
208+
def ROCDL_WaveIdOp : LLVM_IntrOpBase<ROCDL_Dialect, "wave_id",
209+
"amdgcn_s_get_waveid_in_workgroup", [], [], [Pure], 1>,
210+
Arguments<(ins)> {
211+
let results = (outs LLVM_Type:$res);
212+
let assemblyFormat = "attr-dict `:` type($res)";
213+
}
214+
207215
//===----------------------------------------------------------------------===//
208216
// Thread index and Block index
209217
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ static constexpr StringLiteral amdgcnDataLayout =
8080
"64-S32-A5-G1-ni:7:8:9";
8181

8282
namespace {
83+
84+
// Truncate or extend the result depending on the index bitwidth specified
85+
// by the LLVMTypeConverter options.
86+
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
87+
Location loc, Value value,
88+
const LLVMTypeConverter *converter) {
89+
auto intWidth = cast<IntegerType>(value.getType()).getWidth();
90+
auto indexBitwidth = converter->getIndexTypeBitwidth();
91+
if (indexBitwidth > intWidth) {
92+
return rewriter.create<LLVM::SExtOp>(
93+
loc, IntegerType::get(rewriter.getContext(), indexBitwidth), value);
94+
} else if (indexBitwidth < intWidth) {
95+
return rewriter.create<LLVM::TruncOp>(
96+
loc, IntegerType::get(rewriter.getContext(), indexBitwidth), value);
97+
}
98+
return value;
99+
}
100+
83101
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
84102
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
85103

@@ -98,16 +116,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
98116
rewriter.create<ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
99117
Value laneId = rewriter.create<ROCDL::MbcntHiOp>(
100118
loc, intTy, ValueRange{minus1, mbcntLo});
101-
// Truncate or extend the result depending on the index bitwidth specified
102-
// by the LLVMTypeConverter options.
103-
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
104-
if (indexBitwidth > 32) {
105-
laneId = rewriter.create<LLVM::SExtOp>(
106-
loc, IntegerType::get(context, indexBitwidth), laneId);
107-
} else if (indexBitwidth < 32) {
108-
laneId = rewriter.create<LLVM::TruncOp>(
109-
loc, IntegerType::get(context, indexBitwidth), laneId);
110-
}
119+
laneId = truncOrExtToLLVMType(rewriter, loc, laneId, getTypeConverter());
111120
rewriter.replaceOp(op, {laneId});
112121
return success();
113122
}
@@ -190,6 +199,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
190199
}
191200
};
192201

202+
struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
203+
using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
204+
205+
LogicalResult
206+
matchAndRewrite(gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
207+
ConversionPatternRewriter &rewriter) const override {
208+
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
209+
Value waveIdOp = rewriter.create<ROCDL::WaveIdOp>(op.getLoc(), int32Type);
210+
waveIdOp = truncOrExtToLLVMType(rewriter, op.getLoc(), waveIdOp,
211+
getTypeConverter());
212+
rewriter.replaceOp(op, {waveIdOp});
213+
return success();
214+
}
215+
};
216+
193217
/// Import the GPU Ops to ROCDL Patterns.
194218
#include "GPUToROCDL.cpp.inc"
195219

@@ -405,7 +429,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405429
// TODO: Add alignment for workgroup memory
406430
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
407431

408-
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
432+
patterns
433+
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
434+
converter);
409435

410436
populateMathToROCDLConversionPatterns(converter, patterns);
411437
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ gpu.module @test_module {
1111
func.func @gpu_index_ops()
1212
-> (index, index, index, index, index, index,
1313
index, index, index, index, index, index,
14-
index) {
14+
index, index) {
1515
// CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
1616

1717
// CHECK: rocdl.workitem.id.x : i32
@@ -59,12 +59,16 @@ gpu.module @test_module {
5959
// CHECK: = llvm.sext %{{.*}} : i32 to i64
6060
%laneId = gpu.lane_id
6161

62+
// CHECK: = rocdl.wave_id : i32
63+
// CHECK: = llvm.sext %{{.*}} : i32 to i64
64+
%waveId = gpu.subgroup_id : index
65+
6266
func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
6367
%bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
64-
%laneId
68+
%laneId, %waveId
6569
: index, index, index, index, index, index,
6670
index, index, index, index, index, index,
67-
index
71+
index, index
6872
}
6973
}
7074

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ llvm.func @rocdl.lane_id() -> i32 {
8888
llvm.return %3 : i32
8989
}
9090

91+
llvm.func @rocdl.wave_id() -> i32 {
92+
// CHECK: call i32 @llvm.amdgcn.s.get.waveid.in.workgroup()
93+
%0 = rocdl.wave_id : i32
94+
llvm.return %0 : i32
95+
}
96+
9197
llvm.func @rocdl.swizzle(%src : i32) -> i32 {
9298
// CHECK-LABEL: rocdl.swizzle
9399
// CHECK: call i32 @llvm.amdgcn.ds.swizzle

0 commit comments

Comments
 (0)