Skip to content

Commit e013d44

Browse files
committed
Add parameterization for optimized shared memory variables
1 parent cd160a6 commit e013d44

File tree

4 files changed

+79
-56
lines changed

4 files changed

+79
-56
lines changed

mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1313
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1414
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
1515
include "mlir/Dialect/Transform/IR/TransformTypes.td"
16+
include "mlir/IR/EnumAttr.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
17-
1818
//===----------------------------------------------------------------------===//
1919
// ApplyOptimizeSharedMemoryReadsAndWritesOp
2020
//===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
2828
reads/writes with the goal of avoiding bank conflicts.
2929
}];
3030

31-
let arguments = (ins TransformHandleTypeInterface:$target);
31+
let arguments = (ins TransformHandleTypeInterface:$target,
32+
DefaultValuedOptionalAttr<I64Attr, "128">:$kSharedMemoryLineSizeBytes,
33+
DefaultValuedOptionalAttr<I64Attr, "128">:$kDefaultVectorSizeBits);
3234
let results = (outs);
3335

3436
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";

mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,18 @@ using namespace mlir::amdgpu;
3737

3838
/// The size of a shared memory line according to AMD documentation.
3939
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
40-
constexpr int64_t kSharedMemoryLineSizeBytes = 64;
40+
int64_t kSharedMemoryLineSizeBytes;
4141
/// We optimize for 64bit accesses, but this can be made an argument in the
4242
/// future.
43-
constexpr int64_t kDefaultVectorSizeBits = 64;
43+
int64_t kDefaultVectorSizeBits;
4444

45+
void setMemoryLineSize(int64_t _kSharedMemoryLineSizeBytes) {
46+
kSharedMemoryLineSizeBytes = _kSharedMemoryLineSizeBytes;
47+
}
48+
49+
void setDefaultVectorSize(int64_t _kDefaultVectorSizeBits) {
50+
kDefaultVectorSizeBits = _kDefaultVectorSizeBits;
51+
}
4552
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
4653
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
4754
/// floordiv(tgtIdxVal,vectorSize)))
@@ -151,6 +158,7 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
151158

152159
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
153160
Value memrefValue) {
161+
154162
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
155163
if (!memRefType ||
156164
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -219,6 +227,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
219227

220228
std::optional<LogicalResult>
221229
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
230+
//setMemoryLineSize(_kSharedMemoryLineSizeBytes);
231+
//setDefaultVectorSize(_kDefaultVectorSizeBits);
222232
SmallVector<memref::AllocOp> shmAllocOps;
223233
funcOp.walk([&](memref::AllocOp allocOp) {
224234
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -235,10 +245,23 @@ amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
235245

236246
struct OptimizeSharedMemoryPass
237247
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
248+
238249
public:
239-
OptimizeSharedMemoryPass() = default;
250+
OptimizeSharedMemoryPass()
251+
: OptimizeSharedMemoryBase(),
252+
_kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes = 128),
253+
_kDefaultVectorSizeBits(kDefaultVectorSizeBits = 128){};
254+
255+
OptimizeSharedMemoryPass(int64_t kSharedMemoryLineSizeBytes,
256+
int64_t kDefaultVectorSizeBits)
257+
: OptimizeSharedMemoryBase(),
258+
_kSharedMemoryLineSizeBytes(kSharedMemoryLineSizeBytes),
259+
_kDefaultVectorSizeBits(kDefaultVectorSizeBits){};
240260

241261
void runOnOperation() override {
262+
setMemoryLineSize(_kSharedMemoryLineSizeBytes);
263+
setDefaultVectorSize(_kDefaultVectorSizeBits);
264+
242265
Operation *op = getOperation();
243266
SmallVector<memref::AllocOp> shmAllocOps;
244267
op->walk([&](memref::AllocOp allocOp) {
@@ -253,4 +276,8 @@ struct OptimizeSharedMemoryPass
253276
return;
254277
}
255278
}
279+
280+
private:
281+
int64_t _kSharedMemoryLineSizeBytes;
282+
int64_t _kDefaultVectorSizeBits;
256283
};
Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,50 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
22

33
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
4-
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
4+
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
55
%readRow: index, %readCol: index,
66
%writeRow: index, %writeCol: index,
7-
%fragRow: index, %fragCol: index,
7+
%fragRow: index, %fragCol: index,
88
%fragColPerm: index,
99
%stRow: index, %stCol: index) {
10-
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
10+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
1111
%cst = arith.constant 0.000000e+00 : f16
1212

1313
// CHECK: [[shmA:%.+]] = memref.alloc
1414
// CHECK: [[shmB:%.+]] = memref.alloc
1515
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
1616
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
1717

18-
// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
1918
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
20-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
21-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
22-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
23-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
24-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
25-
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
19+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
20+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
21+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
22+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
23+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
2624
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
2725
gpu.barrier
2826
gpu.barrier
29-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
30-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
31-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
32-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
27+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
28+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
29+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
30+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
3331
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
34-
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
3532
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
3633

37-
// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
3834
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
39-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
40-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
41-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
42-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
43-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
44-
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
35+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
36+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
37+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
38+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
39+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
4540
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
4641
gpu.barrier
4742
gpu.barrier
48-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
49-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
50-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
51-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
43+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
44+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
45+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
46+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
5247
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
53-
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
5448
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
5549
return
5650
}
Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
1+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
22

33
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
4-
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
4+
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
55
%readRow: index, %readCol: index,
66
%writeRow: index, %writeCol: index,
7-
%fragRow: index, %fragCol: index,
7+
%fragRow: index, %fragCol: index,
88
%fragColPerm: index,
99
%stRow: index, %stCol: index) {
1010
%cst = arith.constant 0.000000e+00 : f16
@@ -13,33 +13,33 @@
1313
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
1414

1515
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
16-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
17-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
18-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
19-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
20-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
16+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
17+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
18+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
19+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
20+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
2121
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
2222
gpu.barrier
2323
gpu.barrier
24-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
25-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
26-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
27-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
28-
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
24+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
25+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
26+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
27+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
28+
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
2929
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
3030
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
31-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
32-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
33-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
34-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
35-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
31+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
32+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
33+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
34+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
35+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
3636
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
3737
gpu.barrier
3838
gpu.barrier
39-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
40-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
41-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
42-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
39+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
40+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
41+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
42+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
4343
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
4444
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
4545
return
@@ -48,7 +48,7 @@
4848
module attributes { transform.with_named_sequence } {
4949
transform.named_sequence @__transform_main(%root: !transform.any_op {transform.readonly}) {
5050
%0 = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.any_op
51-
transform.amdgpu.optimize_shared_memory_reads_and_writes %0 : (!transform.any_op) -> ()
51+
transform.amdgpu.optimize_shared_memory_reads_and_writes %0 {kSharedMemoryLineSizeBytes = 128, kDefaultVectorSizeBits = 128}: (!transform.any_op) -> ()
5252
transform.yield
5353
} // @__transform_main
5454
} // module

0 commit comments

Comments
 (0)