@@ -52,7 +52,7 @@ namespace mlir {
52
52
53
53
using namespace mlir ;
54
54
55
- // / Query function for static subgroup size lookup for given chipset.
55
+ // / Returns the static subgroup size lookup for the given chipset.
56
56
// TODO: move this function to a common place.
57
57
static int64_t querySubgroupSize (const amdgpu::Chipset &chipset) {
58
58
// The subgroup size is the same as the wavefront size for all chipsets.
@@ -242,15 +242,34 @@ struct GPUSubgroupIdOpToROCDL final
242
242
LogicalResult
243
243
matchAndRewrite (gpu::SubgroupIdOp op, gpu::SubgroupIdOp::Adaptor adaptor,
244
244
ConversionPatternRewriter &rewriter) const override {
245
+ // Calculation of the thread's subgroup identifier.
246
+ //
247
+ // The process involves mapping the thread's 3D identifier within its
248
+ // workgroup/block (w_id.x, w_id.y, w_id.z) to a 1D linear index.
249
+ // This linearization assumes a layout where the x-dimension (w_dim.x)
250
+ // varies most rapidly (i.e., it is the innermost dimension).
251
+ //
252
+ // The formula for the linearized thread index is:
253
+ // L = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
254
+ //
255
+ // Subsequently, the range of linearized indices [0, N_threads-1] is
256
+ // divided into consecutive, non-overlapping segments, each representing
257
+ // a subgroup of size 'subgroup_size'.
258
+ //
259
+ // Example Partitioning (N = subgroup_size):
260
+ // | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
261
+ // | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
262
+ //
263
+ // The subgroup identifier is obtained via integer division of the
264
+ // linearized thread index by the predefined 'subgroup_size'.
265
+ //
266
+ // subgroup_id = floor( L / subgroup_size )
267
+ // = (w_id.x + w_dim.x * (w_id.y + w_dim.y * w_id.z)) /
268
+ // subgroup_size
245
269
auto int32Type = IntegerType::get (rewriter.getContext (), 32 );
246
- auto loc = op.getLoc ();
270
+ Location loc = op.getLoc ();
247
271
LLVM::IntegerOverflowFlags flags =
248
272
LLVM::IntegerOverflowFlags::nsw | LLVM::IntegerOverflowFlags::nuw;
249
- // linearized thread ids are divided into consecutive subgroups.
250
- // Where thread id is calculated as:
251
- // thread_id = w_id.x + w_dim.x * (w_id.y + (w_dim.y * w_id.z))
252
- // And the subgroup id of the thread is calculated as:
253
- // subgroup_id = thread_id / subgroup_size
254
273
Value workitemIdX = rewriter.create <ROCDL::ThreadIdXOp>(loc, int32Type);
255
274
Value workitemIdY = rewriter.create <ROCDL::ThreadIdYOp>(loc, int32Type);
256
275
Value workitemIdZ = rewriter.create <ROCDL::ThreadIdZOp>(loc, int32Type);
0 commit comments