16
16
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17
17
#include " mlir/Dialect/SCF/IR/SCF.h"
18
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
19
20
#include " mlir/IR/BuiltinTypes.h"
20
21
#include " mlir/IR/OpDefinition.h"
21
22
#include " mlir/IR/PatternMatch.h"
22
23
#include " mlir/IR/TypeUtilities.h"
23
24
#include " mlir/Pass/Pass.h"
24
25
#include " mlir/Support/LogicalResult.h"
26
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
25
27
#include " mlir/Transforms/WalkPatternRewriteDriver.h"
26
28
27
29
namespace mlir ::amdgpu {
@@ -132,7 +134,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
132
134
133
135
LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
134
136
PatternRewriter &rewriter) const override {
135
- if (readOp->hasAttr (" amdgpu.transformed " ))
137
+ if (readOp->hasAttr (" amdgpu.buffer_transfer_read_needs_mask " ))
136
138
return failure ();
137
139
138
140
bool requiresBroadcasting = false ;
@@ -148,7 +150,6 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
148
150
VectorType vectorType = readOp.getVectorType ();
149
151
int64_t vectorSize = vectorType.getNumElements ();
150
152
int64_t elementBitWidth = vectorType.getElementTypeBitWidth ();
151
- // Value linearIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
152
153
SmallVector<OpFoldResult> indices = readOp.getIndices ();
153
154
154
155
auto stridedMetadata =
@@ -161,16 +162,15 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
161
162
stridedMetadata.getConstifiedMixedOffset (),
162
163
stridedMetadata.getConstifiedMixedSizes (),
163
164
stridedMetadata.getConstifiedMixedStrides (), indices);
164
- // OpFoldResult linearIndexSize = linearizedInfo.linearizedSize;
165
165
Value linearIndex =
166
166
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
167
167
168
- // Note below doesn't give the correct result for the linearized size.
169
- // It compute the mutiplied sizes of all dimensions instead of taking
170
- // the maximum of each dimension size * stride.
171
168
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
169
+ // Note below doesn't give the correct result for the linearized size.
172
170
// Value totalSize = getValueOrCreateConstantIndexOp(
173
171
// rewriter, loc, linearizedInfo.linearizedSize);
172
+ // It compute the mutiplied sizes of all dimensions instead of taking
173
+ // the maximum of each dimension size * stride.
174
174
SmallVector<AffineExpr> productExpressions;
175
175
SmallVector<Value> productResults;
176
176
unsigned sourceRank =
@@ -201,7 +201,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
201
201
Value isOutofBounds = rewriter.create <arith::CmpIOp>(
202
202
loc, arith::CmpIPredicate::ule, delta, vectorSizeOffset);
203
203
204
- // 2) check if (detla(bytes) % (32 / elementBitwidth) != 0)
204
+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
205
205
Value deltaBytes = rewriter.create <arith::MulIOp>(
206
206
loc, delta,
207
207
rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
@@ -219,7 +219,8 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
219
219
220
220
auto thenBuilder = [&](OpBuilder &builder, Location loc) {
221
221
Operation *read = builder.clone (*readOp.getOperation ());
222
- read->setAttr (" amdgpu.transformed" , builder.getUnitAttr ());
222
+ read->setAttr (" amdgpu.buffer_transfer_read_needs_mask" ,
223
+ builder.getUnitAttr ());
223
224
Value readResult = read->getResult (0 );
224
225
builder.create <scf::YieldOp>(loc, readResult);
225
226
};
@@ -244,6 +245,7 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
244
245
void mlir::amdgpu::populateAmdgpuTransferReadToLoadPatterns (
245
246
RewritePatternSet &patterns) {
246
247
patterns.add <TransferReadLowering>(patterns.getContext ());
248
+ vector::populateVectorTransferLoweringPatterns (patterns);
247
249
}
248
250
249
251
struct AmdgpuTransferReadToLoadPass final
@@ -252,6 +254,8 @@ struct AmdgpuTransferReadToLoadPass final
252
254
void runOnOperation () override {
253
255
RewritePatternSet patterns (&getContext ());
254
256
populateAmdgpuTransferReadToLoadPatterns (patterns);
255
- walkAndApplyPatterns (getOperation (), std::move (patterns));
257
+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
258
+ return signalPassFailure ();
259
+ }
256
260
}
257
261
};
0 commit comments