@@ -80,6 +80,23 @@ static constexpr StringLiteral amdgcnDataLayout =
80
80
" 64-S32-A5-G1-ni:7:8:9" ;
81
81
82
82
namespace {
83
+
84
+ // Truncate or extend the result depending on the index bitwidth specified
85
+ // by the LLVMTypeConverter options.
86
+ template <int64_t N>
87
+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
88
+ Location loc, Value value,
89
+ const unsigned indexBitwidth) {
90
+ if (indexBitwidth > N) {
91
+ return rewriter.create <LLVM::SExtOp>(
92
+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
93
+ } else if (indexBitwidth < N) {
94
+ return rewriter.create <LLVM::TruncOp>(
95
+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
96
+ }
97
+ return value;
98
+ }
99
+
83
100
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
84
101
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
85
102
@@ -98,16 +115,8 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
98
115
rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
99
116
Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
100
117
loc, intTy, ValueRange{minus1, mbcntLo});
101
- // Truncate or extend the result depending on the index bitwidth specified
102
- // by the LLVMTypeConverter options.
103
118
const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
104
- if (indexBitwidth > 32 ) {
105
- laneId = rewriter.create <LLVM::SExtOp>(
106
- loc, IntegerType::get (context, indexBitwidth), laneId);
107
- } else if (indexBitwidth < 32 ) {
108
- laneId = rewriter.create <LLVM::TruncOp>(
109
- loc, IntegerType::get (context, indexBitwidth), laneId);
110
- }
119
+ laneId = truncOrExtToLLVMType<32 >(rewriter, loc, laneId, indexBitwidth);
111
120
rewriter.replaceOp (op, {laneId});
112
121
return success ();
113
122
}
@@ -190,6 +199,24 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
190
199
}
191
200
};
192
201
202
+ struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
203
+ using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
204
+
205
+ LogicalResult
206
+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
207
+ ConversionPatternRewriter &rewriter) const override {
208
+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
209
+ Value waveIdOp = rewriter.create <ROCDL::WaveIdOp>(op.getLoc (), int32Type);
210
+
211
+ waveIdOp =
212
+ truncOrExtToLLVMType<32 >(rewriter, op.getLoc (), waveIdOp,
213
+ getTypeConverter ()->getIndexTypeBitwidth ());
214
+
215
+ rewriter.replaceOp (op, {waveIdOp});
216
+ return success ();
217
+ }
218
+ };
219
+
193
220
// / Import the GPU Ops to ROCDL Patterns.
194
221
#include " GPUToROCDL.cpp.inc"
195
222
@@ -405,7 +432,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405
432
// TODO: Add alignment for workgroup memory
406
433
patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
407
434
408
- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
435
+ patterns
436
+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
437
+ converter);
409
438
410
439
populateMathToROCDLConversionPatterns (converter, patterns);
411
440
}
0 commit comments