Skip to content

Commit a09cd51

Browse files
committed
Invoking vector transfer lowering pattern in amdgpu pass
1 parent ca9d7df commit a09cd51

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1717
#include "mlir/Dialect/SCF/IR/SCF.h"
1818
#include "mlir/Dialect/Vector/IR/VectorOps.h"
19+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/OpDefinition.h"
2122
#include "mlir/IR/PatternMatch.h"
2223
#include "mlir/IR/TypeUtilities.h"
2324
#include "mlir/Pass/Pass.h"
2425
#include "mlir/Support/LogicalResult.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2527
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
2628

2729
namespace mlir::amdgpu {
@@ -132,7 +134,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
132134

133135
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
134136
PatternRewriter &rewriter) const override {
135-
if (readOp->hasAttr("amdgpu.transformed"))
137+
if (readOp->hasAttr("amdgpu.buffer_transfer_read_needs_mask"))
136138
return failure();
137139

138140
bool requiresBroadcasting = false;
@@ -148,7 +150,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
148150
VectorType vectorType = readOp.getVectorType();
149151
int64_t vectorSize = vectorType.getNumElements();
150152
int64_t elementBitWidth = vectorType.getElementTypeBitWidth();
151-
// Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
152153
SmallVector<OpFoldResult> indices = readOp.getIndices();
153154

154155
auto stridedMetadata =
@@ -161,16 +162,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
161162
stridedMetadata.getConstifiedMixedOffset(),
162163
stridedMetadata.getConstifiedMixedSizes(),
163164
stridedMetadata.getConstifiedMixedStrides(), indices);
164-
// OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
165165
Value linearIndex =
166166
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
167167

168-
// Note below doesn't give the correct result for the linearized size.
169-
// It compute the mutiplied sizes of all dimensions instead of taking
170-
// the maximum of each dimension size * stride.
171168
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
169+
// Note below doesn't give the correct result for the linearized size.
172170
// Value totalSize = getValueOrCreateConstantIndexOp(
173171
// rewriter, loc, linearizedInfo.linearizedSize);
172+
// It compute the mutiplied sizes of all dimensions instead of taking
173+
// the maximum of each dimension size * stride.
174174
SmallVector<AffineExpr> productExpressions;
175175
SmallVector<Value> productResults;
176176
unsigned sourceRank =
@@ -201,7 +201,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
201201
Value isOutofBounds = rewriter.create<arith::CmpIOp>(
202202
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
203203

204-
// 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
204+
// 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
205205
Value deltaBytes = rewriter.create<arith::MulIOp>(
206206
loc, delta,
207207
rewriter.create<arith::ConstantIndexOp>(loc, elementBitWidth / 8));
@@ -219,7 +219,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
219219

220220
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
221221
Operation *read = builder.clone(*readOp.getOperation());
222-
read->setAttr("amdgpu.transformed", builder.getUnitAttr());
222+
read->setAttr("amdgpu.buffer_transfer_read_needs_mask",
223+
builder.getUnitAttr());
223224
Value readResult = read->getResult(0);
224225
builder.create<scf::YieldOp>(loc, readResult);
225226
};
@@ -244,6 +245,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
244245
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns(
245246
RewritePatternSet &patterns) {
246247
patterns.add<TransferReadLowering>(patterns.getContext());
248+
vector::populateVectorTransferLoweringPatterns(patterns);
247249
}
248250

249251
struct AmdgpuTransferReadToLoadPass final
@@ -252,6 +254,8 @@ struct AmdgpuTransferReadToLoadPass final
252254
void runOnOperation() override {
253255
RewritePatternSet patterns(&getContext());
254256
populateAmdgpuTransferReadToLoadPatterns(patterns);
255-
walkAndApplyPatterns(getOperation(), std::move(patterns));
257+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
258+
return signalPassFailure();
259+
}
256260
}
257261
};

0 commit comments

Comments
 (0)