Skip to content

Commit 0ba1361

Browse files
authored
[MLIR][GPU] Use arith instead of index for subgroup_id (llvm#137843)
Trying to simplify situation by using `arith` dialect instead of `index` in the rewriting of `gpu.subgroup_id`.
1 parent 6feb4a8 commit 0ba1361

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

mlir/lib/Dialect/GPU/Transforms/SubgroupIdRewriter.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,23 +53,25 @@ struct GpuSubgroupIdRewriter final : OpRewritePattern<gpu::SubgroupIdOp> {
5353
// subgroup_size
5454

5555
Location loc = op->getLoc();
56+
Type indexType = rewriter.getIndexType();
5657

5758
Value dimX = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::x);
5859
Value dimY = rewriter.create<gpu::BlockDimOp>(loc, gpu::Dimension::y);
5960
Value tidX = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
6061
Value tidY = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::y);
6162
Value tidZ = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::z);
6263

63-
Value dimYxIdZ = rewriter.create<index::MulOp>(loc, dimY, tidZ);
64-
Value dimYxIdZPlusIdY = rewriter.create<index::AddOp>(loc, dimYxIdZ, tidY);
64+
Value dimYxIdZ = rewriter.create<arith::MulIOp>(loc, indexType, dimY, tidZ);
65+
Value dimYxIdZPlusIdY =
66+
rewriter.create<arith::AddIOp>(loc, indexType, dimYxIdZ, tidY);
6567
Value dimYxIdZPlusIdYTimesDimX =
66-
rewriter.create<index::MulOp>(loc, dimX, dimYxIdZPlusIdY);
67-
Value IdXPlusDimYxIdZPlusIdYTimesDimX =
68-
rewriter.create<index::AddOp>(loc, tidX, dimYxIdZPlusIdYTimesDimX);
68+
rewriter.create<arith::MulIOp>(loc, indexType, dimX, dimYxIdZPlusIdY);
69+
Value IdXPlusDimYxIdZPlusIdYTimesDimX = rewriter.create<arith::AddIOp>(
70+
loc, indexType, tidX, dimYxIdZPlusIdYTimesDimX);
6971
Value subgroupSize = rewriter.create<gpu::SubgroupSizeOp>(
7072
loc, rewriter.getIndexType(), /*upper_bound = */ nullptr);
71-
Value subgroupIdOp = rewriter.create<index::DivUOp>(
72-
loc, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
73+
Value subgroupIdOp = rewriter.create<arith::DivUIOp>(
74+
loc, indexType, IdXPlusDimYxIdZPlusIdYTimesDimX, subgroupSize);
7375
rewriter.replaceOp(op, {subgroupIdOp});
7476
return success();
7577
}

mlir/test/Dialect/GPU/subgroupId-rewrite.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@ func.func @subgroupId(%sz : index, %mem: memref<index, 1>) {
1010
// CHECK-NEXT: %[[TIDX:.*]] = gpu.thread_id x
1111
// CHECK-NEXT: %[[TIDY:.*]] = gpu.thread_id y
1212
// CHECK-NEXT: %[[TIDZ:.*]] = gpu.thread_id z
13-
// CHECK-NEXT: %[[T0:.*]] = index.mul %[[DIMY]], %[[TIDZ]]
14-
// CHECK-NEXT: %[[T1:.*]] = index.add %[[T0]], %[[TIDY]]
15-
// CHECK-NEXT: %[[T2:.*]] = index.mul %[[DIMX]], %[[T1]]
16-
// CHECK-NEXT: %[[T3:.*]] = index.add %[[TIDX]], %[[T2]]
13+
// CHECK-NEXT: %[[T0:.*]] = arith.muli %[[DIMY]], %[[TIDZ]] : index
14+
// CHECK-NEXT: %[[T1:.*]] = arith.addi %[[T0]], %[[TIDY]] : index
15+
// CHECK-NEXT: %[[T2:.*]] = arith.muli %[[DIMX]], %[[T1]] : index
16+
// CHECK-NEXT: %[[T3:.*]] = arith.addi %[[TIDX]], %[[T2]] : index
1717
// CHECK-NEXT: %[[T4:.*]] = gpu.subgroup_size : index
18-
// CHECK-NEXT: %[[T5:.*]] = index.divu %[[T3]], %[[T4]]
18+
// CHECK-NEXT: %[[T5:.*]] = arith.divui %[[T3]], %[[T4]] : index
1919
%idz = gpu.subgroup_id : index
2020
memref.store %idz, %mem[] : memref<index, 1>
2121
gpu.terminator

0 commit comments

Comments
 (0)