Skip to content

Commit 574b423

Browse files
committed
[MLIR][NVVM] Introduce special registers for CTA Cluster
This work introduces special registers such as cluster ID, dimensions, and more for managing CTA clusters, which are groups of CTAsthat can synchronize and communicate through shared memory. This is for Nvidia's sm_90 capability. Differential Revision: https://reviews.llvm.org/D158588
1 parent b09a52d commit 574b423

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,30 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
249249
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
250250
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
251251

252+
//===----------------------------------------------------------------------===//
253+
// CTA Cluster index and range
254+
def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
255+
def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
256+
def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
257+
def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
258+
def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
259+
def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
260+
261+
262+
//===----------------------------------------------------------------------===//
263+
// CTA index and range within Cluster
264+
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
265+
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
266+
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
267+
def NVVM_GridInClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
268+
def NVVM_GridInClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
269+
def NVVM_GridInClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
270+
271+
//===----------------------------------------------------------------------===//
272+
// CTA index and across Cluster dimensions
273+
def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
274+
def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
275+
252276
//===----------------------------------------------------------------------===//
253277
// NVVM approximate op definitions
254278
//===----------------------------------------------------------------------===//

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,35 @@ llvm.func @nvvm_special_regs() -> i32 {
3030
%13 = nvvm.read.ptx.sreg.warpsize : i32
3131
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.laneid()
3232
%14 = nvvm.read.ptx.sreg.laneid : i32
33+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.x
34+
%15 = nvvm.read.ptx.sreg.clusterid.x : i32
35+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.y
36+
%16 = nvvm.read.ptx.sreg.clusterid.y : i32
37+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.clusterid.z
38+
%17 = nvvm.read.ptx.sreg.clusterid.z : i32
39+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.x
40+
%18 = nvvm.read.ptx.sreg.nclusterid.x : i32
41+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.y
42+
%19 = nvvm.read.ptx.sreg.nclusterid.y : i32
43+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.nclusterid.z
44+
%20 = nvvm.read.ptx.sreg.nclusterid.z : i32
45+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
46+
%21 = nvvm.read.ptx.sreg.cluster.ctaid.x : i32
47+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
48+
%22 = nvvm.read.ptx.sreg.cluster.ctaid.y : i32
49+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctaid
50+
%23 = nvvm.read.ptx.sreg.cluster.ctaid.z : i32
51+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
52+
%24 = nvvm.read.ptx.sreg.cluster.nctaid.x : i32
53+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
54+
%25 = nvvm.read.ptx.sreg.cluster.nctaid.y : i32
55+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctaid
56+
%26 = nvvm.read.ptx.sreg.cluster.nctaid.z : i32
57+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank
58+
%27 = nvvm.read.ptx.sreg.cluster.ctarank : i32
59+
// CHECK: call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank
60+
%28 = nvvm.read.ptx.sreg.cluster.nctarank : i32
61+
3362
llvm.return %1 : i32
3463
}
3564

0 commit comments

Comments
 (0)