Skip to content

Commit 9a553d3

Browse files
Wolfram70grypp
andauthored
[MLIR][NVVM] Add NVVMRequiresSM op traits (#126886)
Motivation: Currently, the NVVMOps are not verified against the supported SM architectures. This can manifest as an ISel failure in the NVPTX LLVM backend during CodeGen to PTX ISA. This PR addresses this issue by adding verifier checks for Target-SM architectures in the NVVM Dialect itself, thereby catching the errors early on. Summary: * Parametric traits named `NVVMRequiresSM` and `NVVMRequiresSMa` are added to facilitate the version checks for typical and arch-accelerated versions respectively. * These traits can be attached to any NVVM Op to enable the checks for the particular Op. (example shown below) * An attribute interface called named `TargetAttrVerifyInterface` is added to the GPU dialect which any target attribute seeking to perform target-verification on the module can implement. * The checks are performed by the `NVVMTargetAttr` (implementing the `TargetAttrVerifyInterface` interface) when called from the GPU module verifier where it walks through the module and performs the checks for Ops with the `NVVMRequiresSM` traits. * A few Ops in `NVVMOps.td` have been updated to serve as examples. Example Usage: ``` def NVVM_ReduxOp : NVVM_Op<"redux.sync"> {...} ----> def NVVM_ReduxOp : NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]> {...} def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {...} ----> def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {...} ``` --------- Co-authored-by: Guray Ozen <[email protected]>
1 parent 0f2a469 commit 9a553d3

File tree

16 files changed

+436
-38
lines changed

16 files changed

+436
-38
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
5555
];
5656
}
5757

58+
def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> {
59+
let description = [{
60+
Interface for GPU target attributes that verify the target attribute
61+
of a given GPU module.
62+
}];
63+
let cppNamespace = "::mlir::gpu";
64+
let methods = [
65+
InterfaceMethod<[{
66+
Verifies that the target attribute is valid for the given GPU module.
67+
}], "::mlir::LogicalResult", "verifyTarget",
68+
(ins "::mlir::Operation *":$module)>
69+
];
70+
}
71+
5872
def GPUTargetAttr :
5973
ConfinedAttr<AnyAttr, [PromisedAttrInterface<GPUTargetAttrInterface>]> {
6074
let description = [{

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
14601460
/// Sets the targets of the module.
14611461
void setTargets(ArrayRef<TargetAttrInterface> targets);
14621462
}];
1463+
1464+
let hasVerifier = 1;
14631465
}
14641466

14651467
def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins

mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
5454
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
5555
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
5656

57+
set(LLVM_TARGET_DEFINITIONS NVVMRequiresSMTraits.td)
58+
mlir_tablegen(NVVMRequiresSMTraits.h.inc -gen-op-interface-decls)
59+
mlir_tablegen(NVVMRequiresSMTraits.cpp.inc -gen-op-interface-defs)
60+
add_public_tablegen_target(MLIRNVVMRequiresSMTraitsIncGen)
61+
add_dependencies(mlir-headers MLIRNVVMRequiresSMTraitsIncGen)
62+
5763
add_mlir_dialect(NVVMOps nvvm)
5864
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
5965
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
1616

1717
#include "mlir/Bytecode/BytecodeOpInterface.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1819
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
1920
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21+
#include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.h"
2022
#include "mlir/IR/Dialect.h"
2123
#include "mlir/IR/OpDefinition.h"
2224
#include "mlir/Interfaces/InferIntRangeInterface.h"

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
include "mlir/IR/EnumAttr.td"
1717
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+
include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
2021
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2122
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -138,8 +139,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
138139
let assemblyFormat = "attr-dict `:` type($res)";
139140
}
140141

