Skip to content

Commit 25da115

Browse files
authored
[mlir][nvvm] Move BasicPtxBuilder Interface to Its Own File (NFC) (#68095)
1 parent f445be9 commit 25da115

File tree

9 files changed

+387
-286
lines changed

9 files changed

+387
-286
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -792,10 +792,10 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
792792
//===----------------------------------------------------------------------===//
793793

794794
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
795-
let summary = "Convert NVVM dialect to LLVM dialect";
795+
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
796796
let description = [{
797-
This pass generates inline assembly for the NVVM ops which is not
798-
implemented in LLVM core.
797+
This pass generates PTX instructions using inline assembly for NVVM
798+
operations implements `BasicPtxBuilderInterface`.
799799
}];
800800
let dependentDialects = [
801801
"NVVM::NVVMDialect",
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10+
// automatically. It is used by NVVM to LLVM pass.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
15+
#define NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
16+
17+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/IR/Value.h"
21+
22+
namespace mlir {
23+
namespace NVVM {
24+
/// Register read/write modifier to build constraint string for PTX inline
25+
/// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#parameters
26+
enum class PTXRegisterMod {
27+
/// Read register with no modifier
28+
Read = 0,
29+
/// Read register with '+' modifier
30+
Write = 2,
31+
/// Read register with '=' modifier.
32+
/// Note that, this is not natively supported by LLVM, but it is possible to
33+
/// set read and write for the same operand.
34+
ReadWrite = 1,
35+
};
36+
} // namespace NVVM
37+
} // namespace mlir
38+
39+
/// Include the generated interface declarations.
40+
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h.inc"
41+
42+
namespace mlir {
43+
44+
namespace NVVM {
45+
46+
/// A class to build PTX assembly automatically. It is used by
47+
/// BasicPtxBuilderInterface.
48+
class PtxBuilder {
49+
// The interface op that is used to build the PTX.
50+
BasicPtxBuilderInterface interfaceOp;
51+
// Rewriter to create new operations.
52+
PatternRewriter &rewriter;
53+
// The operands for the PTX instruction
54+
SmallVector<Value> ptxOperands;
55+
// Register constraints (read, write, readwrite) and register data types
56+
std::string registerConstraints;
57+
58+
bool hasResult = false;
59+
60+
public:
61+
/// Single constructor that only initializes members.
62+
PtxBuilder(Operation *op, PatternRewriter &rewriter)
63+
: interfaceOp(op), rewriter(rewriter) {}
64+
65+
/// Add an operand with the read/write input type.
66+
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
67+
68+
/// Builds the inline assembly Op and returns it. The `insertValue` needs to
69+
/// be called to pass operands before building the PTX.
70+
LLVM::InlineAsmOp build();
71+
72+
/// Shortcut to build the inline assembly Op and replace or erase the original
73+
/// op with
74+
void buildAndReplaceOp();
75+
};
76+
77+
} // namespace NVVM
78+
} // namespace mlir
79+
80+
#endif // NVVM_DIALECT_NVVM_IR_BASICPTXBUILDERINTERFACE_H_
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
//===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10+
// automatically. It is used by NVVM to LLVM pass.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef BASICPTXBUILDER_OP_INTERFACE
15+
#define BASICPTXBUILDER_OP_INTERFACE
16+
17+
include "mlir/IR/EnumAttr.td"
18+
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
19+
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
20+
21+
//===----------------------------------------------------------------------===//
22+
// Basic PTX Builder Interface
23+
//===----------------------------------------------------------------------===//
24+
25+
def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
26+
let description = [{
27+
This interface is used to generate inline assembly with PTX for basic
28+
operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
29+
NVVM Ops that implement this interface to PTX (parallel thread execution)
30+
using inline assembly Ops. Interface methods play a crucial role in this
31+
lowering process.
32+
33+
Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
34+
```tablegen
35+
def NVVM_SpecialOp : NVVM_Op<"special.op",
36+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
37+
Results<(outs LLVM_Type:$res)>,
38+
Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
39+
...
40+
let extraClassDefinition = [{
41+
std::string $cppClass::getPtx() {
42+
return std::string("special.op %0, %1, %2;");
43+
}
44+
} ];
45+
```
46+
47+
In the above NVVM Op example:
48+
```mlir
49+
%0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
50+
```
51+
52+
The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
53+
The order of arguments is retained, and the read and write modifiers are
54+
set based on the input and result types:
55+
```mlir
56+
%0 = llvm.inline_asm
57+
has_side_effects
58+
asm_dialect =
59+
att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
60+
: (!llvm.ptr, i32) -> i32
61+
```
62+
}];
63+
let cppNamespace = "::mlir::NVVM";
64+
let methods = [
65+
InterfaceMethod<
66+
/*desc=*/[{ Returns PTX assembly with operand number. }],
67+
/*retType=*/"std::string",
68+
/*methodName=*/"getPtx"
69+
>,
70+
InterfaceMethod<
71+
/*desc=*/[{
72+
This function indicates whether the operation is supported by LLVM
73+
intrinsics. It's particularly useful for operations that have
74+
specific cases with LLVM intrinsic support.
75+
}],
76+
/*retType=*/"bool",
77+
/*methodName=*/"hasIntrinsic",
78+
/*args=*/(ins),
79+
/*methodBody=*/"",
80+
/*defaultImplementation=*/"return false;"
81+
>,
82+
InterfaceMethod<
83+
/*desc=*/[{Return whether the operation has memory side effects.}],
84+
/*retType=*/"bool",
85+
/*methodName=*/"hasSideEffect",
86+
/*args=*/(ins),
87+
/*methodBody=*/"",
88+
/*defaultImplementation=*/"return true;"
89+
>,
90+
91+
InterfaceMethod<
92+
/*desc=*/[{Helper function to generate i32 constant value.}],
93+
/*retType=*/"::mlir::Value",
94+
/*methodName=*/"makeConstantI32",
95+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
96+
/*methodBody=*/"",
97+
/*defaultImpl=*/ [{
98+
mlir::Operation* op = $_op;
99+
return rewriter.create<LLVM::ConstantOp>(
100+
op->getLoc(), rewriter.getIntegerType(32), val);
101+
}]
102+
>,
103+
InterfaceMethod<
104+
/*desc=*/[{
105+
This function supplies the necessary arguments for passing PTX code,
106+
following this order:
107+
1) Adds results
108+
2) Adds operands
109+
3) Adds attributes
110+
}],
111+
/*retType=*/"void",
112+
/*methodName=*/"getAsmValues",
113+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
114+
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
115+
/*methodBody=*/"",
116+
/*defaultImpl=*/ [{
117+
mlir::Operation* op = $_op;
118+
119+
// Step 1. Add results
120+
for (auto val : op->getResults())
121+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
122+
123+
// Step 2. Add operands
124+
for (auto val : op->getOperands())
125+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
126+
127+
// Step 3. Add attributes
128+
for (auto attr : op->getAttrs()) {
129+
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
130+
::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
131+
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
132+
}
133+
}
134+
}]
135+
>
136+
];
137+
}
138+
139+
#endif // BASICPTXBUILDER_OP_INTERFACE

mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,18 @@ mlir_tablegen(LLVMIntrinsicFromLLVMIRConversions.inc -gen-intr-from-llvmir-conve
4646
mlir_tablegen(LLVMConvertibleLLVMIRIntrinsics.inc -gen-convertible-llvmir-intrinsics)
4747
add_public_tablegen_target(MLIRLLVMIntrinsicConversionsIncGen)
4848

49+
set(LLVM_TARGET_DEFINITIONS BasicPtxBuilderInterface.td)
50+
mlir_tablegen(BasicPtxBuilderInterface.h.inc -gen-op-interface-decls)
51+
mlir_tablegen(BasicPtxBuilderInterface.cpp.inc -gen-op-interface-defs)
52+
add_public_tablegen_target(MLIRBasicPtxBuilderInterfaceIncGen)
53+
add_dependencies(mlir-headers MLIRBasicPtxBuilderInterfaceIncGen)
54+
4955
add_mlir_dialect(NVVMOps nvvm)
5056
add_mlir_doc(NVVMOps NVVMDialect Dialects/ -gen-dialect-doc -dialect=nvvm)
5157
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
5258
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
5359
mlir_tablegen(NVVMOpsEnums.h.inc -gen-enum-decls)
5460
mlir_tablegen(NVVMOpsEnums.cpp.inc -gen-enum-defs)
55-
mlir_tablegen(NVVMOpsInterface.h.inc -gen-op-interface-decls)
56-
mlir_tablegen(NVVMOpsInterface.cpp.inc -gen-op-interface-defs)
5761
mlir_tablegen(NVVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=nvvm)
5862
mlir_tablegen(NVVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=nvvm)
5963
add_public_tablegen_target(MLIRNVVMConversionsIncGen)

mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
1616

1717
#include "mlir/Bytecode/BytecodeOpInterface.h"
18+
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/Dialect.h"
2021
#include "mlir/IR/OpDefinition.h"
@@ -26,8 +27,6 @@
2627
namespace mlir {
2728
namespace NVVM {
2829

29-
#include "mlir/Dialect/LLVMIR/NVVMOpsInterface.h.inc"
30-
3130
/// NVVM memory space identifiers.
3231
enum NVVMMemorySpace {
3332
/// Global memory space identifier.

0 commit comments

Comments
 (0)