Skip to content

Commit a953982

Browse files
authored
[mlir][GPU] Plumb range information through the NVVM lowerings (#107659)
Update the GPU to NVVM lowerings to correctly propagate range information on IDs and dimension queries, etiher from known_{block,grid}_size attributes or from `upperBound` annotations on the operations themselves.
1 parent 4ef16e3 commit a953982

File tree

6 files changed

+119
-49
lines changed

6 files changed

+119
-49
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -123,52 +123,67 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
123123
let assemblyFormat = "attr-dict `:` type($res)";
124124
}
125125

126+
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
127+
NVVM_SpecialRegisterOp<mnemonic, traits> {
128+
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
129+
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
130+
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
131+
let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
132+
133+
// Backwards-compatibility builder for an unspecified range.
134+
let builders = [
135+
OpBuilder<(ins "Type":$resultType), [{
136+
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
137+
}]>
138+
];
139+
}
140+
126141
//===----------------------------------------------------------------------===//
127142
// Lane index and range
128-
def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">;
129-
def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">;
143+
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
144+
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
130145

131146
//===----------------------------------------------------------------------===//
132147
// Thread index and range
133-
def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
134-
def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
135-
def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
136-
def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
137-
def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
138-
def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
148+
def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
149+
def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
150+
def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
151+
def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
152+
def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
153+
def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
139154

140155
//===----------------------------------------------------------------------===//
141156
// Block index and range
142-
def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
143-
def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
144-
def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
145-
def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
146-
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
147-
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
157+
def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
158+
def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
159+
def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
160+
def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
161+
def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
162+
def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
148163

149164
//===----------------------------------------------------------------------===//
150165
// CTA Cluster index and range
151-
def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
152-
def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
153-
def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
154-
def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
155-
def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
156-
def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
166+
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
167+
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
168+
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
169+
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
170+
def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
171+
def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
157172

158173

159174
//===----------------------------------------------------------------------===//
160175
// CTA index and range within Cluster
161-
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
162-
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
163-
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
164-
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
165-
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
166-
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
176+
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
177+
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
178+
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
179+
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
180+
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
181+
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
167182

168183
//===----------------------------------------------------------------------===//
169184
// CTA index and across Cluster dimensions
170-
def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
171-
def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
185+
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
186+
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
172187

173188
//===----------------------------------------------------------------------===//
174189
// Clock registers

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
3030
#include "mlir/Dialect/Math/IR/Math.h"
3131
#include "mlir/Dialect/MemRef/IR/MemRef.h"
32+
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
3233
#include "mlir/Transforms/DialectConversion.h"
3334
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3435