141-
class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
142-
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
142+
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
143+
NVVM_SpecialRegisterOp<mnemonic,
144+
!listconcat(traits,
145+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
143146
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
144147
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
145148
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -202,7 +205,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
202205

203206
//===----------------------------------------------------------------------===//
204207
// CTA Cluster index and range
205-
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
208+
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
206209
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
207210
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
208211
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -212,16 +215,16 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
212215

213216
//===----------------------------------------------------------------------===//
214217
// CTA index and range within Cluster
215-
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
216-
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
217-
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
218-
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
219-
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
218+
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
219+
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>]>;
220+
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>]>;
221+
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>]>;
222+
def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>]>;
220223
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
221224

222225
//===----------------------------------------------------------------------===//
223226
// CTA index and across Cluster dimensions
224-
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
227+
def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>]>;
225228
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
226229

227230
//===----------------------------------------------------------------------===//
@@ -343,7 +346,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
343346
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
344347

345348
def NVVM_ReduxOp :
346-
NVVM_Op<"redux.sync">,
349+
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
347350
Results<(outs LLVM_Type:$res)>,
348351
Arguments<(ins LLVM_Type:$val,
349352
ReduxKindAttr:$kind,
@@ -392,7 +395,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
392395
}
393396

394397
/// mbarrier.init instruction with shared pointer type
395-
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
398+
def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
396399
Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
397400
string llvmBuilder = [{
398401
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
@@ -614,7 +617,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
614617
let assemblyFormat = "attr-dict";
615618
}
616619

617-
def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
620+
def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>]> {
618621
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
619622

620623
let summary = "Cluster Barrier Relaxed Arrive Op";
@@ -640,7 +643,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
640643
let assemblyFormat = "attr-dict";
641644
}
642645

643-
def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
646+
def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>]> {
644647
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
645648

646649
let summary = "Cluster Barrier Wait Op";
@@ -845,7 +848,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
845848
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
846849

847850
def NVVM_ShflOp :
848-
NVVM_Op<"shfl.sync">,
851+
NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
849852
Results<(outs LLVM_Type:$res)>,
850853
Arguments<(ins I32:$thread_mask,
851854
LLVM_Type:$val,
@@ -2184,7 +2187,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
21842187
}];
21852188
}
21862189

2187-
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
2190+
def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>]>,
21882191
Arguments<(ins
21892192
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
21902193
OptionalAttr<UnitAttr>:$read)> {
@@ -2214,7 +2217,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
22142217
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
22152218
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
22162219
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
2217-
AttrSizedOperandSegments]>,
2220+
AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
22182221
Arguments<(ins LLVM_PointerShared:$dstMem,
22192222
LLVM_AnyPointer:$tmaDescriptor,
22202223
Variadic<I32>:$coordinates,
@@ -2663,7 +2666,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
26632666
// NVVM Wgmma Ops
26642667
//===----------------------------------------------------------------------===//
26652668

2666-
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2669+
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>]> {
26672670
let arguments = (ins);
26682671
let description = [{
26692672
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2677,8 +2680,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
26772680
}];
26782681
}
26792682

2680-
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2681-
Arguments<(ins )> {
2683+
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
26822684
let assemblyFormat = "attr-dict";
26832685
let description = [{
26842686
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2690,7 +2692,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
26902692
}];
26912693
}
26922694

2693-
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned">{
2695+
def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
26942696
let arguments = (ins I64Attr:$group);
26952697
let assemblyFormat = "attr-dict $group";
26962698
let description = [{
@@ -2886,7 +2888,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp
28862888

28872889
def NVVM_MapaOp: NVVM_Op<"mapa",
28882890
[TypesMatchWith<"`res` and `a` should have the same type",
2889-
"a", "res", "$_self">]> {
2891+
"a", "res", "$_self">, NVVMRequiresSM<90>]> {
28902892
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
28912893
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
28922894

@@ -3053,7 +3055,7 @@ def Tcgen05WaitKindAttr :
30533055
let assemblyFormat = "`<` $value `>`";
30543056
}
30553057

3056-
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
3058+
def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>]> {
30573059
let summary = "Tcgen05 alloc operation";
30583060
let description = [{
30593061
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -3083,7 +3085,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
30833085
}];
30843086
}
30853087

3086-
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
3088+
def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>]> {
30873089
let summary = "Tcgen05 dealloc operation";
30883090
let description = [{
30893091
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -3111,7 +3113,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
31113113
}];
31123114
}
31133115

