-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AMDGPU] Add parameterization for optimized shared memory variables #82508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
It is not ready yet. I need to make a couple of changes. |
e9ea568
to
88f5f83
Compare
88f5f83
to
030b521
Compare
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-backend-amdgpu Author: None (erman-gurses) Changes
Full diff: https://github.com/llvm/llvm-project/pull/82508.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
index 23873d86b495c6..9419c8b14069e2 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
-include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// ApplyOptimizeSharedMemoryReadsAndWritesOp
//===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
reads/writes with the goal of avoiding bank conflicts.
}];
- let arguments = (ins TransformHandleTypeInterface:$target);
+ let arguments = (ins TransformHandleTypeInterface:$target,
+ DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
+ DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
let results = (outs);
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index 79f9ab71a2b430..bb234d3a285e97 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,9 @@ LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue);
std::optional<LogicalResult>
-optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
+optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+ int64_t kSharedMemoryLineSizeBytes,
+ int64_t kDefaultVectorSizeBits);
} // namespace amdgpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
index ff29f9f6938535..08b57f7c8182f4 100644
--- a/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp
@@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
TransformState &state) {
- optimizeSharedMemoryReadsAndWritesOp(funcOp);
+ optimizeSharedMemoryReadsAndWritesOp(funcOp, getKSharedMemoryLineSizeBytes(),
+ getKDefaultVectorSizeBits());
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 6bd03ed833898d..c48a9e1a9a6422 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -37,11 +37,18 @@ using namespace mlir::amdgpu;
/// The size of a shared memory line according to AMD documentation.
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
-constexpr int64_t kSharedMemoryLineSizeBytes = 64;
-/// We optimize for 64bit accesses, but this can be made an argument in the
+int64_t kSharedMemoryLineSizeBytes;
+/// We optimize for 128 bit accesses, but this can be made an argument in the
/// future.
-constexpr int64_t kDefaultVectorSizeBits = 64;
+int64_t kDefaultVectorSizeBits;
+void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
+ kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
+}
+
+void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
+ kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
+}
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
/// floordiv(tgtIdxVal,vectorSize)))
@@ -151,6 +158,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
+
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType ||
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -218,7 +226,11 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
}
std::optional<LogicalResult>
-amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
+ int64_t kSharedMemoryLineSizeBytes,
+ int64_t kDefaultVectorSizeBits) {
+ setMemoryLineSize(kSharedMemoryLineSizeBytes);
+ setDefaultVectorSize(kDefaultVectorSizeBits);
SmallVector<memref::AllocOp> shmAllocOps;
funcOp.walk([&](memref::AllocOp allocOp) {
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -235,10 +247,23 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
+
public:
- OptimizeSharedMemoryPass() = default;
+ OptimizeSharedMemoryPass()
+ : OptimizeSharedMemoryBase(),
+ _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
+ _kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
+
+ OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
+ int64_t kDefaultVectorSizeBits)
+ : OptimizeSharedMemoryBase(),
+ _kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
+ _kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
void runOnOperation() override {
+ setMemoryLineSize(_kSharedMemoryLineSizeBytes);
+ setDefaultVectorSize(_kDefaultVectorSizeBits);
+
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
op->walk([&](memref::AllocOp allocOp) {
@@ -253,4 +278,8 @@ struct OptimizeSharedMemoryPass
return;
}
}
+
+private:
+ int64_t _kSharedMemoryLineSizeBytes;
+ int64_t _kDefaultVectorSizeBits;
};
diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
index a1de1ff87c229f..983eee732e2afe 100644
--- a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
@@ -1,13 +1,13 @@
-// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
+// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
- func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
+ func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
- %fragRow: index, %fragCol: index,
+ %fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
- // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
+ // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16
// CHECK: [[shmA:%.+]] = memref.alloc
@@ -15,42 +15,36 @@
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
- // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
- // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
- // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
- // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index 143e7c2d270952..83fcc2520f3ce7 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -1,10 +1,10 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
- func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
+ func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
- %fragRow: index, %fragCol: index,
+ %fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
%cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
- // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
- // CHECK: [[c7:%.+]] = arith.constant 7 : index
- // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
- // CHECK: [[c2:%.+]] = arith.constant 2 : index
- // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
+ // CHECK: [[c6:%.+]] = arith.constant 6 : index
+ // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
+ // CHECK: [[c2:%.+]] = arith.constant 2 : index
+ // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
@@ -48,7 +48,7 @@
module attributes { transform.with_named_sequence } {
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
- transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
+ transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
transform.yield
} // @__transform_main
} // module
|
385a4bd
to
1025f2b
Compare
a74928c
to
d5b8fcf
Compare
looks good to me, but will let mehdi approve based on his comments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! No concerns from my side, but probably best to leave it to approve by someone else here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG since @harsh-amd approved!
Thank you for the comments. |
sharedMemoryLineSizeBytes
anddefaultVectorSizeBits.