Skip to content

[AMDGPU] Add parameterization for optimized shared memory variables #82508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 28, 2024

Conversation

erman-gurses
Copy link
Contributor

@erman-gurses erman-gurses commented Feb 21, 2024

  • This PR adds parameterization for shared memory variables that are used for optimization: sharedMemoryLineSizeBytes and defaultVectorSizeBits.
  • The default values are set to 128 for both variables since it gives zero bank conflicts.

Copy link

github-actions bot commented Feb 21, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@erman-gurses
Copy link
Contributor Author

It is not ready yet. I need to make a couple of changes.

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2024

@llvm/pr-subscribers-mlir-amdgpu
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: None (erman-gurses)

Changes
  • This PR adds parameterization for shared memory variables that are used for optimization: kSharedMemoryLineSizeBytes and kDefaultVectorSizeBits.
  • The default values are set to 128 for both variables since it gives zero bank conflicts.

Full diff: https://github.com/llvm/llvm-project/pull/82508.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td (+4-2)
  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+3-1)
  • (modified) mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp (+2-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+34-5)
  • (modified) mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir (+22-28)
  • (modified) mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir (+23-23)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 23873d86b495c6..9419c8b14069e2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
 include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
 include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
 
+include "mlir/Interfaces/SideEffectInterfaces.td"
 //===----------------------------------------------------------------------===//
 // ApplyOptimizeSharedMemoryReadsAndWritesOp
 //===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
     reads/writes with the goal of avoiding bank conflicts.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target);
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
+                    DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
   let results = (outs);
 
   let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 79f9ab71a2b430..bb234d3a285e97 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,9 @@ LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                  Value memrefValue);
 
 std::optional<LogicalResult>
-optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
+optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+                                     int64_t kSharedMemoryLineSizeBytes,
+                                     int64_t kDefaultVectorSizeBits);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index ff29f9f6938535..08b57f7c8182f4 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
 ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
     TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
     TransformState &state) {
-  optimizeSharedMemoryReadsAndWritesOp(funcOp);
+  optimizeSharedMemoryReadsAndWritesOp(funcOp, getKSharedMemoryLineSizeBytes(),
+                                       getKDefaultVectorSizeBits());
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6bd03ed833898d..c48a9e1a9a6422 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -37,11 +37,18 @@ using namespace mlir::amdgpu;
 
 /// The size of a shared memory line according to AMD documentation.
 /// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-constexpr int64_t kSharedMemoryLineSizeBytes = 64;
-/// We optimize for 64bit accesses, but this can be made an argument in the
+int64_t kSharedMemoryLineSizeBytes;
+/// We optimize for 128 bit accesses, but this can be made an argument in the
 /// future.
-constexpr int64_t kDefaultVectorSizeBits = 64;
+int64_t kDefaultVectorSizeBits;
 
+void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
+  kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
+}
+
+void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
+  kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
+}
 /// Uses `srcIndexValue` to permute `tgtIndexValue` via
 /// `result = xor(floordiv(srcIdxVal,permuteEveryN),
 ///               floordiv(tgtIdxVal,vectorSize)))
@@ -151,6 +158,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
 
 LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                          Value memrefValue) {
+
   auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
   if (!memRefType ||
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -218,7 +226,11 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
 }
 
 std::optional<LogicalResult>
-amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+                                             int64_t kSharedMemoryLineSizeBytes,
+                                             int64_t kDefaultVectorSizeBits) {
+  setMemoryLineSize(kSharedMemoryLineSizeBytes);
+  setDefaultVectorSize(kDefaultVectorSizeBits);
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -235,10 +247,23 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
 
 struct OptimizeSharedMemoryPass
     : public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+
 public:
-  OptimizeSharedMemoryPass() = default;
+  OptimizeSharedMemoryPass()
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+
+  OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
+                           int64_t kDefaultVectorSizeBits)
+      : OptimizeSharedMemoryBase(),
+        _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
+        _kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
 
   void runOnOperation() override {
+    setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+    setDefaultVectorSize(_kDefaultVectorSizeBits);
+
     Operation *op = getOperation();
     SmallVector<memref::AllocOp> shmAllocOps;
     op->walk([&](memref::AllocOp allocOp) {
@@ -253,4 +278,8 @@ struct OptimizeSharedMemoryPass
         return;
     }
   }
+
+private:
+  int64_t _kSharedMemoryLineSizeBytes;
+  int64_t _kDefaultVectorSizeBits;
 };
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
index a1de1ff87c229f..983eee732e2afe 100644
--- a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt  %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
   
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
+    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16
     %cst = arith.constant 0.000000e+00 : f16
 
     // CHECK: [[shmA:%.+]] = memref.alloc
@@ -15,42 +15,36 @@
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
 
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 143e7c2d270952..83fcc2520f3ce7 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt  %s -transform-interpreter  | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
 
   // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
-  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, 
+  func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
                     %readRow: index, %readCol: index,
                     %writeRow: index, %writeCol: index,
-                    %fragRow: index, %fragCol: index, 
+                    %fragRow: index, %fragCol: index,
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
     %cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]     
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
-    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
-    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
-    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+    // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
-    // CHECK: [[c7:%.+]] = arith.constant 7 : index                     
-    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]          
-    // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
-    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
+    // CHECK: [[c6:%.+]] = arith.constant 6 : index
+    // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+    // CHECK: [[c2:%.+]] = arith.constant 2 : index
+    // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
@@ -48,7 +48,7 @@
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
     %0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
-    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+    transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
     transform.yield
   } // @__transform_main
 } // module

@erman-gurses erman-gurses requested a review from ftynse February 26, 2024 22:22
@erman-gurses erman-gurses changed the title Add parameterization for optimized shared memory variables [AMDGPU] Add parameterization for optimized shared memory variables Feb 26, 2024
@erman-gurses erman-gurses removed the request for review from harsh-nod February 27, 2024 04:21
@harsh-amd
Copy link

looks good to me, but will let mehdi approve based on his comments.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! No concerns from my side, but probably best to leave it to approve by someone else here.

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG since @harsh-amd approved!

@erman-gurses
Copy link
Contributor Author

Thanks for the updates! No concerns from my side, but probably best to leave it to approve by someone else here.

Thank you for the comments.

@erman-gurses erman-gurses merged commit 87c0260 into llvm:main Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants