Skip to content

Commit 87c0260

Browse files
authored
[AMDGPU] Add parameterization for optimized shared memory variables (#82508)
- This PR adds parameterization for shared memory variables that are used for optimization: `sharedMemoryLineSizeBytes` and `defaultVectorSizeBits.` - The default values are set to 128 for both variables since it gives zero bank conflicts.
1 parent 9d56be0 commit 87c0260

File tree

7 files changed

+97
-84
lines changed

7 files changed

+97
-84
lines changed

mlir/include/mlir/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1313
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1414
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
1515
include "mlir/Dialect/Transform/IR/TransformTypes.td"
16-
include "mlir/Interfaces/SideEffectInterfaces.td"
1716

17+
include "mlir/Interfaces/SideEffectInterfaces.td"
1818
//===----------------------------------------------------------------------===//
1919
// ApplyOptimizeSharedMemoryReadsAndWritesOp
2020
//===----------------------------------------------------------------------===//
@@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
2828
reads/writes with the goal of avoiding bank conflicts.
2929
}];
3030

31-
let arguments = (ins TransformHandleTypeInterface:$target);
31+
let arguments = (ins TransformHandleTypeInterface:$target,
32+
DefaultValuedOptionalAttr<I64Attr, "128">:$sharedMemoryLineSizeBytes,
33+
DefaultValuedOptionalAttr<I64Attr, "128">:$defaultVectorSizeBits);
3234
let results = (outs);
3335

3436
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";

mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,17 @@ def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
3737
attempts to optimize reads/writes from a memref representing GPU shared
3838
memory in order to avoid bank conflicts.
3939
}];
40-
4140
let dependentDialects = [
4241
"memref::MemRefDialect", "vector::VectorDialect"
4342
];
43+
let options = [
44+
Option<"sharedMemoryLineSizeBytes", "shared-memory-line-size-bytes", "int64_t",
45+
/*default=*/"128",
46+
"Shared memory line size in bytes">,
47+
Option<"defaultVectorSizeBits", "default-vector-size-bits", "int64_t",
48+
/*default=*/"128",
49+
"Default vector size in bits">,
50+
];
4451
}
4552

4653
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_

mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,15 @@ namespace amdgpu {
4545
/// function that depends on the row Index. The permutation function is chosen
4646
/// to ensure that sequential distributed+vectorized reads/writes down a single
4747
/// dimension of the memref have minimal conflicts.
48-
LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
49-
Value memrefValue);
48+
LogicalResult
49+
optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue,
50+
int64_t sharedMemoryLineSizeBytes,
51+
int64_t defaultVectorSizeBits);
5052

5153
std::optional<LogicalResult>
52-
optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
54+
optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
55+
int64_t sharedMemoryLineSizeBytes,
56+
int64_t defaultVectorSizeBits);
5357

5458
} // namespace amdgpu
5559
} // namespace mlir

mlir/lib/Dialect/AMDGPU/TransformOps/AMDGPUTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
2727
ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
2828
TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
2929
TransformState &state) {
30-
optimizeSharedMemoryReadsAndWritesOp(funcOp);
30+
optimizeSharedMemoryReadsAndWritesOp(funcOp, getSharedMemoryLineSizeBytes(),
31+
getDefaultVectorSizeBits());
3132
return DiagnosedSilenceableFailure::success();
3233
}
3334

mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,26 @@ namespace amdgpu {
3535
using namespace mlir;
3636
using namespace mlir::amdgpu;
3737

38-
/// The size of a shared memory line according to AMD documentation.
39-
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
40-
constexpr int64_t kSharedMemoryLineSizeBytes = 64;
41-
/// We optimize for 64bit accesses, but this can be made an argument in the
42-
/// future.
43-
constexpr int64_t kDefaultVectorSizeBits = 64;
44-
4538
/// Uses `srcIndexValue` to permute `tgtIndexValue` via
4639
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
4740
/// floordiv(tgtIdxVal,vectorSize)))
4841
/// + tgtIdxVal % vectorSize`
4942
/// This is done using an optimized sequence of `arith` operations.
5043
static Value permuteVectorOffset(OpBuilder &b, Location loc,
5144
ArrayRef<Value> indices, MemRefType memrefTy,
52-
int64_t srcDim, int64_t tgtDim) {
45+
int64_t srcDim, int64_t tgtDim,
46+
int64_t sharedMemoryLineSizeBytes,
47+
int64_t defaultVectorSizeBits) {
5348
// Adjust the src index to change how often the permutation changes
5449
// if necessary.
5550
Value src = indices[srcDim];
5651

5752
// We only want to permute every N iterations of the target dim where N is
5853
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
5954
const int64_t permuteEveryN = std::max<int64_t>(
60-
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
61-
memrefTy.getElementTypeBitWidth()) /
62-
8));
55+
1, sharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
56+
memrefTy.getElementTypeBitWidth()) /
57+
8));
6358

6459
// clang-format off
6560
// Index bit representation (b0 = least significant bit) for dim(1)
@@ -71,7 +66,7 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
7166
// bits[N:M] = vector index
7267
// clang-format on
7368
int64_t n =
74-
llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
69+
llvm::Log2_64(defaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
7570
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));
7671

