Skip to content

Commit 2cc88a6

Browse files
committed
[MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
1 parent 1b5cd1d commit 2cc88a6

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ class FuncOp;
3939
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
4040
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
4141

42+
/// Collect a set of patterns to rewrite SubgrouplIdOp op within the GPU dialect.
43+
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);
44+
4245
/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
4346
void populateGpuShufflePatterns(RewritePatternSet &patterns);
4447

@@ -88,6 +91,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
8891
populateGpuAllReducePatterns(patterns);
8992
populateGpuGlobalIdPatterns(patterns);
9093
populateGpuShufflePatterns(patterns);
94+
populateGpuSubgroupIdPatterns(patterns);
9195
}
9296

9397
namespace gpu {

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
4040
Transforms/ROCDLAttachTarget.cpp
4141
Transforms/ShuffleRewriter.cpp
4242
Transforms/SPIRVAttachTarget.cpp
43+
Transforms/SubgroupIdRewriter.cpp
4344
Transforms/SubgroupReduceLowering.cpp
4445

4546
OBJECT
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10+
// where:
11+
// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
17+
#include "mlir/Dialect/Index/IR/IndexOps.h"
18+
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Pass/Pass.h"
21+
22+
using namespace mlir;
23+
24+
namespace {
25+
struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
26+
using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
27+
28+
LogicalResult
29+
matchAndRewrite(gpu::SubgroupIdOp op,
30+
PatternRewriter &rewriter) const override {
31+
// Calculation of the thread's subgroup identifier.
32+
//
33+
// The process involves mapping the thread's 3D identifier within its
34+
// block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
35+
// This linearization assumes a layout where the x-dimension (w_dim.x)
36+
// varies most rapidly (i.e., it is the innermost dimension).
37+
//
38+
// The formula for the linearized thread index is:
39+
// L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
40+
//
41+
// Subsequently, the range of linearized indices [0, N_threads-1] is
42+
// divided into consecutive, non-overlapping segments, each representing
43+
// a subgroup of size 'subgroup_size'.
44+
//
45+
// Example Partitioning (N = subgroup_size):
46+
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
47+
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
48+
//
49+
// The subgroup identifier is obtained via integer division of the
50+
// linearized thread index by the predefined 'subgroup_size'.
51+
//
52+
// subgroup_id = floor( L / subgroup_size )
53+
// = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
54+
// subgroup_size
55+
56+
auto loc = op->getLoc();
57+
58+
Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
59+
Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
60+
Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
61+
Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
62+
Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
63+
64+
Value dimYxIdZ =
65+
rewriter.create<index::MulOp>(loc, dimY, tidZ);
66+
Value dimYxIdZPlusIdY =
67+
rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
68+
Value dimYxIdZPlusIdYTimesDimX =
69+
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
70+
Value IdXPlusDimYxIdZPlusIdYTimesDimX =
71+
rewriter.create<index::AddOp>(loc, tidX,
72+
dimYxIdZPlusIdYTimesDimX);
73+
Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
74+
loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
75+
Value subgroupIdOp = rewriter.create<index::DivUOp>(
76+
loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
77+
rewriter.replaceOp(op, {subgroupIdOp});
78+
return success();
79+
}
80+
};
81+
82+
} // namespace
83+
84+
void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
85+
patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
86+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
2+
3+
module {
4+
// CHECK-LABEL: func.func @subgroupId
5+
// CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
6+
func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
7+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
8+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
9+
// CHECK: %[[DIMX:.*]] = gpu.block_dim x
10+
// CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim y
11+
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
12+
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
13+
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
14+
// CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
15+
// CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
16+
// CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
17+
// CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
18+
// CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
19+
// CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
20+
%idz = gpu.subgroup_id : index
21+
memref.store %idz, %mem[] : memref<index, 1>
22+
gpu.terminator
23+
}
24+
return
25+
}
26+
}

0 commit comments

Comments
 (0)