Skip to content

Commit 74d914d

Browse files
committed
[MLIR][NVVM] Add NVVMRequiresSM op trait
This change adds the NVVMRequiresSM op trait to the NVVM dialect to allow tagging NVVM Ops with a minimum required SM version. When a target SM is able to be determined (through NVVMTargetAttr), this allows the verification of SM compatibility with the Op without needing to unnecessarily lower any further down.
1 parent f796bc6 commit 74d914d

File tree

10 files changed

+223
-8
lines changed

10 files changed

+223
-8
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Bytecode/BytecodeOpInterface.h"
1818
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
1919
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20+
#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
2021
#include "mlir/IR/Dialect.h"
2122
#include "mlir/IR/OpDefinition.h"
2223
#include "mlir/Interfaces/InferIntRangeInterface.h"

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
include "mlir/IR/EnumAttr.td"
1717
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
1818
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
19+
include "mlir/Dialect/LLVMIR/NVVMTraits.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
2021
include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td"
2122
include "mlir/Interfaces/InferIntRangeInterface.td"
@@ -136,8 +137,10 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
136137
let assemblyFormat = "attr-dict `:` type($res)";
137138
}
138139

139-
class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
140-
NVVM_SpecialRegisterOp<mnemonic, [DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
140+
class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
141+
NVVM_SpecialRegisterOp<mnemonic,
142+
!listconcat(traits,
143+
[DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>])> {
141144
let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
142145
let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
143146
let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
@@ -167,14 +170,14 @@ class NVVM_SpecialRangeableRegisterOp<string mnemonic> :
167170
def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
168171
def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
169172
def NVVM_WarpIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpid">;
170-
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid">;
173+
def NVVM_WarpDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nwarpid", [NVVMRequiresSM<20>]>;
171174
def NVVM_SmIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.smid">;
172175
def NVVM_SmDimOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nsmid">;
173176
def NVVM_GridIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.gridid">;
174177

175178
//===----------------------------------------------------------------------===//
176179
// Lane Mask Comparison Ops
177-
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq">;
180+
def NVVM_LaneMaskEqOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.eq", [NVVMRequiresSM<20>]>;
178181
def NVVM_LaneMaskLeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.le">;
179182
def NVVM_LaneMaskLtOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.lt">;
180183
def NVVM_LaneMaskGeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.lanemask.ge">;
@@ -200,7 +203,7 @@ def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
200203

201204
//===----------------------------------------------------------------------===//
202205
// CTA Cluster index and range
203-
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
206+
def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x", [NVVMRequiresSM<90>]>;
204207
def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
205208
def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
206209
def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
@@ -210,7 +213,7 @@ def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ncluster
210213

211214
//===----------------------------------------------------------------------===//
212215
// CTA index and range within Cluster
213-
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
216+
def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x", [NVVMRequiresSM<90>]>;
214217
def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
215218
def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
216219
def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
@@ -269,7 +272,7 @@ def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
269272
def ReduxKindAttr : EnumAttr<NVVM_Dialect, ReduxKind, "redux_kind">;
270273

271274
def NVVM_ReduxOp :
272-
NVVM_Op<"redux.sync">,
275+
NVVM_Op<"redux.sync", [NVVMRequiresSM<80>]>,
273276
Results<(outs LLVM_Type:$res)>,
274277
Arguments<(ins LLVM_Type:$val,
275278
ReduxKindAttr:$kind,
@@ -2327,7 +2330,8 @@ def NVVM_CpAsyncBulkSharedCTAToGlobalOp :
23272330
// NVVM Wgmma Ops
23282331
//===----------------------------------------------------------------------===//
23292332

2330-
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned"> {
2333+
def NVVM_WgmmaFenceAlignedOp : NVVM_Op<"wgmma.fence.aligned",
2334+
[NVVMRequiresSM<90, /*ArchAccelerated*/"true">]> {
23312335
let arguments = (ins);
23322336
let description = [{
23332337
Enforce an ordering of register accesses between warpgroup level matrix
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===--- NVVMTraits.h - NVVM Traits -----------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines op traits for the NVVM Dialect in MLIR
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
14+
#define NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
15+
16+
#include "mlir/IR/OpDefinition.h"
17+
#include "llvm/ADT/StringExtras.h"
18+
19+
namespace mlir {
20+
21+
namespace NVVM {
22+
23+
struct NVVMCheckSMVersion {
24+
int ArchVersion;
25+
bool ArchAccelerated;
26+
std::string ArchString;
27+
28+
NVVMCheckSMVersion() {}
29+
NVVMCheckSMVersion(StringRef SMVersion) : ArchString(SMVersion) {
30+
parse(SMVersion);
31+
}
32+
NVVMCheckSMVersion(int ArchVersion, bool ArchAccelerated)
33+
: ArchVersion(ArchVersion), ArchAccelerated(ArchAccelerated) {
34+
ArchString = (llvm::Twine("sm_") + llvm::Twine(ArchVersion) +
35+
(ArchAccelerated ? "a" : "\0"))
36+
.str();
37+
}
38+
39+
const StringRef getArchString() const { return ArchString; }
40+
41+
void parse(StringRef SMVersion) {
42+
ArchAccelerated = (SMVersion.back() == 'a');
43+
SMVersion.drop_front(3)
44+
.take_while([](char c) { return llvm::isDigit(c); })
45+
.getAsInteger(10, ArchVersion);
46+
}
47+
48+
bool isCompatible(const NVVMCheckSMVersion &TargetSM) const {
49+
// for arch-conditional SMs, they should exactly match to be valid
50+
if (ArchAccelerated || TargetSM.ArchAccelerated)
51+
return (*this) == TargetSM;
52+
53+
return ArchVersion <= TargetSM.ArchVersion;
54+
}
55+
56+
bool operator==(const NVVMCheckSMVersion &Other) const {
57+
return ArchVersion == Other.ArchVersion &&
58+
ArchAccelerated == Other.ArchAccelerated;
59+
}
60+
};
61+
62+
llvm::SmallVector<NVVMCheckSMVersion> getTargetSMVersions(Operation *op);
63+
64+
LogicalResult
65+
verifyOpSMRequirements(Operation *op,
66+
llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
67+
NVVMCheckSMVersion RequiredSMVersion);
68+
} // namespace NVVM
69+
70+
namespace OpTrait {
71+
72+
template <int Version, bool ArchAccelerated = false>
73+
class NVVMRequiresSM {
74+
public:
75+
template <typename ConcreteOp>
76+
class Impl : public OpTrait::TraitBase<
77+
ConcreteOp, NVVMRequiresSM<Version, ArchAccelerated>::Impl> {
78+
public:
79+
static LogicalResult verifyTrait(Operation *op) {
80+
NVVM::NVVMCheckSMVersion RequiredSMVersion(Version, ArchAccelerated);
81+
llvm::SmallVector<NVVM::NVVMCheckSMVersion> TargetSMVersions =
82+
NVVM::getTargetSMVersions(op);
83+
84+
return NVVM::verifyOpSMRequirements(op, TargetSMVersions,
85+
RequiredSMVersion);
86+
}
87+
};
88+
};
89+
} // namespace OpTrait
90+
} // namespace mlir
91+
#endif // NVVM_DIALECT_NVVM_IR_NVVMTRAITS_H_
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===-- NVVMTraits.td - NVVM Traits ------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines traits for the NVVM Dialect in MLIR
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef NVVM_TRAITS
14+
#define NVVM_TRAITS
15+
16+
include "mlir/IR/OpBase.td"
17+
18+
class NVVMRequiresSM<int Version, string ArchAccelerated = "false"> :
19+
ParamNativeOpTrait<"NVVMRequiresSM",
20+
!cast<string>(Version) # "," # ArchAccelerated>;
21+
22+
#endif //NVVM_TRAITS

mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ add_mlir_dialect_library(MLIRNVVMDialect
6060
LINK_LIBS PUBLIC
6161
MLIRIR
6262
MLIRLLVMDialect
63+
MLIRGPUDialect
6364
MLIRSideEffectInterfaces
6465
MLIRInferIntRangeInterface
6566
)

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
2020
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
21+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
2122
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2223
#include "mlir/IR/Builders.h"
2324
#include "mlir/IR/BuiltinAttributes.h"
@@ -1439,6 +1440,36 @@ NVVMTargetAttr::verify(function_ref<InFlightDiagnostic()> emitError,
14391440
return success();
14401441
}
14411442

1443+
//===----------------------------------------------------------------------===//
1444+
// Requires minimum target SM trait helper functions
1445+
//===----------------------------------------------------------------------===//
1446+
llvm::SmallVector<NVVMCheckSMVersion> NVVM::getTargetSMVersions(Operation *op) {
1447+
llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions;
1448+
gpu::GPUModuleOp GPUModule = op->getParentOfType<gpu::GPUModuleOp>();
1449+
if (GPUModule && GPUModule->hasAttr("targets")) {
1450+
ArrayAttr Targets = dyn_cast<ArrayAttr>(GPUModule->getAttr("targets"));
1451+
for (auto Target : Targets) {
1452+
if (auto NVVMTarget = dyn_cast<NVVMTargetAttr>(Target))
1453+
TargetSMVersions.push_back(NVVMCheckSMVersion(NVVMTarget.getChip()));
1454+
}
1455+
}
1456+
return TargetSMVersions;
1457+
}
1458+
1459+
// Helper function to verify the minimum SM requirement of an NVVM Op
1460+
LogicalResult NVVM::verifyOpSMRequirements(
1461+
Operation *op, llvm::SmallVector<NVVMCheckSMVersion> TargetSMVersions,
1462+
NVVMCheckSMVersion RequiredSMVersion) {
1463+
for (auto TargetSMVersion : TargetSMVersions) {
1464+
if (!RequiredSMVersion.isCompatible(TargetSMVersion)) {
1465+
op->emitOpError() << "is not supported on "
1466+
<< TargetSMVersion.getArchString();
1467+
return failure();
1468+
}
1469+
}
1470+
return success();
1471+
}
1472+
14421473
#define GET_OP_CLASSES
14431474
#include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
14441475

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
2+
3+
// Just check these don't emit errors.
4+
gpu.module @check_valid_SM_exact [#nvvm.target<chip = "sm_80">] {
5+
test.nvvm_requires_sm_80
6+
}
7+
8+
gpu.module @check_valid_SM_greater_1 [#nvvm.target<chip = "sm_86">] {
9+
test.nvvm_requires_sm_80
10+
}
11+
12+
gpu.module @check_valid_SM_greater_2 [#nvvm.target<chip = "sm_90">] {
13+
test.nvvm_requires_sm_80
14+
}
15+
16+
gpu.module @check_valid_SM_arch_acc [#nvvm.target<chip = "sm_90a">] {
17+
test.nvvm_requires_sm_90a
18+
}
19+
20+
// -----
21+
22+
gpu.module @check_invalid_SM_lesser_1 [#nvvm.target<chip = "sm_70">] {
23+
// expected-error @below {{is not supported on sm_70}}
24+
test.nvvm_requires_sm_80
25+
}
26+
27+
// -----
28+
29+
gpu.module @check_invalid_SM_lesser_2 [#nvvm.target<chip = "sm_75">] {
30+
// expected-error @below {{is not supported on sm_75}}
31+
test.nvvm_requires_sm_80
32+
}
33+
34+
// -----
35+
36+
gpu.module @check_invalid_SM_arch_acc_1 [#nvvm.target<chip = "sm_90">] {
37+
// expected-error @below {{is not supported on sm_90}}
38+
test.nvvm_requires_sm_90a
39+
}
40+
41+
// -----
42+
43+
gpu.module @check_invalid_SM_arch_acc_2 [#nvvm.target<chip = "sm_80">] {
44+
// expected-error @below {{is not supported on sm_80}}
45+
test.nvvm_requires_sm_90a
46+
}

mlir/test/lib/Dialect/Test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ mlir_target_link_libraries(MLIRTestDialect PUBLIC
8585
MLIRLinalgDialect
8686
MLIRLinalgTransforms
8787
MLIRLLVMDialect
88+
MLIRNVVMDialect
8889
MLIRPass
8990
MLIRPolynomialDialect
9091
MLIRReduce

mlir/test/lib/Dialect/Test/TestOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/DLTI/DLTI.h"
1717
#include "mlir/Dialect/DLTI/Traits.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/LLVMIR/NVVMTraits.h"
1920
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2021
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
2122
#include "mlir/Dialect/Traits.h"

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "TestDialect.td"
1313
include "TestInterfaces.td"
1414
include "mlir/Dialect/DLTI/DLTIBase.td"
1515
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
16+
include "mlir/Dialect/LLVMIR/NVVMTraits.td"
1617
include "mlir/IR/EnumAttr.td"
1718
include "mlir/Interfaces/FunctionInterfaces.td"
1819
include "mlir/IR/OpBase.td"
@@ -2698,6 +2699,22 @@ def TestLinalgFillOp :
26982699
}];
26992700
}
27002701

2702+
//===----------------------------------------------------------------------===//
2703+
// Test NVVM RequiresSM trait.
2704+
//===----------------------------------------------------------------------===//
2705+
2706+
def TestNVVMRequiresSMOp : TEST_Op<"nvvm_requires_sm_80",
2707+
[NVVMRequiresSM<80>]> {
2708+
let arguments = (ins );
2709+
let assemblyFormat = "attr-dict";
2710+
}
2711+
2712+
def TestNVVMRequiresSMArchCondOp : TEST_Op<"nvvm_requires_sm_90a",
2713+
[NVVMRequiresSM<90, "true">]> {
2714+
let arguments = (ins );
2715+
let assemblyFormat = "attr-dict";
2716+
}
2717+
27012718
//===----------------------------------------------------------------------===//
27022719
// Test Ops with Default-Valued String Attributes
27032720
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)