3114-
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
3116+
def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>]> {
31153117
let summary = "Tcgen05 Op to relinquish the right to allocate";
31163118
let description = [{
31173119
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -3134,7 +3136,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
31343136
}];
31353137
}
31363138

3137-
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
3139+
def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>]> {
31383140
let summary = "Tcgen05 fence operations";
31393141
let description = [{
31403142
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -3156,7 +3158,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
31563158
}];
31573159
}
31583160

3159-
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
3161+
def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>]> {
31603162
let summary = "Tcgen05 wait operations";
31613163
let description = [{
31623164
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -3178,7 +3180,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
31783180
}];
31793181
}
31803182

3181-
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
3183+
def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>]> {
31823184
let summary = "Tcgen05 commit operations";
31833185
let description = [{
31843186
The `tcgen05.commit` makes the mbarrier object, specified by
@@ -3216,7 +3218,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
32163218
}];
32173219
}
32183220

3219-
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
3221+
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>]> {
32203222
let summary = "Tcgen05 shift operation";
32213223
let description = [{
32223224
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -3282,7 +3284,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
32823284
let assemblyFormat = "`<` $value `>`";
32833285
}
32843286

3285-
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
3287+
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
32863288
let summary = "Tcgen05 copy operation";
32873289
let description = [{
32883290
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -3352,7 +3354,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
33523354
// NVVM tcgen05.ld Op
33533355
//===----------------------------------------------------------------------===//
33543356

3355-
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
3357+
def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>]> {
33563358
let summary = "tensor memory load instructions";
33573359
let arguments = (ins
33583360
// Attributes
@@ -3442,7 +3444,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
34423444
// NVVM tcgen05.st Op
34433445
//===----------------------------------------------------------------------===//
34443446

3445-
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
3447+
def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
34463448
let summary = "tensor memory store instructions";
34473449
let arguments = (ins
34483450
// Attributes
@@ -3594,7 +3596,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
35943596
// NVVM target attribute.
35953597
//===----------------------------------------------------------------------===//
35963598

3597-
def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
3599+
def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
3600+
[DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
35983601
let description = [{
35993602
GPU target attribute for controlling compilation of NVIDIA targets. All
36003603
parameters decay into default values if not present.
@@ -3621,19 +3624,21 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
36213624
StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
36223625
StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
36233626
OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
3624-
OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
3627+
OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link,
3628+
DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget
36253629
);
36263630
let assemblyFormat = [{
3627-
(`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
3631+
(`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget)^ `>`)?
36283632
}];
36293633
let builders = [
36303634
AttrBuilder<(ins CArg<"int", "2">:$optLevel,
36313635
CArg<"StringRef", "\"nvptx64-nvidia-cuda\"">:$triple,
36323636
CArg<"StringRef", "\"sm_50\"">:$chip,
36333637
CArg<"StringRef", "\"+ptx60\"">:$features,
36343638
CArg<"DictionaryAttr", "nullptr">:$targetFlags,
3635-
CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
3636-
return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
3639+
CArg<"ArrayAttr", "nullptr">:$linkFiles,
3640+
CArg<"bool", "true">:$verifyTarget), [{
3641+
return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget);
36373642
}]>
36383643
];
36393644
let skipDefaultBuilders = 1;
@@ -3644,6 +3649,7 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
36443649
bool hasFtz() const;
36453650
bool hasCmdOptions() const;
36463651
std::optional<mlir::NamedAttribute> getCmdOptions() const;
3652+
LogicalResult verifyTarget(Operation *gpuModule);
36473653
}];
36483654
let extraClassDefinition = [{
36493655
bool $cppClass::hasFlag(StringRef flag) const {

0 commit comments

Comments
 (0)