@@ -209,7 +210,15 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
209210
ConversionPatternRewriter &rewriter) const override {
210211
auto loc = op->getLoc();
211212
MLIRContext *context = rewriter.getContext();
212-
Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
213+
LLVM::ConstantRangeAttr bounds = nullptr;
214+
if (std::optional<APInt> upperBound = op.getUpperBound())
215+
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
216+
/*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
217+
else
218+
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
219+
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
220+
Value newOp =
221+
rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
213222
// Truncate or extend the result depending on the index bitwidth specified
214223
// by the LLVMTypeConverter options.
215224
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
@@ -340,27 +349,40 @@ void mlir::populateGpuSubgroupReduceOpLoweringPattern(
340349

341350
void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
342351
RewritePatternSet &patterns) {
352+
using gpu::index_lowering::IndexKind;
353+
using gpu::index_lowering::IntrType;
343354
populateWithGenerated(patterns);
344355
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
345356
patterns.add<
346357
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
347-
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
358+
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
359+
converter, IndexKind::Block, IntrType::Id);
360+
patterns.add<
348361
gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
349-
NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
362+
NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
363+
converter, IndexKind::Block, IntrType::Dim);
364+
patterns.add<
350365
gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
351-
NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
352-
gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
353-
NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
354-
gpu::index_lowering::OpLowering<
355-
gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
356-
NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>,
357-
gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
358-
NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
359-
gpu::index_lowering::OpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
360-
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
361-
gpu::index_lowering::OpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
362-
NVVM::GridDimYOp, NVVM::GridDimZOp>,
363-
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
366+
NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
367+
converter, IndexKind::Other, IntrType::Id);
368+
patterns.add<gpu::index_lowering::OpLowering<
369+
gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
370+
NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
371+
patterns.add<gpu::index_lowering::OpLowering<
372+
gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
373+
NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
374+
converter, IndexKind::Other, IntrType::Id);
375+
patterns.add<gpu::index_lowering::OpLowering<
376+
gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
377+
NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
378+
patterns.add<gpu::index_lowering::OpLowering<
379+
gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
380+
converter, IndexKind::Block, IntrType::Id);
381+
patterns.add<gpu::index_lowering::OpLowering<
382+
gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
383+
converter, IndexKind::Grid, IntrType::Dim);
384+
patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
385+
converter);
364386

365387
patterns.add<GPUDynamicSharedMemoryOpLowering>(
366388
converter, NVVM::kSharedMemoryAlignmentBit);

mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
1515
#include "mlir/Target/LLVMIR/ModuleImport.h"
1616

17+
#include "llvm/IR/ConstantRange.h"
1718
#include "llvm/IR/IntrinsicsNVPTX.h"
1819

1920
using namespace mlir;

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ gpu.module @test_module_0 {
5050
%gDimZ = gpu.grid_dim z
5151

5252

53-
// CHECK: = nvvm.read.ptx.sreg.laneid : i32
53+
// CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
5454
// CHECK: = llvm.sext %{{.*}} : i32 to i64
5555
%laneId = gpu.lane_id
5656

@@ -699,9 +699,21 @@ gpu.module @test_module_32 {
699699
}
700700

701701
gpu.module @test_module_33 {
702-
// CHECK-LABEL: func @kernel_with_block_size()
703-
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
704-
gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
702+
// CHECK-LABEL: func @kernel_with_block_size(
703+
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 32, 4, 2>, nvvm.kernel, nvvm.maxntid = array<i32: 32, 4, 2>}
704+
gpu.func @kernel_with_block_size(%arg0: !llvm.ptr) kernel attributes {known_block_size = array<i32: 32, 4, 2>} {
705+
// CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
706+
%0 = gpu.thread_id x
707+
// CHECK: = nvvm.read.ptx.sreg.tid.y range <i32, 0, 4> : i32
708+
%1 = gpu.thread_id y
709+
// CHECK: = nvvm.read.ptx.sreg.tid.z range <i32, 0, 2> : i32
710+
%2 = gpu.thread_id z
711+
712+
// Fake usage to prevent dead code elimination
713+
%3 = arith.addi %0, %1 : index
714+
%4 = arith.addi %3, %2 : index
715+
%5 = arith.index_cast %4 : index to i64
716+
llvm.store %5, %arg0 : i64, !llvm.ptr
705717
gpu.return
706718
}
707719
}
@@ -917,6 +929,20 @@ gpu.module @test_module_48 {
917929
}
918930
}
919931

932+
gpu.module @test_module_49 {
933+
// CHECK-LABEL: func @explicit_id_bounds()
934+
func.func @explicit_id_bounds() -> (index, index, index) {
935+
// CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
936+
%0 = gpu.thread_id x upper_bound 32
937+
// CHECK: = nvvm.read.ptx.sreg.ntid.x range <i32, 1, 33> : i32
938+
%1 = gpu.block_dim x upper_bound 32
939+
// CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 16> : i32
940+
%2 = gpu.lane_id upper_bound 16
941+
942+
return %0, %1, %2 : index, index, index
943+
}
944+
}
945+
920946
module attributes {transform.with_named_sequence} {
921947
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
922948
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module

mlir/test/Target/LLVMIR/Import/nvvmir.ll

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ define i32 @nvvm_special_regs() {
5858
%27 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank()
5959
; CHECK: = nvvm.read.ptx.sreg.cluster.nctarank : i32
6060
%28 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank()
61+
62+
; CHECK = nvvm.read.ptx.sreg.tid.x range <0 : i32, 64 : i32> : i32
63+
%29 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
6164
ret i32 %1
6265
}
6366

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ llvm.func @nvvm_special_regs() -> i32 {
6262
%29 = nvvm.read.ptx.sreg.clock : i32
6363
// CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
6464
%30 = nvvm.read.ptx.sreg.clock64 : i64
65-
65+
66+
// CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
67+
%31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32
68+
6669
llvm.return %1 : i32
6770
}
6871

0 commit comments

Comments
 (0)