16
16
include "mlir/IR/EnumAttr.td"
17
17
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
18
18
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19
+ include "mlir/Dialect/LLVMIR/NVVMRequiresSMTraits.td"
19
20
include "mlir/Interfaces/SideEffectInterfaces.td"
20
21
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
21
22
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -138,8 +139,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
138
139
let assemblyFormat = "attr-dict `:` type($res)";
139
140
}
140
141
141
- class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
142
- NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
142
+ class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
143
+ NVVM_SpecialRegisterOp<mnemonic,
144
+ !listconcat(traits,
145
+ [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
143
146
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
144
147
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
145
148
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -202,7 +205,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
202
205
203
206
//===----------------------------------------------------------------------===//
204
207
// CTA Cluster index and range
205
- def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
208
+ def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>] >;
206
209
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
207
210
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
208
211
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -212,16 +215,16 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
212
215
213
216
//===----------------------------------------------------------------------===//
214
217
// CTA index and range within Cluster
215
- def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
216
- def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
217
- def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
218
- def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
219
- def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
218
+ def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>] >;
219
+ def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y", [NVVMRequiresSM<90>] >;
220
+ def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z", [NVVMRequiresSM<90>] >;
221
+ def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x", [NVVMRequiresSM<90>] >;
222
+ def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y", [NVVMRequiresSM<90>] >;
220
223
def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
221
224
222
225
//===----------------------------------------------------------------------===//
223
226
// CTA index and across Cluster dimensions
224
- def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
227
+ def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank", [NVVMRequiresSM<90>] >;
225
228
def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
226
229
227
230
//===----------------------------------------------------------------------===//
@@ -343,7 +346,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
343
346
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
344
347
345
348
def NVVM_ReduxOp :
346
- NVVM_Op<"redux.sync">,
349
+ NVVM_Op<"redux.sync", [NVVMRequiresSM<80>] >,
347
350
Results<(outs LLVM_Type:$res)>,
348
351
Arguments<(ins LLVM_Type:$val,
349
352
ReduxKindAttr:$kind,
@@ -392,7 +395,7 @@ def NVVM_MBarrierInitOp : NVVM_PTXBuilder_Op<"mbarrier.init">,
392
395
}
393
396
394
397
/// mbarrier.init instruction with shared pointer type
395
- def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared">,
398
+ def NVVM_MBarrierInitSharedOp : NVVM_PTXBuilder_Op<"mbarrier.init.shared", [NVVMRequiresSM<80>, DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>] >,
396
399
Arguments<(ins LLVM_PointerShared:$addr, I32:$count, PtxPredicate:$predicate)> {
397
400
string llvmBuilder = [{
398
401
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_mbarrier_init_shared, {$addr, $count});
@@ -614,7 +617,7 @@ def NVVM_ClusterArriveOp : NVVM_Op<"cluster.arrive"> {
614
617
let assemblyFormat = "attr-dict";
615
618
}
616
619
617
- def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
620
+ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed", [NVVMRequiresSM<90>] > {
618
621
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
619
622
620
623
let summary = "Cluster Barrier Relaxed Arrive Op";
@@ -640,7 +643,7 @@ def NVVM_ClusterArriveRelaxedOp : NVVM_Op<"cluster.arrive.relaxed"> {
640
643
let assemblyFormat = "attr-dict";
641
644
}
642
645
643
- def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait"> {
646
+ def NVVM_ClusterWaitOp : NVVM_Op<"cluster.wait", [NVVMRequiresSM<90>] > {
644
647
let arguments = (ins OptionalAttr<UnitAttr>:$aligned);
645
648
646
649
let summary = "Cluster Barrier Wait Op";
@@ -845,7 +848,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
845
848
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
846
849
847
850
def NVVM_ShflOp :
848
- NVVM_Op<"shfl.sync">,
851
+ NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>] >,
849
852
Results<(outs LLVM_Type:$res)>,
850
853
Arguments<(ins I32:$thread_mask,
851
854
LLVM_Type:$val,
@@ -2184,7 +2187,7 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
2184
2187
}];
2185
2188
}
2186
2189
2187
- def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
2190
+ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group", [NVVMRequiresSM<90>] >,
2188
2191
Arguments<(ins
2189
2192
ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
2190
2193
OptionalAttr<UnitAttr>:$read)> {
@@ -2214,7 +2217,7 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
2214
2217
def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
2215
2218
NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
2216
2219
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
2217
- AttrSizedOperandSegments]>,
2220
+ AttrSizedOperandSegments, NVVMRequiresSM<90> ]>,
2218
2221
Arguments<(ins LLVM_PointerShared:$dstMem,
2219
2222
LLVM_AnyPointer:$tmaDescriptor,
2220
2223
Variadic<I32>:$coordinates,
@@ -2663,7 +2666,7 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
2663
2666
// NVVM Wgmma Ops
2664
2667
//===----------------------------------------------------------------------===//
2665
2668
2666
- def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2669
+ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned", [NVVMRequiresSMa<[90]>] > {
2667
2670
let arguments = (ins);
2668
2671
let description = [{
2669
2672
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2677,8 +2680,7 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2677
2680
}];
2678
2681
}
2679
2682
2680
- def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2681
- Arguments<(ins )> {
2683
+ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
2682
2684
let assemblyFormat = "attr-dict";
2683
2685
let description = [{
2684
2686
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2690,7 +2692,7 @@ def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2690
2692
}];
2691
2693
}
2692
2694
2693
- def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned"> {
2695
+ def NVVM_WgmmaWaitGroupSyncOp : NVVM_Op<"wgmma.wait.group.sync.aligned", [NVVMRequiresSMa<[90]>]> {
2694
2696
let arguments = (ins I64Attr:$group);
2695
2697
let assemblyFormat = "attr-dict $group";
2696
2698
let description = [{
@@ -2886,7 +2888,7 @@ def NVVM_GriddepcontrolLaunchDependentsOp
2886
2888
2887
2889
def NVVM_MapaOp: NVVM_Op<"mapa",
2888
2890
[TypesMatchWith<"`res` and `a` should have the same type",
2889
- "a", "res", "$_self">]> {
2891
+ "a", "res", "$_self">, NVVMRequiresSM<90> ]> {
2890
2892
let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res);
2891
2893
let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b);
2892
2894
@@ -3053,7 +3055,7 @@ def Tcgen05WaitKindAttr :
3053
3055
let assemblyFormat = "`<` $value `>`";
3054
3056
}
3055
3057
3056
- def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
3058
+ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc", [NVVMRequiresSMa<[100, 101]>] > {
3057
3059
let summary = "Tcgen05 alloc operation";
3058
3060
let description = [{
3059
3061
The `tcgen05.alloc` Op allocates tensor core memory for
@@ -3083,7 +3085,7 @@ def NVVM_Tcgen05AllocOp : NVVM_Op<"tcgen05.alloc"> {
3083
3085
}];
3084
3086
}
3085
3087
3086
- def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
3088
+ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc", [NVVMRequiresSMa<[100, 101]>] > {
3087
3089
let summary = "Tcgen05 dealloc operation";
3088
3090
let description = [{
3089
3091
The `tcgen05.dealloc` Op de-allocates the tensor core memory
@@ -3111,7 +3113,7 @@ def NVVM_Tcgen05DeallocOp : NVVM_Op<"tcgen05.dealloc"> {
3111
3113
}];
3112
3114
}
3113
3115
3114
- def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit"> {
3116
+ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_permit", [NVVMRequiresSMa<[100, 101]>] > {
3115
3117
let summary = "Tcgen05 Op to relinquish the right to allocate";
3116
3118
let description = [{
3117
3119
The `tcgen05.relinquish_alloc_permit` Op specifies that the CTA
@@ -3134,7 +3136,7 @@ def NVVM_Tcgen05RelinquishAllocPermitOp : NVVM_Op<"tcgen05.relinquish_alloc_perm
3134
3136
}];
3135
3137
}
3136
3138
3137
- def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
3139
+ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence", [NVVMRequiresSMa<[100, 101]>] > {
3138
3140
let summary = "Tcgen05 fence operations";
3139
3141
let description = [{
3140
3142
The `tcgen05.fence<before>` orders all prior async tcgen05 operations
@@ -3156,7 +3158,7 @@ def NVVM_Tcgen05FenceOp : NVVM_Op<"tcgen05.fence"> {
3156
3158
}];
3157
3159
}
3158
3160
3159
- def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
3161
+ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait", [NVVMRequiresSMa<[100, 101]>] > {
3160
3162
let summary = "Tcgen05 wait operations";
3161
3163
let description = [{
3162
3164
The `tcgen05.wait<load>` causes the executing thread to block until
@@ -3178,7 +3180,7 @@ def NVVM_Tcgen05WaitOp : NVVM_Op<"tcgen05.wait"> {
3178
3180
}];
3179
3181
}
3180
3182
3181
- def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
3183
+ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit", [NVVMRequiresSMa<[100, 101]>] > {
3182
3184
let summary = "Tcgen05 commit operations";
3183
3185
let description = [{
3184
3186
The `tcgen05.commit` makes the mbarrier object, specified by
@@ -3216,7 +3218,7 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
3216
3218
}];
3217
3219
}
3218
3220
3219
- def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
3221
+ def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift", [NVVMRequiresSMa<[100, 101, 103]>] > {
3220
3222
let summary = "Tcgen05 shift operation";
3221
3223
let description = [{
3222
3224
The `tcgen05.shift` is an asynchronous instruction which initiates
@@ -3282,7 +3284,7 @@ def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05
3282
3284
let assemblyFormat = "`<` $value `>`";
3283
3285
}
3284
3286
3285
- def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
3287
+ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>] > {
3286
3288
let summary = "Tcgen05 copy operation";
3287
3289
let description = [{
3288
3290
Instruction tcgen05.cp initiates an asynchronous copy operation from
@@ -3352,7 +3354,7 @@ def Tcgen05LdStShapeAttr: EnumAttr<NVVM_Dialect, Tcgen05LdStShape, "tcgen05_ldst
3352
3354
// NVVM tcgen05.ld Op
3353
3355
//===----------------------------------------------------------------------===//
3354
3356
3355
- def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
3357
+ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld", [NVVMRequiresSMa<[100, 101]>] > {
3356
3358
let summary = "tensor memory load instructions";
3357
3359
let arguments = (ins
3358
3360
// Attributes
@@ -3442,7 +3444,7 @@ def NVVM_Tcgen05LdOp : NVVM_Op<"tcgen05.ld"> {
3442
3444
// NVVM tcgen05.st Op
3443
3445
//===----------------------------------------------------------------------===//
3444
3446
3445
- def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
3447
+ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>] > {
3446
3448
let summary = "tensor memory store instructions";
3447
3449
let arguments = (ins
3448
3450
// Attributes
@@ -3594,7 +3596,8 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
3594
3596
// NVVM target attribute.
3595
3597
//===----------------------------------------------------------------------===//
3596
3598
3597
- def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
3599
+ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
3600
+ [DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
3598
3601
let description = [{
3599
3602
GPU target attribute for controlling compilation of NVIDIA targets. All
3600
3603
parameters decay into default values if not present.
@@ -3621,19 +3624,21 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
3621
3624
StringRefParameter<"Target chip.", "\"sm_50\"">:$chip,
3622
3625
StringRefParameter<"Target chip features.", "\"+ptx60\"">:$features,
3623
3626
OptionalParameter<"DictionaryAttr", "Target specific flags.">:$flags,
3624
- OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link
3627
+ OptionalParameter<"ArrayAttr", "Files to link to the LLVM module.">:$link,
3628
+ DefaultValuedParameter<"bool", "true", "Perform SM version check on Ops.">:$verifyTarget
3625
3629
);
3626
3630
let assemblyFormat = [{
3627
- (`<` struct($O, $triple, $chip, $features, $flags, $link)^ `>`)?
3631
+ (`<` struct($O, $triple, $chip, $features, $flags, $link, $verifyTarget )^ `>`)?
3628
3632
}];
3629
3633
let builders = [
3630
3634
AttrBuilder<(ins CArg<"int", "2">:$optLevel,
3631
3635
CArg<"StringRef", "\"nvptx64-nvidia-cuda\"">:$triple,
3632
3636
CArg<"StringRef", "\"sm_50\"">:$chip,
3633
3637
CArg<"StringRef", "\"+ptx60\"">:$features,
3634
3638
CArg<"DictionaryAttr", "nullptr">:$targetFlags,
3635
- CArg<"ArrayAttr", "nullptr">:$linkFiles), [{
3636
- return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles);
3639
+ CArg<"ArrayAttr", "nullptr">:$linkFiles,
3640
+ CArg<"bool", "true">:$verifyTarget), [{
3641
+ return Base::get($_ctxt, optLevel, triple, chip, features, targetFlags, linkFiles, verifyTarget);
3637
3642
}]>
3638
3643
];
3639
3644
let skipDefaultBuilders = 1;
@@ -3644,6 +3649,7 @@ def NVVM_TargetAttr : NVVM_Attr<"NVVMTarget", "target"> {
3644
3649
bool hasFtz() const;
3645
3650
bool hasCmdOptions() const;
3646
3651
std::optional<mlir::NamedAttribute> getCmdOptions() const;
3652
+ LogicalResult verifyTarget(Operation *gpuModule);
3647
3653
}];
3648
3654
let extraClassDefinition = [{
3649
3655
bool $cppClass::hasFlag(StringRef flag) const {
0 commit comments