Skip to content

[AMDGPU] Add parameterization for optimized shared memory variables #82508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "mlir/Interfaces/SideEffectInterfaces.td"
//===----------------------------------------------------------------------===//
// ApplyOptimizeSharedMemoryReadsAndWritesOp
//===----------------------------------------------------------------------===//
Expand All @@ -28,7 +28,9 @@ def ApplyOptimizeSharedMemoryReadsAndWritesOp :
reads/writes with the goal of avoiding bank conflicts.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedOptionalAttr<I64Attr, "128">:$sharedMemoryLineSizeBytes,
DefaultValuedOptionalAttr<I64Attr, "128">:$defaultVectorSizeBits);
let results = (outs);

let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
Expand Down
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@ def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> {
attempts to optimize reads/writes from a memref representing GPU shared
memory in order to avoid bank conflicts.
}];

let dependentDialects = [
"memref::MemRefDialect", "vector::VectorDialect"
];
let options = [
Option<"sharedMemoryLineSizeBytes", "shared-memory-line-size-bytes", "int64_t",
/*default=*/"128",
"Shared memory line size in bytes">,
Option<"defaultVectorSizeBits", "default-vector-size-bits", "int64_t",
/*default=*/"128",
"Default vector size in bits">,
];
}

#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ namespace amdgpu {
/// function that depends on the row Index. The permutation function is chosen
/// to ensure that sequential distributed+vectorized reads/writes down a single
/// dimension of the memref have minimal conflicts.
LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue);
LogicalResult
optimizeSharedMemoryReadsAndWrites(Operation *parentOp, Value memrefValue,
int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits);

std::optional<LogicalResult>
optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp);
optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits);

} // namespace amdgpu
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ DiagnosedSilenceableFailure
ApplyOptimizeSharedMemoryReadsAndWritesOp::applyToOne(
TransformRewriter &rewriter, FuncOp funcOp, ApplyToEachResultList &results,
TransformState &state) {
optimizeSharedMemoryReadsAndWritesOp(funcOp);
optimizeSharedMemoryReadsAndWritesOp(funcOp, getSharedMemoryLineSizeBytes(),
getDefaultVectorSizeBits());
return DiagnosedSilenceableFailure::success();
}

Expand Down
57 changes: 31 additions & 26 deletions mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,26 @@ namespace amdgpu {
using namespace mlir;
using namespace mlir::amdgpu;

/// The size of a shared memory line according to AMD documentation.
/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf
constexpr int64_t kSharedMemoryLineSizeBytes = 64;
/// We optimize for 64bit accesses, but this can be made an argument in the
/// future.
constexpr int64_t kDefaultVectorSizeBits = 64;

/// Uses `srcIndexValue` to permute `tgtIndexValue` via
/// `result = xor(floordiv(srcIdxVal,permuteEveryN),
/// floordiv(tgtIdxVal,vectorSize)))
/// + tgtIdxVal % vectorSize`
/// This is done using an optimized sequence of `arith` operations.
static Value permuteVectorOffset(OpBuilder &b, Location loc,
ArrayRef<Value> indices, MemRefType memrefTy,
int64_t srcDim, int64_t tgtDim) {
int64_t srcDim, int64_t tgtDim,
int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits) {
// Adjust the src index to change how often the permutation changes
// if necessary.
Value src = indices[srcDim];

// We only want to permute every N iterations of the target dim where N is
// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
const int64_t permuteEveryN = std::max<int64_t>(
1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
memrefTy.getElementTypeBitWidth()) /
8));
1, sharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
memrefTy.getElementTypeBitWidth()) /
8));

// clang-format off
// Index bit representation (b0 = least significant bit) for dim(1)
Expand All @@ -71,7 +66,7 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
// bits[N:M] = vector index
// clang-format on
int64_t n =
llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
llvm::Log2_64(defaultVectorSizeBits / memrefTy.getElementTypeBitWidth());
int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim));

// Capture bits[0:(M-N)] of src by first creating a (M-N) mask.
Expand Down Expand Up @@ -105,9 +100,11 @@ static Value permuteVectorOffset(OpBuilder &b, Location loc,
static void transformIndices(OpBuilder &builder, Location loc,
SmallVector<Value, 4> &indices,
MemRefType memrefTy, int64_t srcDim,
int64_t tgtDim) {
int64_t tgtDim, int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits) {
indices[tgtDim] =
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim,
sharedMemoryLineSizeBytes, defaultVectorSizeBits);
}