7772
// Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
@@ -105,9 +100,11 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
105100
static void transformIndices(OpBuilder &builder, Location loc,
106101
SmallVector<Value, 4> &indices,
107102
MemRefType memrefTy, int64_t srcDim,
108-
int64_t tgtDim) {
103+
int64_t tgtDim, int64_t sharedMemoryLineSizeBytes,
104+
int64_t defaultVectorSizeBits) {
109105
indices[tgtDim] =
110-
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
106+
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim,
107+
sharedMemoryLineSizeBytes, defaultVectorSizeBits);
111108
}
112109

113110
// Return all operations within `parentOp` that read from or write to
@@ -149,8 +146,9 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
149146
return success();
150147
}
151148

152-
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
153-
Value memrefValue) {
149+
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
150+
Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
151+
int64_t defaultVectorSizeBits) {
154152
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
155153
if (!memRefType ||
156154
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
@@ -167,10 +165,10 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
167165
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
168166
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
169167
const int64_t rowsPerLine =
170-
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
168+
(8 * sharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
171169
rowSize;
172170
const int64_t threadGroupSize =
173-
1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
171+
1LL << (7 - llvm::Log2_64(defaultVectorSizeBits / 8));
174172
if (rowsPerLine >= threadGroupSize)
175173
return failure();
176174

@@ -198,7 +196,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
198196
auto indices = amdgpu::getIndices(shmWriteOp);
199197
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
200198
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
201-
memRefType, srcDim, tgtDim);
199+
memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
200+
defaultVectorSizeBits);
202201
amdgpu::setIndices(shmWriteOp, transformedIndices);
203202
}
204203

@@ -210,24 +209,28 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
210209
auto indices = amdgpu::getIndices(shmReadOp);
211210
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
212211
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
213-
memRefType, srcDim, tgtDim);
212+
memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
213+
defaultVectorSizeBits);
214214
amdgpu::setIndices(shmReadOp, transformedIndices);
215215
}
216216

217217
return success();
218218
}
219219

220220
std::optional<LogicalResult>
221-
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
221+
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
222+
int64_t sharedMemoryLineSizeBytes,
223+
int64_t defaultVectorSizeBits) {
222224
SmallVector<memref::AllocOp> shmAllocOps;
223225
funcOp.walk([&](memref::AllocOp allocOp) {
224226
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
225227
return;
226228
shmAllocOps.push_back(allocOp);
227229
});
228230
for (auto allocOp : shmAllocOps) {
229-
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
230-
allocOp.getMemref())))
231+
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(
232+
funcOp, allocOp.getMemref(), sharedMemoryLineSizeBytes,
233+
defaultVectorSizeBits)))
231234
return failure();
232235
}
233236
return success();
@@ -237,7 +240,8 @@ struct OptimizeSharedMemoryPass
237240
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
238241
public:
239242
OptimizeSharedMemoryPass() = default;
240-
243+
OptimizeSharedMemoryPass(const OptimizeSharedMemoryOptions &options)
244+
: OptimizeSharedMemoryBase(options) {}
241245
void runOnOperation() override {
242246
Operation *op = getOperation();
243247
SmallVector<memref::AllocOp> shmAllocOps;
@@ -248,8 +252,9 @@ struct OptimizeSharedMemoryPass
248252
shmAllocOps.push_back(allocOp);
249253
});
250254
for (auto allocOp : shmAllocOps) {
251-
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
252-
allocOp.getMemref())))
255+
if (failed(optimizeSharedMemoryReadsAndWrites(op, allocOp.getMemref(),
256+
sharedMemoryLineSizeBytes,
257+
defaultVectorSizeBits)))
253258
return;
254259
}
255260
}
Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,50 @@
1-
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
22

33
// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
4-
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
4+
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
55
%readRow: index, %readCol: index,
66
%writeRow: index, %writeCol: index,
7-
%fragRow: index, %fragCol: index,
7+
%fragRow: index, %fragCol: index,
88
%fragColPerm: index,
99
%stRow: index, %stCol: index) {
10-
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
10+
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
1111
%cst = arith.constant 0.000000e+00 : f16
1212

1313
// CHECK: [[shmA:%.+]] = memref.alloc
1414
// CHECK: [[shmB:%.+]] = memref.alloc
1515
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
1616
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
1717

18-
// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
1918
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
20-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
21-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
22-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
23-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
24-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
25-
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
19+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
20+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
21+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
22+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
23+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
2624
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
2725
gpu.barrier
2826
gpu.barrier
29-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
30-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
31-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
32-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
27+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
28+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
29+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
30+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
3331
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
34-
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
3532
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
3633

37-
// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
3834
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
39-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
40-
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
41-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
42-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
43-
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
44-
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
35+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
36+
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
37+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
38+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
39+
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
4540
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
4641
gpu.barrier
4742
gpu.barrier
48-
// CHECK: [[c7:%.+]] = arith.constant 7 : index
49-
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
50-
// CHECK: [[c2:%.+]] = arith.constant 2 : index
51-
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
43+
// CHECK: [[c6:%.+]] = arith.constant 6 : index
44+
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
45+
// CHECK: [[c2:%.+]] = arith.constant 2 : index
46+
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
5247
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
53-
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
5448
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
5549
return
5650
}

0 commit comments

Comments
 (0)