Skip to content

[mlir][GPU] Plumb range information through the NVVM lowerings #107659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 43 additions & 28 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -123,52 +123,67 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_SpecialRegisterOp<mnemonic, traits> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;

// Backwards-compatibility builder for an unspecified range.
let builders = [
OpBuilder<(ins "Type":$resultType), [{
build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
}]>
];
}

//===----------------------------------------------------------------------===//
// Lane index and range
def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">;
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;

//===----------------------------------------------------------------------===//
// Thread index and range
def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;

//===----------------------------------------------------------------------===//
// Block index and range
def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA Cluster index and range
def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;


//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA index and across Cluster dimensions
def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;

//===----------------------------------------------------------------------===//
// Clock registers
Expand Down
54 changes: 38 additions & 16 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand Down Expand Up @@ -209,7 +210,15 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
LLVM::ConstantRangeAttr bounds = nullptr;
if (std::optional<APInt> upperBound = op.getUpperBound())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who is setting the upperbound? I might be missing something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User code - I'll have some tests shortly

bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
else
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize);
Value newOp =
rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we always use kWarpSize = 32 for the laneId? This is a HW constraint, and it hasn't been changed over the years.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for a good observation about the default

// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
Expand Down Expand Up @@ -340,27 +349,40 @@ void mlir::populateGpuSubgroupReduceOpLoweringPattern(

void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
using gpu::index_lowering::IndexKind;
using gpu::index_lowering::IntrType;
populateWithGenerated(patterns);
patterns.add<GPUPrintfOpToVPrintfLowering>(converter);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ThreadIdOp, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>(
converter, IndexKind::Block, IntrType::Id);
patterns.add<
gpu::index_lowering::OpLowering<gpu::BlockDimOp, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>>(
converter, IndexKind::Block, IntrType::Dim);
patterns.add<
gpu::index_lowering::OpLowering<gpu::ClusterIdOp, NVVM::ClusterIdXOp,
NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>,
gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
gpu::index_lowering::OpLowering<
gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>,
gpu::index_lowering::OpLowering<gpu::ClusterDimOp, NVVM::ClusterDimXOp,
NVVM::ClusterDimYOp, NVVM::ClusterDimZOp>,
gpu::index_lowering::OpLowering<gpu::BlockIdOp, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
gpu::index_lowering::OpLowering<gpu::GridDimOp, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>,
GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter);
NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>(
converter, IndexKind::Other, IntrType::Id);
patterns.add<gpu::index_lowering::OpLowering<
gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
patterns.add<gpu::index_lowering::OpLowering<
gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp,
NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>>(
converter, IndexKind::Other, IntrType::Id);
patterns.add<gpu::index_lowering::OpLowering<
gpu::ClusterDimOp, NVVM::ClusterDimXOp, NVVM::ClusterDimYOp,
NVVM::ClusterDimZOp>>(converter, IndexKind::Other, IntrType::Dim);
patterns.add<gpu::index_lowering::OpLowering<
gpu::BlockIdOp, NVVM::BlockIdXOp, NVVM::BlockIdYOp, NVVM::BlockIdZOp>>(
converter, IndexKind::Block, IntrType::Id);
patterns.add<gpu::index_lowering::OpLowering<
gpu::GridDimOp, NVVM::GridDimXOp, NVVM::GridDimYOp, NVVM::GridDimZOp>>(
converter, IndexKind::Grid, IntrType::Dim);
patterns.add<GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
converter);

patterns.add<GPUDynamicSharedMemoryOpLowering>(
converter, NVVM::kSharedMemoryAlignmentBit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"

#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/IntrinsicsNVPTX.h"

using namespace mlir;
Expand Down
34 changes: 30 additions & 4 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ gpu.module @test_module_0 {
%gDimZ = gpu.grid_dim z


// CHECK: = nvvm.read.ptx.sreg.laneid : i32
// CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
// CHECK: = llvm.sext %{{.*}} : i32 to i64
%laneId = gpu.lane_id

Expand Down Expand Up @@ -699,9 +699,21 @@ gpu.module @test_module_32 {
}

gpu.module @test_module_33 {
// CHECK-LABEL: func @kernel_with_block_size()
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
// CHECK-LABEL: func @kernel_with_block_size(
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 32, 4, 2>, nvvm.kernel, nvvm.maxntid = array<i32: 32, 4, 2>}
gpu.func @kernel_with_block_size(%arg0: !llvm.ptr) kernel attributes {known_block_size = array<i32: 32, 4, 2>} {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, you added known_block_size to func.func. So I am wondering is this PR going to work for func.func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that code works generally

// CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
%0 = gpu.thread_id x
// CHECK: = nvvm.read.ptx.sreg.tid.y range <i32, 0, 4> : i32
%1 = gpu.thread_id y
// CHECK: = nvvm.read.ptx.sreg.tid.z range <i32, 0, 2> : i32
%2 = gpu.thread_id z

// Fake usage to prevent dead code elimination
%3 = arith.addi %0, %1 : index
%4 = arith.addi %3, %2 : index
%5 = arith.index_cast %4 : index to i64
llvm.store %5, %arg0 : i64, !llvm.ptr
gpu.return
}
}
Expand Down Expand Up @@ -917,6 +929,20 @@ gpu.module @test_module_48 {
}
}

gpu.module @test_module_49 {
// CHECK-LABEL: func @explicit_id_bounds()
func.func @explicit_id_bounds() -> (index, index, index) {
// CHECK: = nvvm.read.ptx.sreg.tid.x range <i32, 0, 32> : i32
%0 = gpu.thread_id x upper_bound 32
// CHECK: = nvvm.read.ptx.sreg.ntid.x range <i32, 1, 33> : i32
%1 = gpu.block_dim x upper_bound 32
// CHECK: = nvvm.read.ptx.sreg.laneid range <i32, 0, 16> : i32
%2 = gpu.lane_id upper_bound 16

return %0, %1, %2 : index, index, index
}
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/Target/LLVMIR/Import/nvvmir.ll
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ define i32 @nvvm_special_regs() {
%27 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank()
; CHECK: = nvvm.read.ptx.sreg.cluster.nctarank : i32
%28 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank()

; CHECK = nvvm.read.ptx.sreg.tid.x range <0 : i32, 64 : i32> : i32
%29 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
ret i32 %1
}

Expand Down
5 changes: 4 additions & 1 deletion mlir/test/Target/LLVMIR/nvvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ llvm.func @nvvm_special_regs() -> i32 {
%29 = nvvm.read.ptx.sreg.clock : i32
// CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64
%30 = nvvm.read.ptx.sreg.clock64 : i64


// CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%31 = nvvm.read.ptx.sreg.tid.x range <i32, 0, 64> : i32

llvm.return %1 : i32
}

Expand Down
Loading