Skip to content

Commit f989ddc

Browse files
committed
[MLIR][NVVM] Add NVVMRequiresSM op trait
This change adds the NVVMRequiresSM op trait to the NVVM dialect to allow tagging NVVM Ops with a minimum required SM version. When a target SM is able to be determined (through NVVMTargetAttr), this allows the verification of SM compatibility with the Op without needing to unnecessarily lower any further down.
1 parent f796bc6 commit f989ddc

File tree

16 files changed

+311
-12
lines changed

16 files changed

+311
-12
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def GPUTargetAttrInterface : AttrInterface<"TargetAttrInterface"> {
5555
];
5656
}
5757

58+
def GPUTargetAttrVerifyInterface : AttrInterface<"TargetAttrVerifyInterface"> {
59+
let description = [{
60+
Interface for GPU target attributes that need to verify the target attribute
61+
for the given GPU module.
62+
}];
63+
let cppNamespace = "::mlir::gpu";
64+
let methods = [
65+
InterfaceMethod<[{
66+
Verifies that the target attribute is valid for the given GPU module.
67+
}], "::mlir::LogicalResult", "verifyTarget",
68+
(ins "::mlir::Operation *":$module)>
69+
];
70+
}
71+
5872
def GPUTargetAttr :
5973
ConfinedAttr<AnyAttr, [PromisedAttrInterface<GPUTargetAttrInterface>]> {
6074
let description = [{

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,8 @@ def GPU_GPUModuleOp : GPU_Op<"module", [
14601460
/// Sets the targets of the module.
14611461
void setTargets(ArrayRef<TargetAttrInterface> targets);
14621462
}];
1463+
1464+
let hasVerifier = 1;
14631465
}
14641466

14651467
def GPU_BinaryOp : GPU_Op<"binary", [Symbol]>, Arguments<(ins

mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
5454
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
5555
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
5656

57+
set(LLVM_TARGET_DEFINITIONS NVVMTraits.td)
58+
mlir_tablegen(NVVMTraits.h.inc -gen-op-interface-decls)
59+
mlir_tablegen(NVVMTraits.cpp.inc -gen-op-interface-defs)
60+
add_public_tablegen_target(MLIRNVVMTraitsIncGen)
61+
add_dependencies(mlir-headers MLIRNVVMTraitsIncGen)
62+
5763
add_mlir_dialect(NVVMOps nvvm)
5864
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
5965
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
#include "mlir/Bytecode/BytecodeOpInterface.h"
1818
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
21+
#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
2022
#include "mlir/IR/Dialect.h"
2123
#include "mlir/IR/OpDefinition.h"
2224
#include "mlir/Interfaces/InferIntRangeInterface.h"

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
include "mlir/IR/EnumAttr.td"
1717
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+
include "mlir/Dialect/LLVMIR/NVVMTraits.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
2021
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2122
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
136137
let assemblyFormat = "attr-dict `:` type($res)";
137138
}
138139

139-
class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
140-
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
140+
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
141+
NVVM_SpecialRegisterOp<mnemonic,
142+
!listconcat(traits,
143+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
141144
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
142145
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
143146
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
167170
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
168171
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
169172
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
170-
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
173+
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
171174
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
172175
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
173176
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
174177

175178
//===----------------------------------------------------------------------===//
176179
// Lane Mask Comparison Ops
177-
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
180+
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
178181
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
179182
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
180183
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
200203

201204
//===----------------------------------------------------------------------===//
202205
// CTA Cluster index and range
203-
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
206+
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
204207
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
205208
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
206209
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
210213

211214
//===----------------------------------------------------------------------===//
212215
// CTA index and range within Cluster
213-
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
216+
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
214217
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
215218
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
216219
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
269272
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
270273

271274
def NVVM_ReduxOp :
272-
NVVM_Op<"redux.sync">,
275+
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
273276
Results<(outs LLVM_Type:$res)>,
274277
Arguments<(ins LLVM_Type:$val,
275278
ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
23272330
// NVVM Wgmma Ops
23282331
//===----------------------------------------------------------------------===//
23292332

2330-
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2333+
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
2334+
[NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
23312335
let arguments = (ins);
23322336
let description = [{
23332337
Enforce an ordering of register accesses between warpgroup level matrix
@@ -2341,8 +2345,8 @@ def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
23412345
}];
23422346
}
23432347

2344-
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned">,
2345-
Arguments<(ins )> {
2348+
def NVVM_WgmmaGroupSyncAlignedOp : NVVM_Op<"wgmma.commit.group.sync.aligned",
2349+
[NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
23462350
let assemblyFormat = "attr-dict";
23472351
let description = [{
23482352
Commits all prior uncommitted warpgroup level matrix multiplication operations.
@@ -2814,7 +2818,8 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
28142818
// NVVM target attribute.
28152819
//===----------------------------------------------------------------------===//
28162820

2817-
def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
2821+
def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target",
2822+
[DeclareAttrInterfaceMethods<GPUTargetAttrVerifyInterface>]> {
28182823
let description = [{
28192824
GPU target attribute for controlling compilation of NVIDIA targets. All
28202825
parameters decay into default values if not present.
@@ -2862,6 +2867,9 @@ def NVVM_TargettAttr : NVVM_Attr<"NVVMTarget", "target"> {
28622867
bool hasFlag(StringRef flag) const;
28632868
bool hasFastMath() const;
28642869
bool hasFtz() const;
2870+
bool hasCmdOptions() const;
2871+
std::optional<mlir::NamedAttribute> getCmdOptions() const;
2872+
LogicalResult verifyTarget(Operation *gpuModule);
28652873
}];
28662874
let extraClassDefinition = [{
28672875
bool $cppClass::hasFlag(StringRef flag) const {
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines op traits for the NVVM Dialect in MLIR
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
14+
#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
15+
16+
#include "mlir/IR/OpDefinition.h"
17+
#include "mlir/IR/StorageUniquerSupport.h"
18+
#include "llvm/ADT/StringExtras.h"
19+
20+
namespace mlir {
21+
22+
namespace NVVM {
23+
24+
struct NVVMCheckSMVersion {
25+
int archVersion;
26+
bool archAccelerated;
27+
std::string archString;
28+
29+
NVVMCheckSMVersion() {}
30+
NVVMCheckSMVersion(StringRef SMVersion) : archString(SMVersion) {
31+
parse(SMVersion);
32+
}
33+
NVVMCheckSMVersion(int archVersion, bool archAccelerated)
34+
: archVersion(archVersion), archAccelerated(archAccelerated) {
35+
archString = (llvm::Twine("sm_") + llvm::Twine(archVersion) +
36+
(archAccelerated ? "a" : "\0"))
37+
.str();
38+
}
39+
40+
const StringRef getArchString() const { return archString; }
41+
42+
void parse(StringRef SMVersion) {
43+
archAccelerated = (SMVersion.back() == 'a');
44+
SMVersion.drop_front(3)
45+
.take_while([](char c) { return llvm::isDigit(c); })
46+
.getAsInteger(10, archVersion);
47+
}
48+
49+
bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
50+
// for arch-conditional SMs, they should exactly match to be valid
51+
if (archAccelerated || TargetSM.archAccelerated)
52+
return (*this) == TargetSM;
53+
54+
return archVersion <= TargetSM.archVersion;
55+
}
56+
57+
bool operator==(const NVVMCheckSMVersion &Other) const {
58+
return archVersion == Other.archVersion &&
59+
archAccelerated == Other.archAccelerated;
60+
}
61+
};
62+
} // namespace NVVM
63+
} // namespace mlir
64+
65+
#include "mlir/Dialect/LLVMIR/NVVMTraits.h.inc"
66+
67+
namespace mlir {
68+
69+
namespace OpTrait {
70+
71+
template <int Version, bool ArchAccelerated = false>
72+
class NVVMRequiresSM {
73+
public:
74+
template <typename ConcreteOp>
75+
class Impl : public OpTrait::TraitBase<
76+
ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl>,
77+
public mlir::NVVM::RequiresSMInterface::Trait<ConcreteOp> {
78+
public:
79+
const NVVM::NVVMCheckSMVersion getRequiredMinSMVersion() const {
80+
return NVVM::NVVMCheckSMVersion(Version, ArchAccelerated);
81+
}
82+
};
83+
};
84+
} // namespace OpTrait
85+
} // namespace mlir
86+
#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines traits for the NVVM Dialect in MLIR
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef NVVM_TRAITS
14+
#define NVVM_TRAITS
15+
16+
include "mlir/IR/OpBase.td"
17+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
18+
19+
// Interface for NVVM Ops with the NVVMRequiresSM parametric trait
20+
def RequiresSMInterface: OpInterface<"RequiresSMInterface"> {
21+
let cppNamespace = "::mlir::NVVM";
22+
let methods = [
23+
InterfaceMethod<
24+
"Get the SM version required by the op from the trait",
25+
"const mlir::NVVM::NVVMCheckSMVersion", "getRequiredMinSMVersion"
26+
>
27+
];
28+
}
29+
30+
class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
31+
ParamNativeOpTrait<"NVVMRequiresSM",
32+
!cast<string>(Version) # "," # ArchAccelerated>;
33+
34+
#endif //NVVM_TRAITS

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,22 @@ void GPUModuleOp::setTargets(ArrayRef<TargetAttrInterface> targets) {
17931793
targetsAttr = ArrayAttr::get(getContext(), targetsVector);
17941794
}
17951795

1796+
LogicalResult GPUModuleOp::verify() {
1797+
auto targets = getOperation()->getAttrOfType<ArrayAttr>("targets");
1798+
1799+
if (!targets)
1800+
return success();
1801+
1802+
for (auto target : targets) {
1803+
if (auto verifyTargetAttr =
1804+
llvm::dyn_cast<TargetAttrVerifyInterface>(target)) {
1805+
if (verifyTargetAttr.verifyTarget(getOperation()).failed())
1806+
return failure();
1807+
}
1808+
}
1809+
return success();
1810+
}
1811+
17961812
//===----------------------------------------------------------------------===//
17971813
// GPUBinaryOp
17981814
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRLLVMDialect
4242
add_mlir_dialect_library(MLIRNVVMDialect
4343
IR/NVVMDialect.cpp
4444
IR/BasicPtxBuilderInterface.cpp
45+
IR/NVVMTraits.cpp
4546

4647
ADDITIONAL_HEADER_DIRS
4748
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
@@ -51,6 +52,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
5152
MLIRNVVMOpsIncGen
5253
MLIRNVVMConversionsIncGen
5354
MLIRBasicPtxBuilderInterfaceIncGen
55+
MLIRNVVMTraitsIncGen
5456
intrinsics_gen
5557

5658
LINK_COMPONENTS
@@ -60,6 +62,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
6062
LINK_LIBS PUBLIC
6163
MLIRIR
6264
MLIRLLVMDialect
65+
MLIRGPUDialect
6366
MLIRSideEffectInterfaces
6467
MLIRInferIntRangeInterface
6568
)

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
2020
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2122
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2223
#include "mlir/IR/Builders.h"
2324
#include "mlir/IR/BuiltinAttributes.h"
@@ -1439,6 +1440,25 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
14391440
return success();
14401441
}
14411442

1443+
LogicalResult NVVMTargetAttr::verifyTarget(Operation *gpuModule) {
1444+
auto gpuModuleOp = llvm::dyn_cast<gpu::GPUModuleOp>(gpuModule);
1445+
if (!gpuModuleOp)
1446+
return emitError(gpuModule->getLoc(),
1447+
"NVVM target attribute must be attached to a GPU module");
1448+
gpuModuleOp->walk([&](Operation *op) {
1449+
if (auto reqOp = llvm::dyn_cast<NVVM::RequiresSMInterface>(op)) {
1450+
auto requirement = reqOp.getRequiredMinSMVersion();
1451+
if (!requirement.isCompatible(NVVMCheckSMVersion(getChip()))) {
1452+
op->emitOpError() << "is not supported on " << getChip();
1453+
return WalkResult::interrupt();
1454+
}
1455+
}
1456+
return WalkResult::advance();
1457+
});
1458+
1459+
return success();
1460+
}
1461+
14421462
#define GET_OP_CLASSES
14431463
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
14441464

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines op traits for the NVVM Dialect in MLIR
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
14+
15+
#include "mlir/Dialect/LLVMIR/NVVMTraits.cpp.inc"

0 commit comments

Comments
 (0)