Skip to content

[MLIR][NVVM] Add inline_ptx op #139923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 15, 2025
Merged

[MLIR][NVVM] Add inline_ptx op #139923

merged 4 commits into from
May 15, 2025

Conversation

grypp
Copy link
Member

@grypp grypp commented May 14, 2025

This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation.

Example 1: Read-only Parameters

Sets "l,r" automatically.

nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att
      "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()

Example 2: Read-only and Write-only Parameters

Sets =f,f" automatically. = is set because there is store.

%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32

// Lowers to:
%0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32

Example 3: Predicate Usage

Now @$2 is set automatically for predication.

nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> ()

This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation.

Example 1: Read-only Parameters
    ```mlir
    nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32

    // Lowers to:
    llvm.inline_asm has_side_effects asm_dialect = att
      "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
    ```

    Example 2: Read-only and Write-only Parameters
    ```mlir
    %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32

    // Lowers to:
    %0 = llvm.inline_asm has_side_effects asm_dialect = att
      "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
    ```

    Example 3: Predicate Usage
    ```mlir
    nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count),
      predicate = %pred : !llvm.ptr, i32, i1

    // Lowers to:
    llvm.inline_asm has_side_effects asm_dialect = att
      "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3
      : (!llvm.ptr, i32, i1) -> ()
    ```
@llvmbot
Copy link
Member

llvmbot commented May 14, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

Changes

This op allows using PTX directly within the NVVM dialect, while greatly simplifying llvm.inline_asm generation.

Example 1: Read-only Parameters

Sets "l,r" automatically.

nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att
      "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()

Example 2: Read-only and Write-only Parameters

Sets =f,f" automatically. = is set because there is store.

%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32

// Lowers to:
%0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32

Example 3: Predicate Usage

Now @$2 is set automatically for predication.

nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> ()

Full diff: https://github.com/llvm/llvm-project/pull/139923.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+70)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 654aff71f25be..4ba54fa3c1ca7 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -236,6 +236,76 @@ foreach index = !range(0, 32) in {
   def NVVM_EnvReg # index # Op : NVVM_SpecialRegisterOp<"read.ptx.sreg.envreg" # index>;
 }
 
+//===----------------------------------------------------------------------===//
+// Inline PTX op definition
+//===----------------------------------------------------------------------===//
+
+def NVVM_InlinePtxOP : NVVM_Op<"inline_ptx", 
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+    AttrSizedOperandSegments]>
+{
+  let summary = "Inline PTX Op";
+  let description = [{This op allows using PTX directly within the NVVM 
+    dialect, while greatly simplifying llvm.inline_asm generation. It 
+    automatically handles register size selection and sets the correct 
+    read/write access for each operand. The operation leverages the 
+    `BasicPtxBuilderInterface` to abstract away low-level details of 
+    PTX assembly formatting.
+
+    The `predicate` attribute is used to specify a predicate for the 
+    PTX instruction.
+
+    Example 1: Read-only Parameters
+    ```mlir
+    nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+
+    // Lowers to:
+    llvm.inline_asm has_side_effects asm_dialect = att 
+      "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
+    ```
+
+    Example 2: Read-only and Write-only Parameters
+    ```mlir
+    %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+
+    // Lowers to:
+    %0 = llvm.inline_asm has_side_effects asm_dialect = att 
+      "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
+    ```
+
+    Example 3: Predicate Usage
+    ```mlir
+    nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), 
+      predicate = %pred : !llvm.ptr, i32, i1
+
+    // Lowers to:
+    llvm.inline_asm has_side_effects asm_dialect = att 
+      "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 
+      : (!llvm.ptr, i32, i1) -> ()
+    ```
+  }];
+
+  let arguments = (ins Variadic<AnyType>:$readOnlyArgs, 
+                       StrAttr:$ptxCode,
+                       PtxPredicate:$predicate);
+                      
+  let results = (outs Variadic<AnyType>:$writeOnlyArgs);
+
+  let assemblyFormat = [{ 
+    $ptxCode `(` $readOnlyArgs `)` 
+    (`,` `predicate` `=` $predicate^)? attr-dict 
+    `:` type(operands) 
+    (`->` type($writeOnlyArgs)^)?
+  }];
+  
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
+      StringRef eventName = getPtxCode();
+      return std::string(eventName.data());
+    }
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM approximate op definitions
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index c7a6eca158276..1d9164ac94d76 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -680,3 +680,28 @@ llvm.func @llvm_nvvm_barrier_arrive(%barID : i32, %numberOfThreads : i32) {
   nvvm.barrier.arrive id = %barID number_of_threads = %numberOfThreads
   llvm.return
 }
