Skip to content

Commit 04381c1

Browse files
authored
[MLIR][AMDGPU]Add refactoring for shared-mem optimization (#81791)
Addressing the issues in this PR: #81550
1 parent b57ba8e commit 04381c1

File tree

4 files changed

+28
-37
lines changed

4 files changed

+28
-37
lines changed

mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- Transforms.h - AMDGPU Dialect transformations --------------*-
2-
// C++-*-===//
1+
//===- Transforms.h - AMDGPU Dialect transformations -------------*- C++-*-===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -46,10 +45,11 @@ namespace amdgpu {
4645
/// function that depends on the row Index. The permutation function is chosen
4746
/// to ensure that sequential distributed+vectorized reads/writes down a single
4847
/// dimension of the memref have minimal conflicts.
49-
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
50-
Value memrefValue);
48+
LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
49+
Value memrefValue);
5150

52-
void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
51+
std::optional<LogicalResult>
52+
optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
5353

5454
} // namespace amdgpu
5555
} // namespace mlir
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
add_subdirectory(IR)
2-
add_subdirectory(Utils)
32
add_subdirectory(TransformOps)
43
add_subdirectory(Transforms)
4+
add_subdirectory(Utils)

mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
5050
static Value permuteVectorOffset(OpBuilder &b, Location loc,
5151
ArrayRef<Value> indices, MemRefType memrefTy,
5252
int64_t srcDim, int64_t tgtDim) {
53-
/// Adjust the src index to change how often the permutation changes
54-
/// if necessary.
53+
// Adjust the src index to change how often the permutation changes
54+
// if necessary.
5555
Value src = indices[srcDim];
5656

57-
/// We only want to permute every N iterations of the target dim where N is
58-
/// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
57+
// We only want to permute every N iterations of the target dim where N is
58+
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
5959
const int64_t permuteEveryN = std::max<int64_t>(
6060
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
6161
memrefTy.getElementTypeBitWidth()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
110110
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
111111
}
112112

113-
/// Return all operations within `parentOp` that read from or write to
114-
/// `shmMemRef`.
113+
// Return all operations within `parentOp` that read from or write to
114+
// `shmMemRef`.
115115
static LogicalResult
116116
getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
117117
SmallVector<Operation *, 16> &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
131131
writeOps.push_back(op);
132132
});
133133

134-
/// Restrict to a supported set of ops. We also require at least 2D access,
135-
/// although this could be relaxed.
134+
// Restrict to a supported set of ops. We also require at least 2D access,
135+
// although this could be relaxed.
136136
if (llvm::any_of(readOps, [](Operation *op) {
137137
return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
138138
op) ||
@@ -149,23 +149,22 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
149149
return success();
150150
}
151151

152-
mlir::LogicalResult
153-
mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
154-
Value memrefValue) {
152+
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
153+
Value memrefValue) {
155154
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
156155
if (!memRefType ||
157156
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
158157
return failure();
159158

160-
/// Abort if the given value has any sub-views; we do not do any alias
161-
/// analysis.
159+
// Abort if the given value has any sub-views; we do not do any alias
160+
// analysis.
162161
bool hasSubView = false;
163162
parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
164163
if (hasSubView)
165164
return failure();
166165

167-
/// Check if this is necessary given the assumption of 128b accesses:
168-
/// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
166+
// Check if this is necessary given the assumption of 128b accesses:
167+
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
169168
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
170169
const int64_t rowsPerLine =
171170
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +174,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
175174
if (rowsPerLine >= threadGroupSize)
176175
return failure();
177176

178-
/// Get sets of operations within the function that read/write to shared
179-
/// memory.
177+
// Get sets of operations within the function that read/write to shared
178+
// memory.
180179
SmallVector<Operation *, 16> shmReadOps;
181180
SmallVector<Operation *, 16> shmWriteOps;
182181
if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +190,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
191190
int64_t tgtDim = memRefType.getRank() - 1;
192191
int64_t srcDim = memRefType.getRank() - 2;
193192

194-
/// Transform indices for the ops writing to shared memory.
193+
// Transform indices for the ops writing to shared memory.
195194
while (!shmWriteOps.empty()) {
196195
Operation *shmWriteOp = shmWriteOps.pop_back_val();
197196
builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +202,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
203202
amdgpu::setIndices(shmWriteOp, transformedIndices);
204203
}
205204

206-
/// Transform indices for the ops reading from shared memory.
205+
// Transform indices for the ops reading from shared memory.
207206
while (!shmReadOps.empty()) {
208207
Operation *shmReadOp = shmReadOps.pop_back_val();
209208
builder.setInsertionPoint(shmReadOp);
@@ -218,7 +217,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
218217
return success();
219218
}
220219

221-
void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
220+
std::optional<LogicalResult>
221+
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
222222
SmallVector<memref::AllocOp> shmAllocOps;
223223
funcOp.walk([&](memref::AllocOp allocOp) {
224224
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -228,8 +228,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
228228
for (auto allocOp : shmAllocOps) {
229229
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
230230
allocOp.getMemref())))
231-
return;
231+
return failure();
232232
}
233+
return success();
233234
}
234235

235236
struct OptimizeSharedMemoryPass

mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,17 @@
77
%fragRow: index, %fragCol: index,
88
%fragColPerm: index,
99
%stRow: index, %stCol: index) {
10-
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
1110
%cst = arith.constant 0.000000e+00 : f16
1211

13-
// CHECK: [[shmA:%.+]] = memref.alloc
14-
// CHECK: [[shmB:%.+]] = memref.alloc
1512
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
1613
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
1714

18-
// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
1915
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
2016
// CHECK: [[c7:%.+]] = arith.constant 7 : index
2117
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
2218
// CHECK: [[c2:%.+]] = arith.constant 2 : index
2319
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
2420
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
25-
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
2621
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
2722
gpu.barrier
2823
gpu.barrier
@@ -31,17 +26,13 @@
3126
// CHECK: [[c2:%.+]] = arith.constant 2 : index
3227
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
3328
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
34-
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
3529
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
36-
37-
// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
3830
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
3931
// CHECK: [[c7:%.+]] = arith.constant 7 : index
4032
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
4133
// CHECK: [[c2:%.+]] = arith.constant 2 : index
4234
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
4335
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
44-
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
4536
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
4637
gpu.barrier
4738
gpu.barrier
@@ -50,7 +41,6 @@
5041
// CHECK: [[c2:%.+]] = arith.constant 2 : index
5142
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
5243
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
53-
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
5444
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
5545
return
5646
}

0 commit comments

Comments
 (0)