-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][NVVM] Add inline_ptx
op
#139923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add inline_ptx
op
#139923
Conversation
This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation. Example 1: Read-only Parameters ```mlir nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> () ``` Example 2: Read-only and Write-only Parameters ```mlir %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32 // Lowers to: %0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32 ``` Example 3: Predicate Usage ```mlir nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> () ```
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Guray Ozen (grypp) ChangesThis op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation. Example 1: Read-only Parameters Sets
Example 2: Read-only and Write-only Parameters Sets
Example 3: Predicate Usage Now
Full diff: https://github.com/llvm/llvm-project/pull/139923.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff71f25be..4ba54fa3c1ca7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -236,6 +236,76 @@ foreach index = !range(0, 32) in {
def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
}
+//===----------------------------------------------------------------------===//
+// Inline PTX op definition
+//===----------------------------------------------------------------------===//
+
+def NVVM_InlinePtxOP : NVVM_Op<"inline_ptx",
+ [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
+ AttrSizedOperandSegments]>
+{
+ let summary = "Inline PTX Op";
+ let description = [{This op allows using PTX directly within the NVVM
+ dialect, while greatly simplifying llvm.inline_asm generation. It
+ automatically handles register size selection and sets the correct
+ read/write access for each operand. The operation leverages the
+ `BasicPtxBuilderInterface` to abstract away low-level details of
+ PTX assembly formatting.
+
+ The `predicate` attribute is used to specify a predicate for the
+ PTX instruction.
+
+ Example 1: Read-only Parameters
+ ```mlir
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+
+ // Lowers to:
+ llvm.inline_asm has_side_effects asm_dialect = att
+ "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
+ ```
+
+ Example 2: Read-only and Write-only Parameters
+ ```mlir
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+
+ // Lowers to:
+ %0 = llvm.inline_asm has_side_effects asm_dialect = att
+ "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
+ ```
+
+ Example 3: Predicate Usage
+ ```mlir
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count),
+ predicate = %pred : !llvm.ptr, i32, i1
+
+ // Lowers to:
+ llvm.inline_asm has_side_effects asm_dialect = att
+ "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3
+ : (!llvm.ptr, i32, i1) -> ()
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
+ StrAttr:$ptxCode,
+ PtxPredicate:$predicate);
+
+ let results = (outs Variadic<AnyType>:$writeOnlyArgs);
+
+ let assemblyFormat = [{
+ $ptxCode `(` $readOnlyArgs `)`
+ (`,` `predicate` `=` $predicate^)? attr-dict
+ `:` type(operands)
+ (`->` type($writeOnlyArgs)^)?
+ }];
+
+ let extraClassDefinition = [{
+ std::string $cppClass::getPtx() {
+ StringRef eventName = getPtxCode();
+ return std::string(eventName.data());
+ }
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM approximate op definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index c7a6eca158276..1d9164ac94d76 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -680,3 +680,28 @@ llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
llvm.return
}
+
+
+// -----
+
+llvm.func @init_mbarrier(
+ %barrier_gen : !llvm.ptr,
+ %barrier : !llvm.ptr<3>,
+ %count : i32,
+ %pred : i1) {
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+ // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
+ nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+ llvm.return
+}
+// -----
+
+llvm.func @ex2(%input : f32, %pred : i1) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
+ %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
+ llvm.return
+}
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds a new nvvm.inline_ptx
operation that embeds PTX assembly directly in the NVVM dialect and lowers it to llvm.inline_asm
.
- Introduces the
inline_ptx
op in the Dialect definition (NVVMOps.td
). - Provides C++ hook (
getPtx()
) to retrieve the PTX string. - Adds FileCheck-based tests in
nvvm-to-llvm.mlir
to verify correct lowering.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | Define NVVM_InlinePtxOP and its builder |
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | Add tests for nvvm.inline_ptx lowering |
Comments suppressed due to low confidence (1)
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td:243
- The op class name
NVVM_InlinePtxOP
uses uppercase "OP". Per MLIR naming conventions, it should beNVVM_InlinePtxOp
(capital ‘O’, lowercase ‘p’).
def NVVM_InlinePtxOP : NVVM_Op<"inline_ptx",
The AI bot review seems spot on ;) |
So how does it knows about "Read-only Parameters" and "Read-only and Write-only Parameters" ? |
Input arguments are read-only, and results are write-only — it's pure SSA. For read-only arguments, no special marker is used, but for write-only ones, we automatically prepend There is also ReadWrite symbols. I support them in NVVM dialect, for example It's little bit complicated to support ReadWrite in LLVM. They are supported by In LLVM, we handle read-write symbols differently. We mark them as For example here result is readwrite, and valid asm. But llvm doesn't support that.
Instead, we generate the equivalent as:
This marks |
Co-authored-by: Mehdi Amini <[email protected]>
|
||
// Lowers to: | ||
llvm.inline_asm has_side_effects asm_dialect = att | ||
"mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> () |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
In all these given examples, the mlir looks clean.
The lowered llvm-ir still seems to have llvm.ptr, i32 etc. I think this should reflect what we generate, right?
(may be we could copy-paste from the unit-test below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all these given examples, the mlir looks clean.
What do you mean by clean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In all these given examples, the mlir looks clean.
What do you mean by clean?
I meant the mlir version exactly matches the unit-test but the lowered version still refers to (!llvm.ptr, i32) instead of "l, r"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation. **Example 1: Read-only Parameters** Sets `"l,r"` automatically. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> () ``` **Example 2: Read-only and Write-only Parameters** Sets `=f,f"` automatically. `=` is set because there is store. ``` %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32 // Lowers to: %0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32 ``` **Example 3: Predicate Usage** Now `@$2` is set automatically for predication. ``` nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1 // Lowers to: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> () ``` --------- Co-authored-by: Mehdi Amini <[email protected]>
This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation.
Example 1: Read-only Parameters
Sets
"l,r"
automatically.Example 2: Read-only and Write-only Parameters
Sets
=f,f"
automatically.=
is set because there is store.Example 3: Predicate Usage
Now
@$2
is set automatically for predication.