1
+ //===- BasicPtxBuilderInterface.td - PTX builder interface -*- tablegen -*-===//
2
+ //
3
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ //===----------------------------------------------------------------------===//
8
+ //
9
+ // Defines the interface to build PTX (Parallel Thread Execution) from NVVM Ops
10
+ // automatically. It is used by NVVM to LLVM pass.
11
+ //
12
+ //===----------------------------------------------------------------------===//
13
+
14
+ #ifndef BASICPTXBUILDER_OP_INTERFACE
15
+ #define BASICPTXBUILDER_OP_INTERFACE
16
+
17
+ include "mlir/IR/EnumAttr.td"
18
+ include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
19
+ include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
20
+
21
+ //===----------------------------------------------------------------------===//
22
+ // Basic PTX Builder Interface
23
+ //===----------------------------------------------------------------------===//
24
+
25
+ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
26
+ let description = [{
27
+ This interface is used to generate inline assembly with PTX for basic
28
+ operations. It's utilized in the `convert-nvvm-to-llvm pass` to lower
29
+ NVVM Ops that implement this interface to PTX (parallel thread execution)
30
+ using inline assembly Ops. Interface methods play a crucial role in this
31
+ lowering process.
32
+
33
+ Here's an example of an Op with the `BasicPtxBuilderOpInterface`:
34
+ ```tablegen
35
+ def NVVM_SpecialOp : NVVM_Op<"special.op",
36
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
37
+ Results<(outs LLVM_Type:$res)>,
38
+ Arguments<(ins LLVM_i64ptr_any:$op1, I32:$op2)> {
39
+ ...
40
+ let extraClassDefinition = [{
41
+ std::string $cppClass::getPtx() {
42
+ return std::string("special.op %0, %1, %2;");
43
+ }
44
+ } ];
45
+ ```
46
+
47
+ In the above NVVM Op example:
48
+ ```mlir
49
+ %0 = nvvm.special.op %1, %2 : !llvm.ptr, i32 -> i32
50
+ ```
51
+
52
+ The `convert-nvvm-to-llvm` pass generates the inline assembly like below.
53
+ The order of arguments is retained, and the read and write modifiers are
54
+ set based on the input and result types:
55
+ ```mlir
56
+ %0 = llvm.inline_asm
57
+ has_side_effects
58
+ asm_dialect =
59
+ att "special.op %0, %1, %2;", "=r,l,r" %arg0, %arg1
60
+ : (!llvm.ptr, i32) -> i32
61
+ ```
62
+ }];
63
+ let cppNamespace = "::mlir::NVVM";
64
+ let methods = [
65
+ InterfaceMethod<
66
+ /*desc=*/[{ Returns PTX assembly with operand number. }],
67
+ /*retType=*/"std::string",
68
+ /*methodName=*/"getPtx"
69
+ >,
70
+ InterfaceMethod<
71
+ /*desc=*/[{
72
+ This function indicates whether the operation is supported by LLVM
73
+ intrinsics. It's particularly useful for operations that have
74
+ specific cases with LLVM intrinsic support.
75
+ }],
76
+ /*retType=*/"bool",
77
+ /*methodName=*/"hasIntrinsic",
78
+ /*args=*/(ins),
79
+ /*methodBody=*/"",
80
+ /*defaultImplementation=*/"return false;"
81
+ >,
82
+ InterfaceMethod<
83
+ /*desc=*/[{Return whether the operation has memory side effects.}],
84
+ /*retType=*/"bool",
85
+ /*methodName=*/"hasSideEffect",
86
+ /*args=*/(ins),
87
+ /*methodBody=*/"",
88
+ /*defaultImplementation=*/"return true;"
89
+ >,
90
+
91
+ InterfaceMethod<
92
+ /*desc=*/[{Helper function to generate i32 constant value.}],
93
+ /*retType=*/"::mlir::Value",
94
+ /*methodName=*/"makeConstantI32",
95
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter, "int" : $val),
96
+ /*methodBody=*/"",
97
+ /*defaultImpl=*/ [{
98
+ mlir::Operation* op = $_op;
99
+ return rewriter.create<LLVM::ConstantOp>(
100
+ op->getLoc(), rewriter.getIntegerType(32), val);
101
+ }]
102
+ >,
103
+ InterfaceMethod<
104
+ /*desc=*/[{
105
+ This function supplies the necessary arguments for passing PTX code,
106
+ following this order:
107
+ 1) Adds results
108
+ 2) Adds operands
109
+ 3) Adds attributes
110
+ }],
111
+ /*retType=*/"void",
112
+ /*methodName=*/"getAsmValues",
113
+ /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
114
+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
115
+ /*methodBody=*/"",
116
+ /*defaultImpl=*/ [{
117
+ mlir::Operation* op = $_op;
118
+
119
+ // Step 1. Add results
120
+ for (auto val : op->getResults())
121
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
122
+
123
+ // Step 2. Add operands
124
+ for (auto val : op->getOperands())
125
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
126
+
127
+ // Step 3. Add attributes
128
+ for (auto attr : op->getAttrs()) {
129
+ if (auto intAttr = dyn_cast<mlir::IntegerAttr>(attr.getValue())) {
130
+ ::mlir::Value val = makeConstantI32(rewriter, intAttr.getInt());
131
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
132
+ }
133
+ }
134
+ }]
135
+ >
136
+ ];
137
+ }
138
+
139
+ #endif // BASICPTXBUILDER_OP_INTERFACE
0 commit comments