@@ -52,25 +52,6 @@ namespace mlir {
52
52
53
53
using namespace mlir ;
54
54
55
- // Truncate or extend the result depending on the index bitwidth specified
56
- // by the LLVMTypeConverter options.
57
- static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
58
- Location loc, Value value,
59
- const LLVMTypeConverter &converter) {
60
- int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
61
- int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
62
- auto indexBitwidthType =
63
- IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
64
- // TODO: use <=> in C++20.
65
- if (indexBitwidth > intWidth) {
66
- return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
67
- }
68
- if (indexBitwidth < intWidth) {
69
- return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
70
- }
71
- return value;
72
- }
73
-
74
55
// / Returns true if the given `gpu.func` can be safely called using the bare
75
56
// / pointer calling convention.
76
57
static bool canBeCalledWithBarePointers (gpu::GPUFuncOp func) {
@@ -99,6 +80,26 @@ static constexpr StringLiteral amdgcnDataLayout =
99
80
" 64-S32-A5-G1-ni:7:8:9" ;
100
81
101
82
namespace {
83
+
84
+ // Truncate or extend the result depending on the index bitwidth specified
85
+ // by the LLVMTypeConverter options.
86
+ static Value truncOrExtToLLVMType (ConversionPatternRewriter &rewriter,
87
+ Location loc, Value value,
88
+ const LLVMTypeConverter &converter) {
89
+ int64_t intWidth = cast<IntegerType>(value.getType ()).getWidth ();
90
+ int64_t indexBitwidth = converter.getIndexTypeBitwidth ();
91
+ auto indexBitwidthType =
92
+ IntegerType::get (rewriter.getContext (), converter.getIndexTypeBitwidth ());
93
+ // TODO: use <=> in C++20.
94
+ if (indexBitwidth > intWidth) {
95
+ return rewriter.create <LLVM::SExtOp>(loc, indexBitwidthType, value);
96
+ }
97
+ if (indexBitwidth < intWidth) {
98
+ return rewriter.create <LLVM::TruncOp>(loc, indexBitwidthType, value);
99
+ }
100
+ return value;
101
+ }
102
+
102
103
struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
103
104
using ConvertOpToLLVMPattern<gpu::LaneIdOp>::ConvertOpToLLVMPattern;
104
105
@@ -117,16 +118,7 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
117
118
rewriter.create <ROCDL::MbcntLoOp>(loc, intTy, ValueRange{minus1, zero});
118
119
Value laneId = rewriter.create <ROCDL::MbcntHiOp>(
119
120
loc, intTy, ValueRange{minus1, mbcntLo});
120
- // Truncate or extend the result depending on the index bitwidth specified
121
- // by the LLVMTypeConverter options.
122
- const unsigned indexBitwidth = getTypeConverter ()->getIndexTypeBitwidth ();
123
- if (indexBitwidth > 32 ) {
124
- laneId = rewriter.create <LLVM::SExtOp>(
125
- loc, IntegerType::get (context, indexBitwidth), laneId);
126
- } else if (indexBitwidth < 32 ) {
127
- laneId = rewriter.create <LLVM::TruncOp>(
128
- loc, IntegerType::get (context, indexBitwidth), laneId);
129
- }
121
+ laneId = truncOrExtToLLVMType (rewriter, loc, laneId, *getTypeConverter ());
130
122
rewriter.replaceOp (op, {laneId});
131
123
return success ();
132
124
}
@@ -150,11 +142,11 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> {
150
142
/* bitWidth=*/ 32 , /* lower=*/ isBeforeGfx10 ? 64 : 32 ,
151
143
/* upper=*/ op.getUpperBoundAttr ().getInt () + 1 );
152
144
}
153
- Value wavefrontOp = rewriter.create <ROCDL::WavefrontSizeOp>(
145
+ Value wavefrontSizeOp = rewriter.create <ROCDL::WavefrontSizeOp>(
154
146
op.getLoc (), rewriter.getI32Type (), bounds);
155
- wavefrontOp = truncOrExtToLLVMType (rewriter, op. getLoc (), wavefrontOp,
156
- *getTypeConverter ());
157
- rewriter.replaceOp (op, {wavefrontOp });
147
+ wavefrontSizeOp = truncOrExtToLLVMType (
148
+ rewriter, op. getLoc (), wavefrontSizeOp, *getTypeConverter ());
149
+ rewriter.replaceOp (op, {wavefrontSizeOp });
158
150
return success ();
159
151
}
160
152
@@ -239,6 +231,65 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
239
231
}
240
232
};
241
233
234
+ struct GPUSubgroupIdOpToROCDL final
235
+ : ConvertOpToLLVMPattern<gpu::SubgroupIdOp> {
236
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
237
+
238
+ LogicalResult
239
+ matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
240
+ ConversionPatternRewriter &rewriter) const override {
241
+ // Calculation of the thread's subgroup identifier.
242
+ //
243
+ // The process involves mapping the thread's 3D identifier within its
244
+ // workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index.
245
+ // This linearization assumes a layout where the x-dimension (w_dim.x)
246
+ // varies most rapidly (i.e., it is the innermost dimension).
247
+ //
248
+ // The formula for the linearized thread index is:
249
+ // L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
250
+ //
251
+ // Subsequently, the range of linearized indices [0, N_threads-1] is
252
+ // divided into consecutive, non-overlapping segments, each representing
253
+ // a subgroup of size 'subgroup_size'.
254
+ //
255
+ // Example Partitioning (N = subgroup_size):
256
+ // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
257
+ // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
258
+ //
259
+ // The subgroup identifier is obtained via integer division of the
260
+ // linearized thread index by the predefined 'subgroup_size'.
261
+ //
262
+ // subgroup_id = floor( L / subgroup_size )
263
+ // = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) /
264
+ // subgroup_size
265
+ auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
266
+ Location loc = op.getLoc ();
267
+ LLVM::IntegerOverflowFlags flags =
268
+ LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
269
+ Value workitemIdX = rewriter.create <ROCDL::ThreadIdXOp>(loc, int32Type);
270
+ Value workitemIdY = rewriter.create <ROCDL::ThreadIdYOp>(loc, int32Type);
271
+ Value workitemIdZ = rewriter.create <ROCDL::ThreadIdZOp>(loc, int32Type);
272
+ Value workitemDimX = rewriter.create <ROCDL::BlockDimXOp>(loc, int32Type);
273
+ Value workitemDimY = rewriter.create <ROCDL::BlockDimYOp>(loc, int32Type);
274
+ Value dimYxIdZ = rewriter.create <LLVM::MulOp>(loc, int32Type, workitemDimY,
275
+ workitemIdZ, flags);
276
+ Value dimYxIdZPlusIdY = rewriter.create <LLVM::AddOp>(
277
+ loc, int32Type, dimYxIdZ, workitemIdY, flags);
278
+ Value dimYxIdZPlusIdYTimesDimX = rewriter.create <LLVM::MulOp>(
279
+ loc, int32Type, workitemDimX, dimYxIdZPlusIdY, flags);
280
+ Value workitemIdXPlusDimYxIdZPlusIdYTimesDimX =
281
+ rewriter.create <LLVM::AddOp>(loc, int32Type, workitemIdX,
282
+ dimYxIdZPlusIdYTimesDimX, flags);
283
+ Value subgroupSize = rewriter.create <ROCDL::WavefrontSizeOp>(
284
+ loc, rewriter.getI32Type (), nullptr );
285
+ Value waveIdOp = rewriter.create <LLVM::UDivOp>(
286
+ loc, workitemIdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
287
+ rewriter.replaceOp (op, {truncOrExtToLLVMType (rewriter, loc, waveIdOp,
288
+ *getTypeConverter ())});
289
+ return success ();
290
+ }
291
+ };
292
+
242
293
// / Import the GPU Ops to ROCDL Patterns.
243
294
#include " GPUToROCDL.cpp.inc"
244
295
@@ -249,19 +300,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
249
300
// code.
250
301
struct LowerGpuOpsToROCDLOpsPass final
251
302
: public impl::ConvertGpuOpsToROCDLOpsBase<LowerGpuOpsToROCDLOpsPass> {
252
- LowerGpuOpsToROCDLOpsPass () = default ;
253
- LowerGpuOpsToROCDLOpsPass (const std::string &chipset, unsigned indexBitwidth,
254
- bool useBarePtrCallConv,
255
- gpu::amd::Runtime runtime) {
256
- if (this ->chipset .getNumOccurrences () == 0 )
257
- this ->chipset = chipset;
258
- if (this ->indexBitwidth .getNumOccurrences () == 0 )
259
- this ->indexBitwidth = indexBitwidth;
260
- if (this ->useBarePtrCallConv .getNumOccurrences () == 0 )
261
- this ->useBarePtrCallConv = useBarePtrCallConv;
262
- if (this ->runtime .getNumOccurrences () == 0 )
263
- this ->runtime = runtime;
264
- }
303
+ using Base::Base;
265
304
266
305
void getDependentDialects (DialectRegistry ®istry) const override {
267
306
Base::getDependentDialects (registry);
@@ -456,18 +495,14 @@ void mlir::populateGpuToROCDLConversionPatterns(
456
495
patterns.add <GPUDynamicSharedMemoryOpLowering>(converter);
457
496
458
497
patterns
459
- .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupSizeOpToROCDL >(
498
+ .add <GPUShuffleOpLowering, GPULaneIdOpToROCDL, GPUSubgroupIdOpToROCDL >(
460
499
converter);
461
500
patterns.add <GPUSubgroupSizeOpToROCDL>(converter, chipset);
462
501
463
502
populateMathToROCDLConversionPatterns (converter, patterns);
464
503
}
465
504
466
505
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
467
- mlir::createLowerGpuOpsToROCDLOpsPass (const std::string &chipset,
468
- unsigned indexBitwidth,
469
- bool useBarePtrCallConv,
470
- gpu::amd::Runtime runtime) {
471
- return std::make_unique<LowerGpuOpsToROCDLOpsPass>(
472
- chipset, indexBitwidth, useBarePtrCallConv, runtime);
506
+ mlir::createLowerGpuOpsToROCDLOpsPass () {
507
+ return std::make_unique<LowerGpuOpsToROCDLOpsPass>();
473
508
}
0 commit comments