// Return all operations within `parentOp` that read from or write to
Expand Down Expand Up @@ -149,8 +146,9 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
return success();
}

LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
Value memrefValue) {
LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(
Operation *parentOp, Value memrefValue, int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits) {
auto memRefType = dyn_cast<MemRefType>(memrefValue.getType());
if (!memRefType ||
!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
Expand All @@ -167,10 +165,10 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
const int64_t rowsPerLine =
(8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
(8 * sharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
rowSize;
const int64_t threadGroupSize =
1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8));
1LL << (7 - llvm::Log2_64(defaultVectorSizeBits / 8));
if (rowsPerLine >= threadGroupSize)
return failure();

Expand Down Expand Up @@ -198,7 +196,8 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
auto indices = amdgpu::getIndices(shmWriteOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmWriteOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
defaultVectorSizeBits);
amdgpu::setIndices(shmWriteOp, transformedIndices);
}

Expand All @@ -210,24 +209,28 @@ LogicalResult amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
auto indices = amdgpu::getIndices(shmReadOp);
SmallVector<Value, 4> transformedIndices(indices->begin(), indices->end());
transformIndices(builder, shmReadOp->getLoc(), transformedIndices,
memRefType, srcDim, tgtDim);
memRefType, srcDim, tgtDim, sharedMemoryLineSizeBytes,
defaultVectorSizeBits);
amdgpu::setIndices(shmReadOp, transformedIndices);
}

return success();
}

std::optional<LogicalResult>
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp,
int64_t sharedMemoryLineSizeBytes,
int64_t defaultVectorSizeBits) {
SmallVector<memref::AllocOp> shmAllocOps;
funcOp.walk([&](memref::AllocOp allocOp) {
if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
return;
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
allocOp.getMemref())))
if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(
funcOp, allocOp.getMemref(), sharedMemoryLineSizeBytes,
defaultVectorSizeBits)))
return failure();
}
return success();
Expand All @@ -237,7 +240,8 @@ struct OptimizeSharedMemoryPass
: public amdgpu::impl::OptimizeSharedMemoryBase<OptimizeSharedMemoryPass> {
public:
OptimizeSharedMemoryPass() = default;

OptimizeSharedMemoryPass(const OptimizeSharedMemoryOptions &options)
: OptimizeSharedMemoryBase(options) {}
void runOnOperation() override {
Operation *op = getOperation();
SmallVector<memref::AllocOp> shmAllocOps;
Expand All @@ -248,8 +252,9 @@ struct OptimizeSharedMemoryPass
shmAllocOps.push_back(allocOp);
});
for (auto allocOp : shmAllocOps) {
if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(),
allocOp.getMemref())))
if (failed(optimizeSharedMemoryReadsAndWrites(op, allocOp.getMemref(),
sharedMemoryLineSizeBytes,
defaultVectorSizeBits)))
return;
}
}
Expand Down
50 changes: 22 additions & 28 deletions mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir
Original file line number Diff line number Diff line change
@@ -1,56 +1,50 @@
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s

// CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index)
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
func.func @optimize_shmem(%arg0: memref<4096x4096xf16>,
%readRow: index, %readCol: index,
%writeRow: index, %writeCol: index,
%fragRow: index, %fragCol: index,
%fragRow: index, %fragCol: index,
%fragColPerm: index,
%stRow: index, %stCol: index) {
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
// CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16
%cst = arith.constant 0.000000e+00 : f16

// CHECK: [[shmA:%.+]] = memref.alloc
// CHECK: [[shmB:%.+]] = memref.alloc
%shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
%shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>

// CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
// CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
// CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
%1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>

// CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
%2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
// CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]
vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
gpu.barrier
gpu.barrier
// CHECK: [[c7:%.+]] = arith.constant 7 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[c6:%.+]] = arith.constant 6 : index
// CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c6]]
// CHECK: [[c2:%.+]] = arith.constant 2 : index
// CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]
// CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
// CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
%3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
return
}
Loading