+
+
+// -----
+
+llvm.func @init_mbarrier(
+    %barrier_gen : !llvm.ptr, 
+    %barrier : !llvm.ptr<3>, 
+    %count : i32, 
+    %pred : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" 
+  nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count)  : !llvm.ptr, i32 
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
+  nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+  llvm.return
+}
+// -----
+
+llvm.func @ex2(%input : f32, %pred : i1) {
+  // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
+  %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+  
+  // CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
+  %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred  : f32, i1 -> f32
+  llvm.return
+}
\ No newline at end of file

@grypp grypp changed the title [MLIR]NVVM] Add inline_ptx op [MLIR][NVVM] Add inline_ptx op May 14, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds a new nvvm.inline_ptx operation that embeds PTX assembly directly in the NVVM dialect and lowers it to llvm.inline_asm.

  • Introduces the inline_ptx op in the Dialect definition (NVVMOps.td).
  • Provides C++ hook (getPtx()) to retrieve the PTX string.
  • Adds FileCheck-based tests in nvvm-to-llvm.mlir to verify correct lowering.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td Define NVVM_InlinePtxOP and its builder
mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir Add tests for nvvm.inline_ptx lowering
Comments suppressed due to low confidence (1)

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td:243

  • The op class name NVVM_InlinePtxOP uses uppercase "OP". Per MLIR naming conventions, it should be NVVM_InlinePtxOp (capital ‘O’, lowercase ‘p’).
def NVVM_InlinePtxOP : NVVM_Op<"inline_ptx",

@joker-eph
Copy link
Collaborator

The AI bot review seems spot on ;)

@joker-eph
Copy link
Collaborator

So how does it knows about "Read-only Parameters" and "Read-only and Write-only Parameters" ?

@grypp
Copy link
Member Author

grypp commented May 15, 2025

So how does it knows about "Read-only Parameters" and "Read-only and Write-only Parameters" ?

Input arguments are read-only, and results are write-only — it's pure SSA. For read-only arguments, no special marker is used, but for write-only ones, we automatically prepend =.

There is also ReadWrite symbols. I support them in NVVM dialect, for example wgmma.mma_async. But I didn't do this for this PR.

It's little bit complicated to support ReadWrite in LLVM. They are supported by +, but this isn't supported by LLVM, I guess because it doesn't align with SSA semantics.

In LLVM, we handle read-write symbols differently. We mark them as =, indicating write access, and then explicitly map the same symbol as a read operand in the input list.

For example here result is readwrite, and valid asm. But llvm doesn't support that.

asm ("OPCODE %0, %1, %2" : "+f" (result));

Instead, we generate the equivalent as:

asm ("OPCODE %0, %1, %2" : "=f" (result) : "0" (result));

This marks result as both written (=f) and read ("0").


// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att
"mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
In all these given examples, the mlir looks clean.
The lowered llvm-ir still seems to have llvm.ptr, i32 etc. I think this should reflect what we generate, right?
(may be we could copy-paste from the unit-test below).

Copy link
Member Author

@grypp grypp May 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all these given examples, the mlir looks clean.

What do you mean by clean?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In all these given examples, the mlir looks clean.

What do you mean by clean?

I meant the mlir version exactly matches the unit-test but the lowered version still refers to (!llvm.ptr, i32) instead of "l, r"

Copy link
Contributor

@durga4github durga4github left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@grypp grypp merged commit d08b176 into llvm:main May 15, 2025
11 checks passed
TIFitis pushed a commit to TIFitis/llvm-project that referenced this pull request May 19, 2025
This op allows using PTX directly within the NVVM dialect, while greatly
simplifying llvm.inline_asm generation.

**Example 1: Read-only Parameters**

Sets `"l,r"` automatically. 
```
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att
      "mbarrier.init.b64 [$0], $1;", "l,r" %arg0, %arg2 : (!llvm.ptr, i32) -> ()
```

**Example 2: Read-only and Write-only Parameters**

Sets `=f,f"` automatically. `=` is set because there is store. 
```
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32

// Lowers to:
%0 = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %arg0 : (f32) -> f32
```

**Example 3: Predicate Usage**

Now `@$2` is set automatically for predication. 
```
nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1

// Lowers to:
llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b" %arg0, %arg2, %arg3 : (!llvm.ptr, i32, i1) -> ()
```

---------

Co-authored-by: Mehdi Amini <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants