Skip to content

[MLIR][NVVM] Add NVVMRequiresSM op traits #126886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
];
}

def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> {
let description = [{
Interface for GPU target attributes that verify the target attribute
of a given GPU module.
}];
let cppNamespace = "::mlir::gpu";
let methods = [
InterfaceMethod<[{
Verifies that the target attribute is valid for the given GPU module.
}], "::mlir::LogicalResult", "verifyTarget",
(ins "::mlir::Operation *":$module)>
];
}

def GPUTargetAttr :
ConfinedAttr<AnyAttr, [PromisedAttrInterface<GPUTargetAttrInterface>]> {
let description = [{
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
/// Sets the targets of the module.
void setTargets(ArrayRef<TargetAttrInterface> targets);
}];

let hasVerifier = 1;
}

def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)

set(LLVM_TARGET_DEFINITIONS NVVMRequiresSMTraits.td)
mlir_tablegen(NVVMRequiresSMTraits.h.inc -gen-op-interface-decls)
mlir_tablegen(NVVMRequiresSMTraits.cpp.inc -gen-op-interface-defs)
add_public_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen)
add_dependencies(mlir-headers MLIRNVVMRequiresSMTraitsIncGen)

add_mlir_dialect(NVVMOps nvvm)
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
Expand Down
78 changes: 42 additions & 36 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
include "mlir/IR/EnumAttr.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
Expand Down Expand Up @@ -138,8 +139,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
let assemblyFormat = "attr-dict `:` type($res)";
}

class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
NVVM_SpecialRegisterOp<mnemonic,
!listconcat(traits,
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
Expand Down Expand Up @@ -202,7 +205,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA Cluster index and range
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
Expand All @@ -212,16 +215,16 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster

//===----------------------------------------------------------------------===//
// CTA index and range within Cluster
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;

//===----------------------------------------------------------------------===//
// CTA index and across Cluster dimensions
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -273,7 +276,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;

def NVVM_ReduxOp :
NVVM_Op<"redux.sync">,
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins LLVM_Type:$val,
ReduxKindAttr:$kind,
Expand Down Expand Up @@ -322,7 +325,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
}

/// mbarrier.init instruction with shared pointer type
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
Expand Down Expand Up @@ -544,7 +547,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
let assemblyFormat = "attr-dict";
}

def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>]> {
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);

let summary = "Cluster Barrier Relaxed Arrive Op";
Expand All @@ -570,7 +573,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
let assemblyFormat = "attr-dict";
}

def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> {
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);

let summary = "Cluster Barrier Wait Op";
Expand Down Expand Up @@ -775,7 +778,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;

def NVVM_ShflOp :
NVVM_Op<"shfl.sync">,
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
Results<(outs LLVM_Type:$res)>,
Arguments<(ins I32:$thread_mask,
LLVM_Type:$val,
Expand Down Expand Up @@ -2114,7 +2117,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
}];
}

def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>]>,
Arguments<(ins
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
OptionalAttr<UnitAttr>:$read)> {
Expand Down Expand Up @@ -2144,7 +2147,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
AttrSizedOperandSegments]>,
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
Arguments<(ins LLVM_PointerShared:$dstMem,
LLVM_AnyPointer:$tmaDescriptor,
Variadic<I32>:$coordinates,
Expand Down Expand Up @@ -2581,7 +2584,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
// NVVM Wgmma Ops
//===----------------------------------------------------------------------===//

def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins);
let description = [{
Enforce an ordering of register accesses between warpgroup level matrix
Expand All @@ -2595,8 +2598,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
}];
}

def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
Arguments<(ins )> {
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let assemblyFormat = "attr-dict";
let description = [{
Commits all prior uncommitted warpgroup level matrix multiplication operations.
Expand All @@ -2608,7 +2610,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
}];
}

def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
let arguments = (ins I64Attr:$group);
let assemblyFormat = "attr-dict $group";
let description = [{
Expand Down Expand Up @@ -2804,7 +2806,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp

def NVVM_MapaOp: NVVM_Op<"mapa",
[TypesMatchWith<"`res` and `a` should have the same type",
"a", "res", "$_self">]> {
"a", "res", "$_self">, NVVMRequiresSM<90>]> {
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);

Expand Down Expand Up @@ -2971,7 +2973,7 @@ def Tcgen05WaitKindAttr :
let assemblyFormat = "`<` $value `>`";
}

def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 alloc operation";
let description = [{
The `tcgen05.alloc` Op allocates tensor core memory for
Expand Down Expand Up @@ -3001,7 +3003,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
}];
}

def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 dealloc operation";
let description = [{
The `tcgen05.dealloc` Op de-allocates the tensor core memory
Expand Down Expand Up @@ -3029,7 +3031,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
}];
}

def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 Op to relinquish the right to allocate";
let description = [{
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
Expand All @@ -3052,7 +3054,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
}];
}

def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 fence operations";
let description = [{
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
Expand All @@ -3074,7 +3076,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
}];
}

def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 wait operations";
let description = [{
The `tcgen05.wait<load>` causes the executing thread to block until
Expand All @@ -3096,7 +3098,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
}];
}

def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 commit operations";
let description = [{
The `tcgen05.commit` makes the mbarrier object, specified by
Expand Down Expand Up @@ -3134,7 +3136,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
}];
}

def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
Expand Down Expand Up @@ -3200,7 +3202,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
let assemblyFormat = "`<` $value `>`";
}

def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
Expand Down Expand Up @@ -3270,7 +3272,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
// NVVM tcgen05.ld Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory load instructions";
let arguments = (ins
// Attributes
Expand Down Expand Up @@ -3360,7 +3362,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
// NVVM tcgen05.st Op
//===----------------------------------------------------------------------===//

def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
let summary = "tensor memory store instructions";
let arguments = (ins
// Attributes
Expand Down Expand Up @@ -3512,7 +3514,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
// NVVM target attribute.
//===----------------------------------------------------------------------===//

def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
[DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
let description = [{
GPU target attribute for controlling compilation of NVIDIA targets. All
parameters decay into default values if not present.
Expand All @@ -3539,19 +3542,21 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link,
DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget
);
let assemblyFormat = [{
(`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
(`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget)^ `>`)?
}];
let builders = [
AttrBuilder<(ins CArg<"int", "2">:$optLevel,
CArg<"StringRef", "\"nvptx64-nvidia-cuda\"">:$triple,
CArg<"StringRef", "\"sm_50\"">:$chip,
CArg<"StringRef", "\"+ptx60\"">:$features,
CArg<"DictionaryAttr", "nullptr">:$targetFlags,
CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
CArg<"ArrayAttr", "nullptr">:$linkFiles,
CArg<"bool", "true">:$verifyTarget), [{
return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget);
}]>
];
let skipDefaultBuilders = 1;
Expand All @@ -3562,6 +3567,7 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
bool hasFtz() const;
bool hasCmdOptions() const;
std::optional<mlir::NamedAttribute> getCmdOptions() const;
LogicalResult verifyTarget(Operation *gpuModule);
}];
let extraClassDefinition = [{
bool $cppClass::hasFlag(StringRef flag) const {
Expand Down
Loading