Skip to content

[MLIR][NVVM] Migrate CpAsyncOp to intrinsics #123789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

durga4github
Copy link
Contributor

Intrinsics are available for the 'cpSize'
variants also. So, this patch migrates the Op
to lower to the intrinsics for all cases.

  • Update the existing tests to check the lowering to intrinsics.
  • Add newer cp_async_zfill tests to verify the lowering for the 'cpSize' variants.
  • Tidy-up CHECK lines in cp_async() function in nvvmir.mlir (NFC)

PTX spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

Intrinsics are available for the 'cpSize'
variants also. So, this patch migrates the Op
to lower to the intrinsics for all cases.

  • Update the existing tests to check the lowering to intrinsics.
  • Add newer cp_async_zfill tests to verify the lowering for the 'cpSize' variants.
  • Tidy-up CHECK lines in cp_async() function in nvvmir.mlir (NFC)

PTX spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async


Full diff: https://github.com/llvm/llvm-project/pull/123789.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+15-40)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+23)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+2-6)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+20-6)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 797a0067081314..dc4295926a8ce5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -849,55 +849,30 @@ def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
 
 def LoadCacheModifierAttr : EnumAttr<NVVM_Dialect, LoadCacheModifierKind, "load_cache_modifier">;
 
-def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
+def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
   Arguments<(ins LLVM_PointerShared:$dst,
                  LLVM_PointerGlobal:$src,
                  I32Attr:$size,
                  LoadCacheModifierAttr:$modifier,
                  Optional<LLVM_Type>:$cpSize)> {
-  string llvmBuilder = [{
-      llvm::Intrinsic::ID id;
-      switch ($size) {
-        case 4:
-          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
-          break;
-        case 8:
-          id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
-          break;
-        case 16:
-          if($modifier == NVVM::LoadCacheModifierKind::CG)
-            id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
-          else if($modifier == NVVM::LoadCacheModifierKind::CA)
-            id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
-          else 
-            llvm_unreachable("unsupported cache modifier");
-          break;
-        default:
-          llvm_unreachable("unsupported async copy size");
-      }
-      createIntrinsicCall(builder, id, {$dst, $src});
-  }];
   let assemblyFormat = "$dst `,` $src `,` $size `,` `cache` `=` $modifier (`,` $cpSize^)? attr-dict `:` type(operands)";
   let hasVerifier = 1;
   let extraClassDeclaration = [{
-    bool hasIntrinsic() { if(getCpSize()) return false; return true; }
-
-    void getAsmValues(RewriterBase &rewriter, 
-        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) {
-      asmValues.push_back({getDst(), PTXRegisterMod::Read});
-      asmValues.push_back({getSrc(), PTXRegisterMod::Read});
-      asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read});
-      asmValues.push_back({getCpSize(), PTXRegisterMod::Read});
-    }        
+    static llvm::Intrinsic::ID getIntrinsicID(int size,
+                                              NVVM::LoadCacheModifierKind kind,
+                                              bool hasCpSize);
   }];
