Skip to content

[mlir][GPU] Plumb range information through the NVVM lowerings #107659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 13, 2024

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented Sep 6, 2024

Update the GPU to NVVM lowerings to correctly propagate range
information on IDs and dimension queries, etiher from
known_{block,grid}_size attributes or from upperBound annotations on
the operations themselves.

Copy link
Contributor Author

krzysz00 commented Sep 6, 2024

@llvmbot
Copy link
Member

llvmbot commented Sep 7, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Krzysztof Drewniak (krzysz00)

Changes

Update the GPU to NVVM lowerings to correctly propagate range
information on IDs and dimension queries, etiher from
known_{block,grid}_size attributes or from upperBound annotations on
the operations themselves.


Patch is 37.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107659.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+149-134)
  • (modified) mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp (+34-16)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp (+1)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+15-3)
  • (modified) mlir/test/Target/LLVMIR/Import/nvvmir.ll (+3)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+5-2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 709dd922b8fa2f..66ac9f289d233b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -86,8 +86,8 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
 }
 
-/// Base class that defines BasicPtxBuilderOpInterface. 
-class NVVM_PTXBuilder_Op<string mnemonic, 
+/// Base class that defines BasicPtxBuilderOpInterface.
+class NVVM_PTXBuilder_Op<string mnemonic,
   list<Trait> traits = [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> :
   LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
 }
@@ -123,52 +123,67 @@ class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
+class NVVM_SpecialRangeableRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_SpecialRegisterOp<mnemonic, traits> {
+  let arguments = (ins OptionalAttr<LLVM_ConstantRangeAttr>:$range);
+  let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)";
+  let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda;
+  let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda;
+
+  // Backwards-compatibility builder for an unspecified range.
+  let builders = [
+    OpBuilder<(ins "Type":$resultType), [{
+      build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{});
+    }]>
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // Lane index and range
-def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">;
-def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">;
+def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">;
+def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">;
 
 //===----------------------------------------------------------------------===//
 // Thread index and range
-def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
-def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
-def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
-def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
-def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
-def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
+def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">;
+def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">;
+def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">;
+def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">;
+def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">;
+def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">;
 
 //===----------------------------------------------------------------------===//
 // Block index and range
-def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
-def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
-def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
-def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
-def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
-def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
+def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">;
+def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">;
+def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">;
+def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">;
+def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">;
+def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA Cluster index and range
-def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">;
-def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">;
-def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">;
-def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">;
-def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">;
-def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">;
+def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">;
+def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">;
+def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">;
+def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">;
+def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">;
+def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">;
 
 
 //===----------------------------------------------------------------------===//
 // CTA index and range within Cluster
-def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
-def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
-def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
-def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
-def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
-def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
+def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">;
+def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">;
+def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">;
+def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">;
+def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">;
+def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">;
 
 //===----------------------------------------------------------------------===//
 // CTA index and across Cluster dimensions
-def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">;
-def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">;
+def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">;
+def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">;
 
 //===----------------------------------------------------------------------===//
 // Clock registers
@@ -197,11 +212,11 @@ def ReduxKindMin  : I32EnumAttrCase<"MIN", 4, "min">;
 def ReduxKindOr   : I32EnumAttrCase<"OR", 5, "or">;
 def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">;
 def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">;
-def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">; 
+def ReduxKindXor  : I32EnumAttrCase<"XOR", 8, "xor">;
 
 /// Enum attribute of the different kinds.
 def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind",
-  [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, 
+  [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr,
     ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
@@ -221,7 +236,7 @@ def NVVM_ReduxOp :
   }];
   let assemblyFormat = [{
     $kind $val `,` $mask_and_clamp  attr-dict `:` type($val) `->` type($res)
-   }];   
+   }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -308,7 +323,7 @@ def NVVM_MBarrierArriveNocompleteSharedOp : NVVM_Op<"mbarrier.arrive.nocomplete.
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,  
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
   Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
   let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
@@ -316,16 +331,16 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
   }];
 }
 
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,  
-  Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {    
+def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
+  Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
   let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
   }];
 }
 
-def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,  
-  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {  
+def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
+  Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
@@ -338,13 +353,13 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
         "bra.uni     LAB_WAIT; \n\t"
         "DONE: \n\t"
         "}"
-      ); 
+      );
     }
   }];
 }
 
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,  
-  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {  
+def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
+  Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
@@ -357,7 +372,7 @@ def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.p
         "bra.uni     LAB_WAIT; \n\t"
         "DONE: \n\t"
         "}"
-      ); 
+      );
     }
   }];
 }
