@@ -80,6 +80,24 @@ static constexpr StringLiteral amdgcnDataLayout =
80
80
" 64-S32-A5-G1-ni:7:8:9" ;
81
81
82
82
namespace {
83
+
84
+ // Truncate or extend the result depending on the index bitwidth specified
85
+ // by the LLVMTypeConverter options.
86
+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
87
+ Location loc, Value value,
88
+ const LLVMTypeConverter *converter) {
89
+ auto intWidth = cast<IntegerType>(value.getType ()).getWidth ();
90
+ auto indexBitwidth = converter->getIndexTypeBitwidth ();
91
+ if (indexBitwidth > intWidth) {
92
+ return rewriter.create <LLVM::SExtOp>(
93
+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
94
+ } else if (indexBitwidth < intWidth) {
95
+ return rewriter.create <LLVM::TruncOp>(
96
+ loc, IntegerType::get (rewriter.getContext (), indexBitwidth), value);
97
+ }
98
+ return value;
99
+ }
100
+
83
101
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
84
102
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
85
103
@@ -98,16 +116,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
98
116
rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
99
117
Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
100
118
loc, intTy, ValueRange{minus1, mbcntLo});
101
- // Truncate or extend the result depending on the index bitwidth specified
102
- // by the LLVMTypeConverter options.
103
- const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
104
- if (indexBitwidth > 32 ) {
105
- laneId = rewriter.create <LLVM::SExtOp>(
106
- loc, IntegerType::get (context, indexBitwidth), laneId);
107
- } else if (indexBitwidth < 32 ) {
108
- laneId = rewriter.create <LLVM::TruncOp>(
109
- loc, IntegerType::get (context, indexBitwidth), laneId);
110
- }
119
+ laneId = truncOrExtToLLVMType (rewriter, loc, laneId, getTypeConverter ());
111
120
rewriter.replaceOp (op, {laneId});
112
121
return success ();
113
122
}
@@ -190,6 +199,21 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
190
199
}
191
200
};
192
201
202
+ struct GPUSubgroupIdOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
203
+ using ConvertOpToLLVMPattern<gpu::SubgroupIdOp>::ConvertOpToLLVMPattern;
204
+
205
+ LogicalResult
206
+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
207
+ ConversionPatternRewriter &rewriter) const override {
208
+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
209
+ Value waveIdOp = rewriter.create <ROCDL::WaveIdOp>(op.getLoc (), int32Type);
210
+ waveIdOp = truncOrExtToLLVMType (rewriter, op.getLoc (), waveIdOp,
211
+ getTypeConverter ());
212
+ rewriter.replaceOp (op, {waveIdOp});
213
+ return success ();
214
+ }
215
+ };
216
+
193
217
// / Import the GPU Ops to ROCDL Patterns.
194
218
#include " GPUToROCDL.cpp.inc"
195
219
@@ -405,7 +429,9 @@ void mlir::populateGpuToROCDLConversionPatterns(
405
429
// TODO: Add alignment for workgroup memory
406
430
patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
407
431
408
- patterns.add <GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
432
+ patterns
433
+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL>(
434
+ converter);
409
435
410
436
populateMathToROCDLConversionPatterns (converter, patterns);
411
437
}
0 commit comments