-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][AMDGPU] Adding dynamic size check to avoid subword buffer load #135014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-amdgpu @llvm/pr-subscribers-mlir-gpu Author: Zhuoran Yin (jerryyin) ChangesMotivation: amdgpu buffer load instruction will return all zeros when loading sub-word values. For example, assuming the buffer size is exactly one word and we attempt to invoke This PR come up with a fix to this problem, such that, it creates a bounds check against the buffer load instruction. It will compare the offset + vector size to see if the upper bound of the address will exceed the buffer size. If it does, masked transfer read will be optimized to Full diff: https://github.com/llvm/llvm-project/pull/135014.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
index 761caa448a57c..0e858108acf35 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td
@@ -54,15 +54,20 @@ def AmdgpuResolveStridedMetadataPass : Pass<"amdgpu-resolve-strided-metadata"> {
def AmdgpuTransferReadToLoadPass : Pass<"amdgpu-transfer-read-to-load"> {
let summary = "Lower the operations from the vector transfer_read to vector load";
let description = [{
- This pass creates a transfer read op lowering. A vector trasfer read op
- will be lowered to a combination of vector.load, arith.select and
- vector.broadcast.
+ This pass creates a transfer read op lowering optimization. The lowering
+ will produce a conditional check at runtime. If within bounds, a vector
+ trasfer read op will be lowered to a combination of vector.load, arith.select
+ and vector.broadcast. If not, it will fallback to the default lowering
+ of the transfer_read op.
This pattern will make it possible for masked transfer_read to be lowered
towards buffer load with bounds check, allowing a more optimized global
load accessing pattern compared with existing implementation of
llvm.intr.masked.load on vectors.
}];
- let dependentDialects = [];
+ let dependentDialects = [
+ "scf::SCFDialect",
+ "memref::MemRefDialect"
+ ];
}
#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
index bc5b6e9186449..8709a27e0168e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms
MLIRAMDGPUUtils
MLIRArithDialect
MLIRMemRefDialect
+ MLIRSCFDialect
MLIRVectorDialect
MLIRControlFlowDialect
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
index 3c1a2eb962037..519f695d99f91 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp
@@ -9,6 +9,8 @@
#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
@@ -108,6 +110,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
+ if (readOp->hasAttr("amdgpu.transformed"))
+ return failure();
bool requiresBroadcasting = false;
VectorType unbroadcastedVectorType;
@@ -117,20 +121,85 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
}
Location loc = readOp.getLoc();
- Value fill = rewriter.create<vector::SplatOp>(loc, unbroadcastedVectorType,
- readOp.getPadding());
- Value load = rewriter.create<vector::LoadOp>(
- loc, unbroadcastedVectorType, readOp.getSource(), readOp.getIndices());
- Value res = rewriter.create<arith::SelectOp>(loc, unbroadcastedVectorType,
- readOp.getMask(), load, fill);
-
- // Insert a broadcasting op if required.
- if (requiresBroadcasting) {
- res = rewriter.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
- res);
+ Value src = readOp.getSource();
+ MemRefType memRefType = cast<MemRefType>(src.getType());
+ ArrayRef<int64_t> shape = memRefType.getShape();
+
+ Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ Value stride = one;
+
+ // Compute the linear index by linearIndex += indices[i] * stride
+ for (int i = shape.size() - 1; i >= 0; --i) {
+ Value currentIndex = readOp.getIndices()[i];
+ Value strideIndexed =
+ rewriter.create<arith::MulIOp>(loc, currentIndex, stride);
+ linearIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, strideIndexed);
+
+ if (i == 0)
+ break;
+
+ // Update stride for the next dimension
+ Value nextStride;
+ if (shape[i] != ShapedType::kDynamic) {
+ nextStride = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ nextStride = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ stride = rewriter.create<arith::MulIOp>(loc, stride, nextStride);
+ }
+
+ // Add vector size offset to linear index
+ VectorType vectorType = readOp.getVectorType();
+ int64_t vectorSize = vectorType.getNumElements();
+ Value vectorSizeOffset =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
+ Value upperBoundIndex =
+ rewriter.create<arith::AddIOp>(loc, linearIndex, vectorSizeOffset);
+
+ Value totalSize = one;
+ for (size_t i = 0; i < shape.size(); ++i) {
+ Value dimensionSize;
+ if (shape[i] != ShapedType::kDynamic) {
+ dimensionSize = rewriter.create<arith::ConstantIndexOp>(loc, shape[i]);
+ } else {
+ dimensionSize = rewriter.create<memref::DimOp>(loc, src, i);
+ }
+ totalSize = rewriter.create<arith::MulIOp>(loc, totalSize, dimensionSize);
}
- rewriter.replaceOp(readOp, res);
+ Value isInBounds = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::ule, upperBoundIndex, totalSize);
+
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
+ Value fill = builder.create<vector::SplatOp>(loc, unbroadcastedVectorType,
+ readOp.getPadding());
+ Value load = builder.create<vector::LoadOp>(loc, unbroadcastedVectorType,
+ readOp.getSource(),
+ readOp.getIndices());
+ Value res = builder.create<arith::SelectOp>(loc, unbroadcastedVectorType,
+ readOp.getMask(), load, fill);
+
+ // Insert a broadcasting op if required.
+ if (requiresBroadcasting) {
+ res = builder.create<vector::BroadcastOp>(loc, readOp.getVectorType(),
+ res);
+ }
+ rewriter.create<scf::YieldOp>(loc, res);
+ };
+
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
+ Operation *read = builder.clone(*readOp.getOperation());
+ read->setAttr("amdgpu.transformed", builder.getUnitAttr());
+ Value readResult = read->getResult(0);
+ builder.create<scf::YieldOp>(loc, readResult);
+ };
+
+ auto ifOp =
+ rewriter.create<scf::IfOp>(loc, isInBounds, thenBuilder, elseBuilder);
+
+ rewriter.replaceOp(readOp, ifOp);
return success();
}
diff --git a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
index 3e1283579f2b1..776a047e6a85d 100644
--- a/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir
@@ -10,10 +10,54 @@ func.func @transfer_to_maskedload_fatrawbuffer(%mem : memref<8x8xf32, #amdgpu.ad
return %res : vector<4xf32>
}
// CHECK: %[[CST:.*]] = arith.constant 0.0
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[MUL0:.*]] = arith.muli %[[ARG1]], %[[C1]]
+// CHECK: %[[ADD0:.*]] = arith.addi %[[C0]], %[[MUL0]]
+// CHECK: %[[C8:.*]] = arith.constant 8
+// CHECK: %[[MUL1:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL2:.*]] = arith.muli %[[ARG1]], %[[MUL1]]
+// CHECK: %[[ADD1:.*]] = arith.addi %[[ADD0]], %[[MUL2]]
+// CHECK: %[[C4:.*]] = arith.constant 4
+// CHECK: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[C4]]
+
+// CHECK: %[[MUL3:.*]] = arith.muli %[[C1]], %[[C8]]
+// CHECK: %[[MUL4:.*]] = arith.muli
+
+// CHECK: %[[CMP:.*]] = arith.cmpi ule, %[[ADD2]], %[[MUL4]]
+// CHECK: %[[IF:.*]] = scf.if %[[CMP]] -> (vector<4xf32>) {
+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<4xf32>
+
+// CHECK: } else {
+// CHECK: %[[LOAD:.*]] = vector.transfer_read %arg0[%arg1, %arg1], %[[CST]], %arg2 {amdgpu.transformed, in_bounds = [true]} : memref<8x8xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+
+// CHECK: return %[[IF]] : vector<4xf32>
+
+// -----
+
+// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic(
+// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-SAME: %[[ARG2:.*]]: vector<4xi1>
+func.func @transfer_to_maskedload_fatrawbuffer_dynamic(%mem : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %idx : index, %mask : vector<4xi1>) -> vector<4xf32> {
+ %cf0 = arith.constant 0.0 : f32
+ %res = vector.transfer_read %mem[%idx, %idx], %cf0, %mask {in_bounds = [true]} : memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, vector<4xf32>
+ return %res : vector<4xf32>
+}
+
+// CHECK: %[[C1:.*]] = arith.constant 1
+// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]]
+// CHECK: %[[MUL0:.*]] = arith.muli %{{.*}}, %[[DIM1]]
+// CHECK: %[[C0:.*]] = arith.constant 0
+// CHECK: %[[DIM0:.*]] = memref.dim %[[ARG0]], %[[C0]]
+// CHECK: %[[MUL1:.*]] = arith.muli %{{.*}}, %[[DIM0]]
+
+// CHECK: %[[C1_1:.*]] = arith.constant 1
+// CHECK: %[[DIM1_1:.*]] = memref.dim %[[ARG0]], %[[C1_1]]
+// CHECK: %[[MUL2:.*]] = arith.muli %{{.*}}, %[[DIM1_1]]
// -----
@@ -64,7 +108,6 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32, #amdgpu.address_space<fa
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
// CHECK: %[[BROADCAST:.*]] = vector.broadcast %[[SELECT]] : vector<1xf32> to vector<4xf32>
-// CHECK: return %[[BROADCAST]] : vector<4xf32>
// -----
@@ -83,4 +126,3 @@ func.func @transfer_scalar(%mem : memref<8x8xf32, #amdgpu.address_space<fat_raw_
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CST]]
// CHECK: %[[LOAD:.*]] = vector.load %arg0[%arg1, %arg1]
// CHECK: %[[SELECT:.*]] = arith.select %arg2, %[[LOAD]], %[[SPLAT]]
-// CHECK: return %[[SELECT]] : vector<1xf32>
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 141986392917e..c4d87484fd5d5 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -1568,6 +1568,7 @@ cc_library(
":IR",
":MemRefDialect",
":Pass",
+ ":SCFDialect",
":SideEffectInterfaces",
":Support",
":TransformUtils",
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot reviewed 1 out of 5 changed files in this pull request and generated no comments.
Files not reviewed (4)
- mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td: Language not supported
- mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt: Language not supported
- mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir: Language not supported
- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel: Language not supported
Comments suppressed due to low confidence (1)
mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp:140
- [nitpick] Consider restructuring the loop to iterate only over the necessary dimensions instead of breaking when i equals 0. This could improve clarity and intent in the linear index computation.
if (i == 0) break;
Will give an actual review when properly awake but I wanted to flag that I really don't think the approach in the summary is the right one |
@krzysz00 I'd be interested to hear your thoughts on this. My wild guess is that you'd prefer to more precise condition to check specifically on if the sub-word loading situation has appeared. The reason why I didn't take that approach is due to this adds a lot of overhead: 1) Taking an additional look into each element of the select vector 2) Checking if the offset is mis-aligned 3) Verifying the offset + select vector size exceed the buffer size. Realistically, I'd argue that my approach is a generic enough proxy that only happens at the last iteration of the K loop and doesn't adds as much of the overhead in checking the precise conditions. |
First, I'm not convinced this sort of dynamic control flow is the way to go here at all and I want to have a broader discussion of how we got here If we're implementing a MLIR-side fix here - which I'm still not convinced of, especially since some of these tests are easier to inplement at the LLVM level, we need:
|
(and also, because these are vector ops, I think negative starting indices might be allowed so that's another scalarization case) |
Ok that's quite a few points... Would be happy to discuss offline but below are my understanding:
I can't think of any better approach on MLIR side that can apply unconditionally to either small toy example or large matmul example, so a dynamic check seems necessary. If you come up with better LLVM implementations that can help avoid this additional overhead, we can revert this PR.
Good point, I can skip when the element type >= a word
Even if the alignment >= 4 bytes and buffer is a multiple of 4 the boundary condition can still be triggered if we read from in the middle of the last word?
I thought the conclusion we had the other day is to avoid negative offsets as much as possible? I see this as slightly off topic and needs broader discussion/approval. |
To summarize discussions elsewhere:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits, I think this is all reasonable
ddee079
to
d30d971
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty happy with these changes
…llvm#135014) Motivation: amdgpu buffer load instruction will return all zeros when loading sub-word values. For example, assuming the buffer size is exactly one word and we attempt to invoke `llvm.amdgcn.raw.ptr.buffer.load.v2i32` starting from byte 2 of the word, we will not receive the actual value of the buffer but all zeros for the first word. This is because the boundary has been crossed for the first word. This PR come up with a fix to this problem, such that, it creates a bounds check against the buffer load instruction. It will compare the offset + vector size to see if the upper bound of the address will exceed the buffer size. If it does, masked transfer read will be optimized to `vector.load` + `arith.select`, else, it will continue to fall back to default lowering of the masked vector load.
…llvm#135014) Motivation: amdgpu buffer load instruction will return all zeros when loading sub-word values. For example, assuming the buffer size is exactly one word and we attempt to invoke `llvm.amdgcn.raw.ptr.buffer.load.v2i32` starting from byte 2 of the word, we will not receive the actual value of the buffer but all zeros for the first word. This is because the boundary has been crossed for the first word. This PR come up with a fix to this problem, such that, it creates a bounds check against the buffer load instruction. It will compare the offset + vector size to see if the upper bound of the address will exceed the buffer size. If it does, masked transfer read will be optimized to `vector.load` + `arith.select`, else, it will continue to fall back to default lowering of the masked vector load.
…135982) `delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect in #135014 For example, last load is issued to load only one last element of fp16. Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be judged as word aligned. It will send to fast path but get all zeros for the fp16 because it cross the word boundary. In reality the equation should be just `delta_bytes % 4` , since a word is 4 bytes. This PR fix the bug by amending the mod target to 4.
…fastpath (#135982) `delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect in llvm/llvm-project#135014 For example, last load is issued to load only one last element of fp16. Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be judged as word aligned. It will send to fast path but get all zeros for the fp16 because it cross the word boundary. In reality the equation should be just `delta_bytes % 4` , since a word is 4 bytes. This PR fix the bug by amending the mod target to 4.
…lvm#135982) `delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect in llvm#135014 For example, last load is issued to load only one last element of fp16. Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be judged as word aligned. It will send to fast path but get all zeros for the fp16 because it cross the word boundary. In reality the equation should be just `delta_bytes % 4` , since a word is 4 bytes. This PR fix the bug by amending the mod target to 4.
…lvm#135982) `delta_bytes % (32 ceilDiv elementBitwidth) != 0` condition is incorrect in llvm#135014 For example, last load is issued to load only one last element of fp16. Then `delta bytes = 2`, `(32 ceildiv 16) = 2`. In this case it will be judged as word aligned. It will send to fast path but get all zeros for the fp16 because it cross the word boundary. In reality the equation should be just `delta_bytes % 4` , since a word is 4 bytes. This PR fix the bug by amending the mod target to 4.
Motivation: amdgpu buffer load instruction will return all zeros when loading sub-word values. For example, assuming the buffer size is exactly one word and we attempt to invoke
llvm.amdgcn.raw.ptr.buffer.load.v2i32
starting from byte 2 of the word, we will not receive the actual value of the buffer but all zeros for the first word. This is because the boundary has been crossed for the first word.This PR come up with a fix to this problem, such that, it creates a bounds check against the buffer load instruction. It will compare the offset + vector size to see if the upper bound of the address will exceed the buffer size. If it does, masked transfer read will be optimized to
vector.load
+arith.select
, else, it will continue to fall back to default lowering of the masked vector load.