@@ -392,7 +407,7 @@ def NVVM_Barrier0Op : NVVM_Op<"barrier0"> {
 }
 
 def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
-  let arguments = (ins     
+  let arguments = (ins
     Optional<I32>:$barrierId,
     Optional<I32>:$numberOfThreads);
   string llvmBuilder = [{
@@ -401,7 +416,7 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
                 {$barrierId, $numberOfThreads});
     } else if($barrierId) {
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier_n,
-                {$barrierId});   
+                {$barrierId});
     } else {
       createIntrinsicCall(builder, llvm::Intrinsic::nvvm_barrier0);
     }
@@ -410,27 +425,27 @@ def NVVM_BarrierOp : NVVM_Op<"barrier", [AttrSizedOperandSegments]> {
   let assemblyFormat = "(`id` `=` $barrierId^)? (`number_of_threads` `=` $numberOfThreads^)? attr-dict";
 }
 
-def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive"> 
+def NVVM_BarrierArriveOp : NVVM_PTXBuilder_Op<"barrier.arrive">
 {
   let arguments = (ins Optional<I32>:$barrierId, I32:$numberOfThreads);
 
   let description = [{
-    Thread that executes this op announces their arrival at the barrier with 
+    Thread that executes this op announces their arrival at the barrier with
     given id and continue their execution.
 
-    The default barrier id is 0 that is similar to `nvvm.barrier` Op. When 
-    `barrierId` is not present, the default barrier id is used. 
+    The default barrier id is 0 that is similar to `nvvm.barrier` Op. When
+    `barrierId` is not present, the default barrier id is used.
 
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar)
   }];
-  
+
   let assemblyFormat = "(`id` `=` $barrierId^)? `number_of_threads` `=` $numberOfThreads attr-dict";
 
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       std::string ptx = "bar.arrive ";
-      if (getBarrierId()) { ptx += "%0, %1;"; } 
+      if (getBarrierId()) { ptx += "%0, %1;"; }
       else { ptx += "0, %0;"; }
       return ptx;
     }
@@ -553,7 +568,7 @@ def NVVM_FenceProxyOp : NVVM_PTXBuilder_Op<"fence.proxy">,
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
   }];
-  
+
   let assemblyFormat = "attr-dict";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
@@ -671,9 +686,9 @@ def NVVM_FenceMbarrierInitOp : NVVM_PTXBuilder_Op<"fence.mbarrier.init"> {
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar)
   }];
-  
+
   let assemblyFormat = "attr-dict";
-  let extraClassDefinition = [{        
+  let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
       return std::string("fence.mbarrier_init.release.cluster;");
     }
@@ -749,13 +764,13 @@ def NVVM_SyncWarpOp :
 }
 
 
-def NVVM_ElectSyncOp : NVVM_Op<"elect.sync", 
+def NVVM_ElectSyncOp : NVVM_Op<"elect.sync",
                   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>
-{  
+{
   let results = (outs I1:$pred);
-  let assemblyFormat = "attr-dict `->` type(results)";  
-  let extraClassDefinition = [{        
-    std::string $cppClass::getPtx() { 
+  let assemblyFormat = "attr-dict `->` type(results)";
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
       return std::string(
         "{                                  \n"
         ".reg .u32 rx;                      \n"
@@ -764,7 +779,7 @@ def NVVM_ElectSyncOp : NVVM_Op<"elect.sync",
         "    elect.sync rx | px, 0xFFFFFFFF;\n"
         "@px mov.pred %0, 1;                \n"
         "}\n"
-      ); 
+      );
     }
   }];
 }
@@ -776,16 +791,16 @@ def LoadCacheModifierLU : I32EnumAttrCase<"LU", 3, "lu">;
 def LoadCacheModifierCV : I32EnumAttrCase<"CV", 4, "cv">;
 
 /// Enum attribute of the different kinds.
-def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind", 
+def LoadCacheModifierKind : I32EnumAttr<"LoadCacheModifierKind",
                                 "NVVM load cache modifier kind",
-  [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS, 
+  [LoadCacheModifierCA, LoadCacheModifierCG, LoadCacheModifierCS,
     LoadCacheModifierLU, LoadCacheModifierCV]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::NVVM";
   let description = [{
     Enum attribute of the different kinds of cache operators for load instructions.
 
-    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#id62)    
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#id62)
   }];
 }
 
@@ -811,7 +826,7 @@ def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
             id = llvm::Intrinsic::nvvm_cp_async_cg_shared_global_16;
           else if($modifier == NVVM::LoadCacheModifierKind::CA)
             id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
-          else 
+          else
             llvm_unreachable("unsupported cache modifier");
           break;
         default:
@@ -824,21 +839,21 @@ def NVVM_CpAsyncOp : NVVM_PTXBuilder_Op<"cp.async.shared.global">,
   let extraClassDeclaration = [{
     bool hasIntrinsic() { if(getCpSize()) return false; return true; }
 
-    void getAsmValues(RewriterBase &rewriter, 
+    void getAsmValues(RewriterBase &rewriter,
         llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) {
       asmValues.push_back({getDst(), PTXRegisterMod::Read});
       asmValues.push_back({getSrc(), PTXRegisterMod::Read});
       asmValues.push_back({makeConstantI32(rewriter, getSize()), PTXRegisterMod::Read});
       asmValues.push_back({getCpSize(), PTXRegisterMod::Read});
-    }        
+    }
   }];
-  let extraClassDefinition = [{        
-    std::string $cppClass::getPtx() { 
+  let extraClassDefinition = [{
+    std::string $cppClass::getPtx() {
       if(getModifier() == NVVM::LoadCacheModifierKind::CG)
         return std::string("cp.async.cg.shared.global [%0], [%1], %2, %3;\n");
       if(getModifier() == NVVM::LoadCacheModifierKind::CA)
         return std::string("cp.async.ca.shared.global [%0], [%1], %2, %3;\n");
-      llvm_unreachable("unsupported cache modifier");      
+      llvm_unreachable("unsupported cache modifier");
     }
   }];
 }
@@ -1526,9 +1541,9 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
   let hasVerifier = 1;
 }
 
-def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">, 
-  Arguments<(ins LLVM_PointerShared:$ptr, 
-                 Variadic<I32>:$sources, 
+def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
+  Arguments<(ins LLVM_PointerShared:$ptr,
+                 Variadic<I32>:$sources,
                  MMALayoutAttr:$layout)> {
   let summary = "cooperative matrix store";
   let description = [{
@@ -1537,7 +1552,7 @@ def NVVM_StMatrixOp: NVVM_PTXBuilder_Op<"stmatrix">,
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix)
   }];
-  
+
   let assemblyFormat = "$ptr `,` $sources attr-dict `:` type(operands)";
   let extraClassDefinition = [{
     std::string $cppClass::getPtx() {
@@ -1757,25 +1772,25 @@ def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
 }
 
 def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
-  Arguments<(ins 
-    ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group, 
+  Arguments<(ins
+    ConfinedAttr<I32Attr, [IntMinValue<0>]>:$group,
     OptionalAttr<UnitAttr>:$read)> {
   let assemblyFormat = "$group attr-dict";
   let description = [{
     Op waits for completion of the most recent bulk async-groups.
 
     The `$group` operand tells waiting has to be done until for $group or fewer
-    of the most recent bulk async-groups. If `$group` is 0, the op wait until 
+    of the most recent bulk async-groups. If `$group` is 0, the op wait until
     all the most recent bulk async-groups have completed.
 
-    The `$read` indicates that the waiting has to be done until all the bulk 
-    async operations in the specified bulk async-group have completed reading 
+    The `$read` indicates that the waiting has to be done until all the bulk
+    async operations in the specified bulk async-group have completed reading
     from their source locations.
 
     [For more information, see PTX ISA]
     (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group)
   }];
-  
+
   string llvmBuilder = [{
     auto intId = op.getRead() ?
       llvm::Intrinsic::nvvm_cp_async_bulk_wait_group_read :
@@ -1784,53 +1799,53 @@ def NVVM_CpAsyncBulkWaitGroupOp : NVVM_Op<"cp.async.bulk.wait_group">,
   }];
 }
 
-def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp : 
-  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
-  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
+def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
+  NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global",
+  [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>,
   AttrSizedOperandSegments]>,
   Arguments<(ins  LLVM_PointerShared:$dstMem,
                   LLVM_AnyPointer:$tmaDescriptor,
                   Variadic<I32>:$coordinates,
-                  LLVM_PointerShared:$mbar,                  
+                  LLVM_PointerShared:$mbar,
                   Variadic<I16>:$im2colOffsets,
                   Optional<I16>:$multicastMask,
                   Optional<I64>:$l2CacheHint,
       ...
[truncated]

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing that. I think this is going be very useful. I left some comments

LLVM::ConstantRangeAttr bounds = nullptr;
if (std::optional<APInt> upperBound = op.getUpperBound())
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
32, 0, upperBound->getZExtValue());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So 32 is the bitwidth, and 0 is the lower limit right? maybe we can create symbols to name them.

Variadic<I16>:$im2colOffsets,
Optional<I16>:$multicastMask,
Optional<I64>:$l2CacheHint,
PtxPredicate:$predicate)> {
let description = [{
Initiates an asynchronous copy operation on the tensor data from global
memory to shared memory.
Initiates an asynchronous copy operation on the tensor data from global
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you are removing the trailing spaces. It's unrelated. Can we remove it from this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Darn editor settings, done

@@ -84,7 +87,7 @@ llvm.func @llvm_nvvm_barrier0() {
// CHECK-SAME: i32 %[[barId:.*]], i32 %[[numThreads:.*]])
llvm.func @llvm_nvvm_barrier(%barID : i32, %numberOfThreads : i32) {
// CHECK: call void @llvm.nvvm.barrier0()
nvvm.barrier
nvvm.barrier
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove unrelated space removing from here as well

gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
// CHECK-LABEL: func @kernel_with_block_size(
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 32, 4, 2>, nvvm.kernel, nvvm.maxntid = array<i32: 32, 4, 2>}
gpu.func @kernel_with_block_size(%arg0: !llvm.ptr) kernel attributes {known_block_size = array<i32: 32, 4, 2>} {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, you added known_block_size to func.func. So I am wondering is this PR going to work for func.func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that code works generally

@@ -209,7 +209,12 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
MLIRContext *context = rewriter.getContext();
Value newOp = rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type());
LLVM::ConstantRangeAttr bounds = nullptr;
if (std::optional<APInt> upperBound = op.getUpperBound())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who is setting the upperbound? I might be missing something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

User code - I'll have some tests shortly

@grypp
Copy link
Member

grypp commented Sep 8, 2024

nit: lowterings typo

@krzysz00 krzysz00 force-pushed the users/krzysz00/nvvm-range-plumbing branch from f50dcd3 to d7a2149 Compare September 9, 2024 23:16
bounds = rewriter.getAttr<LLVM::ConstantRangeAttr>(
/*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue());
Value newOp =
rewriter.create<NVVM::LaneIdOp>(loc, rewriter.getI32Type(), bounds);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we always use kWarpSize = 32 for the laneId? This is a HW constraint, and it hasn't been changed over the years.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks for a good observation about the default

@grypp
Copy link
Member

grypp commented Sep 10, 2024

It looks good the in general. Let's wait the main PR to land, and then you can land this one as well.

@krzysz00 krzysz00 changed the title [mlir][GPU] Plumb range information through the NVVM lowterings [mlir][GPU] Plumb range information through the NVVM lowerings Sep 10, 2024
Base automatically changed from users/krzysz00/refactor-range-attributes-rocdl to main September 12, 2024 14:46
@krzysz00 krzysz00 requested a review from Mogball as a code owner September 12, 2024 14:46
@krzysz00
Copy link
Contributor Author

@grypp Can I get a review now?

Update the GPU to NVVM lowerings to correctly propagate range
information on IDs and dimension queries, etiher from
known_{block,grid}_size attributes or from `upperBound` annotations on
the operations themselves.
@grypp
Copy link
Member

grypp commented Sep 13, 2024

It looks good to me. I would also wait @durga4github for a review.

@grypp grypp requested a review from durga4github September 13, 2024 06:01
auto *inst = LLVM::detail::createIntrinsicCall(
builder, moduleTranslation, &opInst, llvm::Intrinsic::}] # !interleave([
enumName, "" # numResults, overloadedResultsCpp, overloadedOperandsCpp,
immArgPositionsCpp, immArgAttrNamesCpp], ",") # [{);
(void) inst;
}] # !if(!gt(requiresAccessGroup, 0), setAccessGroupsMetadataCode, "")
}];
string baseLlvmBuilderCoda = !if(!gt(numResults, 0), "$res = inst;", "");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we want this to be 'baseLlvmBuilderCode' and not '..coda'?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I did mean "coda"

Copy link
Contributor

@durga4github durga4github Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, please excuse my ignorance,
Could you please clarify the intent behind this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'll also note that this is from a previous PR)

The point here is to factor out the LLVM builder so you can add your own adjustments in the middle, and so I needed a name for the small bit that goes on at end to wrap everything up - and so went with coda by analogy to the musical concept.

@durga4github
Copy link
Contributor

LGTM, except a nit.

I am also seeking clarification in one place (mostly for my understanding).

@krzysz00 krzysz00 force-pushed the users/krzysz00/nvvm-range-plumbing branch from b974c8b to b2cfba7 Compare September 13, 2024 14:38
@krzysz00 krzysz00 merged commit a953982 into main Sep 13, 2024
8 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/nvvm-range-plumbing branch September 13, 2024 17:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants