Skip to content

Commit d08b176

Browse files
gryppjoker-eph
andauthored
[MLIR][NVVM] Add inline_ptx op (#139923)
This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation. **Example 1: Read-only Parameters** Sets `"l,r"` automatically. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> () ``` **Example 2: Read-only and Write-only Parameters** Sets `=f,f"` automatically. `=` is set because there is store. ``` %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32 // Lowers to: %0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32 ``` **Example 3: Predicate Usage** Now `@$2` is set automatically for predication. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> () ``` --------- Co-authored-by: Mehdi Amini <[email protected]>
1 parent 6fc0312 commit d08b176

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,76 @@ foreach index = !range(0, 32) in {
236236
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
237237
}
238238

239+
//===----------------------------------------------------------------------===//
240+
// Inline PTX op definition
241+
//===----------------------------------------------------------------------===//
242+
243+
def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
244+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
245+
AttrSizedOperandSegments]>
246+
{
247+
let summary = "Inline PTX Op";
248+
let description = [{This op allows using PTX directly within the NVVM
249+
dialect, while greatly simplifying llvm.inline_asm generation. It
250+
automatically handles register size selection and sets the correct
251+
read/write access for each operand. The operation leverages the
252+
`BasicPtxBuilderInterface` to abstract away low-level details of
253+
PTX assembly formatting.
254+
255+
The `predicate` attribute is used to specify a predicate for the
256+
PTX instruction.
257+
258+
Example 1: Read-only Parameters
259+
```mlir
260+
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
261+
262+
// Lowers to:
263+
llvm.inline_asm has_side_effects asm_dialect = att
264+
"mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
265+
```
266+
267+
Example 2: Read-only and Write-only Parameters
268+
```mlir
269+
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
270+
271+
// Lowers to:
272+
%0 = llvm.inline_asm has_side_effects asm_dialect = att
273+
"ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
274+
```
275+
276+
Example 3: Predicate Usage
277+
```mlir
278+
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count),
279+
predicate = %pred : !llvm.ptr, i32, i1
280+
281+
// Lowers to:
282+
llvm.inline_asm has_side_effects asm_dialect = att
283+
"@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3
284+
: (!llvm.ptr, i32, i1) -> ()
285+
```
286+
}];
287+
288+
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
289+
StrAttr:$ptxCode,
290+
PtxPredicate:$predicate);
291+
292+
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
293+
294+
let assemblyFormat = [{
295+
$ptxCode `(` $readOnlyArgs `)`
296+
(`,` `predicate` `=` $predicate^)? attr-dict
297+
`:` type(operands)
298+
(`->` type($writeOnlyArgs)^)?
299+
}];
300+
301+
let extraClassDefinition = [{
302+
std::string $cppClass::getPtx() {
303+
StringRef ptxInstStr = getPtxCode();
304+
return std::string(ptxInstStr.data());
305+
}
306+
}];
307+
}
308+
239309
//===----------------------------------------------------------------------===//
240310
// NVVM approximate op definitions
241311
//===----------------------------------------------------------------------===//

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,28 @@ llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
680680
nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
681681
llvm.return
682682
}
683+
684+
685+
// -----
686+
687+
llvm.func @init_mbarrier(
688+
%barrier_gen : !llvm.ptr,
689+
%barrier : !llvm.ptr<3>,
690+
%count : i32,
691+
%pred : i1) {
692+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
693+
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
694+
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
695+
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
696+
llvm.return
697+
}
698+
// -----
699+
700+
llvm.func @ex2(%input : f32, %pred : i1) {
701+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
702+
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
703+
704+
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
705+
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
706+
llvm.return
707+
}

0 commit comments

Comments
 (0)