-  let extraClassDefinition = [{        
-    std::string $cppClass::getPtx() { 
-      if(getModifier() == NVVM::LoadCacheModifierKind::CG)
-        return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
-      if(getModifier() == NVVM::LoadCacheModifierKind::CA)
-        return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
-      llvm_unreachable("unsupported cache modifier");      
-    }
+  string llvmBuilder = [{
+    bool hasCpSize = op.getCpSize() ? true : false;
+
+    llvm::SmallVector<llvm::Value *> translatedOperands;
+    translatedOperands.push_back($dst);
+    translatedOperands.push_back($src);
+    if (hasCpSize)
+      translatedOperands.push_back($cpSize);
+
+    auto id = NVVM::CpAsyncOp::getIntrinsicID($size, $modifier, hasCpSize);
+    createIntrinsicCall(builder, id, translatedOperands);
   }];
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ccb5ad05f0bf72..2c45753d52da9c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1110,6 +1110,29 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+#define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
+  llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
+
+#define GET_CP_ASYNC_ID(mod, size, has_cpsize)                                 \
+  has_cpsize ? CP_ASYNC_ID_IMPL(mod, size, _s) : CP_ASYNC_ID_IMPL(mod, size, )
+
+llvm::Intrinsic::ID
+CpAsyncOp::getIntrinsicID(int size, NVVM::LoadCacheModifierKind cacheMod,
+                          bool hasCpSize) {
+  switch (size) {
+  case 4:
+    return GET_CP_ASYNC_ID(ca, 4, hasCpSize);
+  case 8:
+    return GET_CP_ASYNC_ID(ca, 8, hasCpSize);
+  case 16:
+    return (cacheMod == NVVM::LoadCacheModifierKind::CG)
+               ? GET_CP_ASYNC_ID(cg, 16, hasCpSize)
+               : GET_CP_ASYNC_ID(ca, 16, hasCpSize);
+  default:
+    llvm_unreachable("Invalid copy size in CpAsyncOp.");
+  }
+}
+
 llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
                                                                 bool isIm2Col) {
   switch (tensorDims) {
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 84ea55ceb5acc2..c7a6eca1582768 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -74,13 +74,9 @@ func.func @async_cp(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>) {
 
 // CHECK-LABEL: @async_cp_zfill
 func.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16, cache =  cg, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
   nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att 
-  // CHECK-SAME: "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", 
-  // CHECK-SAME: "r,l,n,r" %{{.*}}, %{{.*}}, %{{.*}} : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> ()
+  // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 4, cache =  ca, %{{.*}} : !llvm.ptr<3>, !llvm.ptr<1>, i32
   nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
   return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 09e98765413f0c..7dad9a403def0e 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -488,21 +488,35 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
 
 // CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) {
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 4, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 8, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 16, cache =  ca : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
+  // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 16, cache =  cg : !llvm.ptr<3>, !llvm.ptr<1>
-// CHECK: call void @llvm.nvvm.cp.async.commit.group()
+
+  // CHECK: call void @llvm.nvvm.cp.async.commit.group()
   nvvm.cp.async.commit.group
-// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
+  // CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
   nvvm.cp.async.wait.group 0
   llvm.return
 }
 
+// CHECK-LABEL: @async_cp_zfill
+llvm.func @async_cp_zfill(%dst: !llvm.ptr<3>, %src: !llvm.ptr<1>, %cpSize: i32) {
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 4, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 8, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  ca, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  // CHECK: call void @llvm.nvvm.cp.async.cg.shared.global.16.s(ptr addrspace(3) %{{.*}}, ptr addrspace(1) %{{.*}}, i32 %{{.*}})
+  nvvm.cp.async.shared.global %dst, %src, 16, cache =  cg, %cpSize : !llvm.ptr<3>, !llvm.ptr<1>, i32
+  llvm.return
+}
+
 // CHECK-LABEL: @cp_async_mbarrier_arrive
 llvm.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.ptr) {
   // CHECK: call void @llvm.nvvm.cp.async.mbarrier.arrive(ptr %{{.*}})

Intrinsics are available for the 'cpSize'
variants also. So, this patch migrates the Op
to lower to the intrinsics for all cases.

* Update the existing tests to check the
  lowering to intrinsics.
* Add newer cp_async_zfill tests to verify the
  lowering for the 'cpSize' variants.
* Tidy-up CHECK lines in cp_async() function
  in nvvmir.mlir (NFC)

PTX spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cp-async

Signed-off-by: Durgadoss R <[email protected]>
@durga4github durga4github force-pushed the durgadossr/mlir_cp_async_migrate branch from ac5d431 to ef684fc Compare January 22, 2025 18:13
@durga4github durga4github merged commit 2e6cc79 into llvm:main Jan 23, 2025
8 checks passed
@durga4github durga4github deleted the durgadossr/mlir_cp_async_migrate branch January 23, 2025 10:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants