Skip to content

Commit 6e47937

Browse files
authored
[MLIR][ROCDL] Lower gpu.subgroup_size to wavefrontsize (#137360)
1 parent 6c78ded commit 6e47937

File tree

5 files changed

+81
-7
lines changed

5 files changed

+81
-7
lines changed

mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ class RewritePatternSet;
2020
template <typename OpT>
2121
class OperationPass;
2222

23+
namespace amdgpu {
24+
struct Chipset;
25+
} // namespace amdgpu
26+
2327
namespace gpu {
2428
class GPUModuleOp;
2529
} // namespace gpu
@@ -32,7 +36,8 @@ class GPUModuleOp;
3236
/// The resulting pattern set should be run over a gpu.module op
3337
void populateGpuToROCDLConversionPatterns(const LLVMTypeConverter &converter,
3438
RewritePatternSet &patterns,
35-
gpu::amd::Runtime runtime);
39+
gpu::amd::Runtime runtime,
40+
amdgpu::Chipset chipset);
3641

3742
/// Configure target to convert from the GPU dialect to ROCDL.
3843
void configureGpuToROCDLConversionLegality(ConversionTarget &target);

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">;
216216
def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">;
217217
def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">;
218218

219+
def ROCDL_WavefrontSizeOp : ROCDL_SpecialIdRegisterOp<"wavefrontsize">;
220+
219221
//===----------------------------------------------------------------------===//
220222
// Thread range and Block range
221223
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,25 @@ namespace mlir {
5252

5353
using namespace mlir;
5454

55+
// Truncate or extend the result depending on the index bitwidth specified
56+
// by the LLVMTypeConverter options.
57+
static Value truncOrExtToLLVMType(ConversionPatternRewriter &rewriter,
58+
Location loc, Value value,
59+
const LLVMTypeConverter &converter) {
60+
int64_t intWidth = cast<IntegerType>(value.getType()).getWidth();
61+
int64_t indexBitwidth = converter.getIndexTypeBitwidth();
62+
auto indexBitwidthType =
63+
IntegerType::get(rewriter.getContext(), converter.getIndexTypeBitwidth());
64+
// TODO: use <=> in C++20.
65+
if (indexBitwidth > intWidth) {
66+
return rewriter.create<LLVM::SExtOp>(loc, indexBitwidthType, value);
67+
}
68+
if (indexBitwidth < intWidth) {
69+
return rewriter.create<LLVM::TruncOp>(loc, indexBitwidthType, value);
70+
}
71+
return value;
72+
}
73+
5574
/// Returns true if the given `gpu.func` can be safely called using the bare
5675
/// pointer calling convention.
5776
static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
@@ -113,6 +132,35 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
113132
}
114133
};
115134

135+
struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
136+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
137+
138+
GPUSubgroupSizeOpToROCDL(const LLVMTypeConverter &converter,
139+
amdgpu::Chipset chipset)
140+
: ConvertOpToLLVMPattern<gpu::SubgroupSizeOp>(converter),
141+
chipset(chipset) {}
142+
143+
LogicalResult
144+
matchAndRewrite(gpu::SubgroupSizeOp op, gpu::SubgroupSizeOp::Adaptor adaptor,
145+
ConversionPatternRewriter &rewriter) const override {
146+
LLVM::ConstantRangeAttr bounds = nullptr;
147+
bool isBeforeGfx10 = chipset.majorVersion < 10;
148+
if (auto upperBoundAttr = op.getUpperBoundAttr()) {
149+
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
150+
/*bitWidth=*/32, /*lower=*/isBeforeGfx10 ? 64 : 32,
151+
/*upper=*/op.getUpperBoundAttr().getInt() + 1);
152+
}
153+
Value wavefrontOp = rewriter.create<ROCDL::WavefrontSizeOp>(
154+
op.getLoc(), rewriter.getI32Type(), bounds);
155+
wavefrontOp = truncOrExtToLLVMType(rewriter, op.getLoc(), wavefrontOp,
156+
*getTypeConverter());
157+
rewriter.replaceOp(op, {wavefrontOp});
158+
return success();
159+
}
160+
161+
const amdgpu::Chipset chipset;
162+
};
163+
116164
struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
117165
using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern;
118166

@@ -322,7 +370,8 @@ struct LowerGpuOpsToROCDLOpsPass final
322370

323371
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
324372
*maybeChipset);
325-
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime);
373+
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
374+
*maybeChipset);
326375
configureGpuToROCDLConversionLegality(target);
327376
if (failed(applyPartialConversion(m, target, std::move(llvmPatterns))))
328377
signalPassFailure();
@@ -370,7 +419,7 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
370419

371420
void mlir::populateGpuToROCDLConversionPatterns(
372421
const LLVMTypeConverter &converter, RewritePatternSet &patterns,
373-
mlir::gpu::amd::Runtime runtime) {
422+
mlir::gpu::amd::Runtime runtime, amdgpu::Chipset chipset) {
374423
using gpu::index_lowering::IndexKind;
375424
using gpu::index_lowering::IntrType;
376425
using mlir::gpu::amd::Runtime;
@@ -408,7 +457,10 @@ void mlir::populateGpuToROCDLConversionPatterns(
408457
// TODO: Add alignment for workgroup memory
409458
patterns.add<GPUDynamicSharedMemoryOpLowering>(converter);
410459

411-
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
460+
patterns
461+
.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL>(
462+
converter);
463+
patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset);
412464

413465
populateMathToROCDLConversionPatterns(converter, patterns);
414466
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ gpu.module @test_module {
1111
func.func @gpu_index_ops()
1212
-> (index, index, index, index, index, index,
1313
index, index, index, index, index, index,
14-
index) {
14+
index, index, index) {
1515
// CHECK32-NOT: = llvm.sext %{{.*}} : i32 to i64
1616

1717
// CHECK: rocdl.workitem.id.x : i32
@@ -59,12 +59,20 @@ gpu.module @test_module {
5959
// CHECK: = llvm.sext %{{.*}} : i32 to i64
6060
%laneId = gpu.lane_id
6161

62+
// CHECK: = rocdl.wavefrontsize : i32
63+
// CHECK: = llvm.sext %{{.*}} : i32 to i64
64+
%subgroupSize = gpu.subgroup_size : index
65+
66+
// CHECK: = rocdl.wavefrontsize range <i32, 64, 65> : i32
67+
// CHECK: = llvm.sext %{{.*}} : i32 to i64
68+
%subgroupSize2 = gpu.subgroup_size upper_bound 64 : index
69+
6270
func.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ,
6371
%bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ,
64-
%laneId
72+
%laneId, %subgroupSize, %subgroupSize2
6573
: index, index, index, index, index, index,
6674
index, index, index, index, index, index,
67-
index
75+
index, index, index
6876
}
6977
}
7078

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ llvm.func @rocdl_special_regs() -> i32 {
3232

3333
// CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0)
3434
%14 = rocdl.workgroup.dim.x range <i32, 1, 65> : i64
35+
36+
// CHECK: call i32 @llvm.amdgcn.wavefrontsize()
37+
%15 = rocdl.wavefrontsize : i32
38+
39+
// CHECK: call range(i32 32, 65) i32 @llvm.amdgcn.wavefrontsize()
40+
%16 = rocdl.wavefrontsize range <i32, 32, 65> : i32
41+
3542
llvm.return %1 : i32
3643
}
3744

0 commit comments

Comments
 (0)