Skip to content

Commit 8c603f0

Browse files
committed
[MLIR][GPU] Add a pattern to rewrite gpu.subgroup_id
This patch impelemnts a rewrite pattern for transforming `gpu.subgroup_id` to: ``` subgroup_id = linearized_thread_id / gpu.subgroup_size ``` where: ``` linearized_thread_id = thread_id.x + block_dim.x * (thread_id.y + block_dim.y * thread_id.z) ```
1 parent 1b5cd1d commit 8c603f0

File tree

4 files changed

+114
-0
lines changed

4 files changed

+114
-0
lines changed

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ class FuncOp;
3939
/// Collect a set of patterns to rewrite GlobalIdOp op within the GPU dialect.
4040
void populateGpuGlobalIdPatterns(RewritePatternSet &patterns);
4141

42+
/// Collect a set of patterns to rewrite SubgrouplIdOp op within the GPU
43+
/// dialect.
44+
void populateGpuSubgroupIdPatterns(RewritePatternSet &patterns);
45+
4246
/// Collect a set of patterns to rewrite shuffle ops within the GPU dialect.
4347
void populateGpuShufflePatterns(RewritePatternSet &patterns);
4448

@@ -88,6 +92,7 @@ inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
8892
populateGpuAllReducePatterns(patterns);
8993
populateGpuGlobalIdPatterns(patterns);
9094
populateGpuShufflePatterns(patterns);
95+
populateGpuSubgroupIdPatterns(patterns);
9196
}
9297

9398
namespace gpu {

mlir/lib/Dialect/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
4040
Transforms/ROCDLAttachTarget.cpp
4141
Transforms/ShuffleRewriter.cpp
4242
Transforms/SPIRVAttachTarget.cpp
43+
Transforms/SubgroupIdRewriter.cpp
4344
Transforms/SubgroupReduceLowering.cpp
4445

4546
OBJECT
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- SubgroupIdRewriter.cpp - Implementation of SugroupId rewriting ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements in-dialect rewriting of the gpu.subgroup_id op for archs
10+
// where:
11+
// subgroup_id = (tid.x + dim.x * (tid.y + dim.y * tid.z)) / subgroup_size
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
16+
#include "mlir/Dialect/GPU/Transforms/Passes.h"
17+
#include "mlir/Dialect/Index/IR/IndexOps.h"
18+
#include "mlir/IR/Builders.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Pass/Pass.h"
21+
22+
using namespace mlir;
23+
24+
namespace {
25+
struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
26+
using OpRewritePattern<gpu::SubgroupIdOp>::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(gpu::SubgroupIdOp op,
29+
PatternRewriter &rewriter) const override {
30+
// Calculation of the thread's subgroup identifier.
31+
//
32+
// The process involves mapping the thread's 3D identifier within its
33+
// block (b_id.x, b_id.y, b_id.z) to a 1D linear index.
34+
// This linearization assumes a layout where the x-dimension (w_dim.x)
35+
// varies most rapidly (i.e., it is the innermost dimension).
36+
//
37+
// The formula for the linearized thread index is:
38+
// L = tid.x + dim.x * (tid.y + (dim.y * tid.z))
39+
//
40+
// Subsequently, the range of linearized indices [0, N_threads-1] is
41+
// divided into consecutive, non-overlapping segments, each representing
42+
// a subgroup of size 'subgroup_size'.
43+
//
44+
// Example Partitioning (N = subgroup_size):
45+
// | Subgroup 0 | Subgroup 1 | Subgroup 2 | ... |
46+
// | Indices 0..N-1 | Indices N..2N-1 | Indices 2N..3N-1| ... |
47+
//
48+
// The subgroup identifier is obtained via integer division of the
49+
// linearized thread index by the predefined 'subgroup_size'.
50+
//
51+
// subgroup_id = floor( L / subgroup_size )
52+
// = (tid.x + dim.x * (tid.y + dim.y * tid.z)) /
53+
// subgroup_size
54+
55+
auto loc = op->getLoc();
56+
57+
Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
58+
Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
59+
Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
60+
Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
61+
Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
62+
63+
Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
64+
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
65+
Value dimYxIdZPlusIdYTimesDimX =
66+
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
67+
Value IdXPlusDimYxIdZPlusIdYTimesDimX =
68+
rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
69+
Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
70+
loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
71+
Value subgroupIdOp = rewriter.create<index::DivUOp>(
72+
loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
73+
rewriter.replaceOp(op, {subgroupIdOp});
74+
return success();
75+
}
76+
};
77+
78+
} // namespace
79+
80+
void mlir::populateGpuSubgroupIdPatterns(RewritePatternSet &patterns) {
81+
patterns.add<GpuSubgroupIdRewriter>(patterns.getContext());
82+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// RUN: mlir-opt --test-gpu-rewrite -split-input-file %s | FileCheck %s
2+
3+
module {
4+
// CHECK-LABEL: func.func @subgroupId
5+
// CHECK-SAME: (%[[SZ:.*]]: index, %[[MEM:.*]]: memref<index, 1>) {
6+
func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
7+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %sz, %grid_y = %sz, %grid_z = %sz)
8+
threads(%tx, %ty, %tz) in (%block_x = %sz, %block_y = %sz, %block_z = %sz) {
9+
// CHECK: %[[DIMX:.*]] = gpu.block_dim x
10+
// CHECK-NEXT: %[[DIMY:.*]] = gpu.block_dim y
11+
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
12+
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
13+
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
14+
// CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
15+
// CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
16+
// CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
17+
// CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
18+
// CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
19+
// CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
20+
%idz = gpu.subgroup_id : index
21+
memref.store %idz, %mem[] : memref<index, 1>
22+
gpu.terminator
23+
}
24+
return
25+
}
26+
}

0 commit comments

Comments
 (0)