9
9
#include " mlir/Dialect/AMDGPU/Transforms/Passes.h"
10
10
11
11
#include " mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
12
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
13
+ #include " mlir/Dialect/Arith/IR/Arith.h"
14
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
15
+ #include " mlir/Dialect/MemRef/IR/MemRef.h"
16
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17
+ #include " mlir/Dialect/SCF/IR/SCF.h"
12
18
#include " mlir/Dialect/Vector/IR/VectorOps.h"
19
+ #include " mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
13
20
#include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/IR/OpDefinition.h"
14
22
#include " mlir/IR/PatternMatch.h"
15
23
#include " mlir/IR/TypeUtilities.h"
16
24
#include " mlir/Pass/Pass.h"
17
25
#include " mlir/Support/LogicalResult.h"
18
- #include " mlir/Transforms/WalkPatternRewriteDriver.h"
26
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
27
+ #include " llvm/Support/MathExtras.h"
19
28
20
29
namespace mlir ::amdgpu {
21
30
#define GEN_PASS_DEF_AMDGPUTRANSFERREADTOLOADPASS
@@ -67,6 +76,9 @@ static LogicalResult transferPreconditions(
67
76
if (!memRefType.isLastDimUnitStride ())
68
77
return rewriter.notifyMatchFailure (xferOp, " != 1 stride needs VectorToSCF" );
69
78
79
+ if (memRefType.getElementTypeBitWidth () < 8 )
80
+ return rewriter.notifyMatchFailure (xferOp, " unsupported sub-byte type" );
81
+
70
82
// If there is broadcasting involved then we first load the unbroadcasted
71
83
// vector, and then broadcast it with `vector.broadcast`.
72
84
ArrayRef<int64_t > vectorShape = xferOp.getVectorType ().getShape ();
@@ -101,13 +113,35 @@ static LogicalResult transferPreconditions(
101
113
return success ();
102
114
}
103
115
116
+ static Value createVectorLoadForMaskedLoad (OpBuilder &builder, Location loc,
117
+ vector::TransferReadOp readOp,
118
+ bool requiresBroadcasting,
119
+ VectorType unbroadcastedVectorType) {
120
+ Value fill = builder.create <vector::SplatOp>(loc, unbroadcastedVectorType,
121
+ readOp.getPadding ());
122
+ Value load = builder.create <vector::LoadOp>(
123
+ loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
124
+ Value res = builder.create <arith::SelectOp>(loc, unbroadcastedVectorType,
125
+ readOp.getMask (), load, fill);
126
+ // Insert a broadcasting op if required.
127
+ if (requiresBroadcasting) {
128
+ res = builder.create <vector::BroadcastOp>(loc, readOp.getVectorType (), res);
129
+ }
130
+ return res;
131
+ }
132
+
133
+ static constexpr char kTransferReadNeedsMask [] =
134
+ " amdgpu.buffer_transfer_read_needs_mask" ;
135
+
104
136
namespace {
105
137
106
138
struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
107
139
using OpRewritePattern::OpRewritePattern;
108
140
109
141
LogicalResult matchAndRewrite (vector::TransferReadOp readOp,
110
142
PatternRewriter &rewriter) const override {
143
+ if (readOp->hasAttr (kTransferReadNeedsMask ))
144
+ return failure ();
111
145
112
146
bool requiresBroadcasting = false ;
113
147
VectorType unbroadcastedVectorType;
@@ -117,20 +151,115 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
117
151
}
118
152
119
153
Location loc = readOp.getLoc ();
120
- Value fill = rewriter.create <vector::SplatOp>(loc, unbroadcastedVectorType,
121
- readOp.getPadding ());
122
- Value load = rewriter.create <vector::LoadOp>(
123
- loc, unbroadcastedVectorType, readOp.getSource (), readOp.getIndices ());
124
- Value res = rewriter.create <arith::SelectOp>(loc, unbroadcastedVectorType,
125
- readOp.getMask (), load, fill);
126
-
127
- // Insert a broadcasting op if required.
128
- if (requiresBroadcasting) {
129
- res = rewriter.create <vector::BroadcastOp>(loc, readOp.getVectorType (),
130
- res);
154
+ Value src = readOp.getSource ();
155
+
156
+ VectorType vectorType = readOp.getVectorType ();
157
+ int64_t vectorSize = vectorType.getNumElements ();
158
+ int64_t elementBitWidth = vectorType.getElementTypeBitWidth ();
159
+ SmallVector<OpFoldResult> indices = readOp.getIndices ();
160
+
161
+ auto stridedMetadata =
162
+ rewriter.create <memref::ExtractStridedMetadataOp>(loc, src);
163
+ SmallVector<OpFoldResult> strides =
164
+ stridedMetadata.getConstifiedMixedStrides ();
165
+ SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes ();
166
+ OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset ();
167
+ OpFoldResult linearizedIndices;
168
+ std::tie (std::ignore, linearizedIndices) =
169
+ memref::getLinearizedMemRefOffsetAndSize (rewriter, loc, elementBitWidth,
170
+ elementBitWidth, offset, sizes,
171
+ strides, indices);
172
+
173
+ // TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
174
+ // Note below doesn't give the correct result for the linearized size.
175
+ // Value totalSize = getValueOrCreateConstantIndexOp(
176
+ // rewriter, loc, linearizedInfo.linearizedSize);
177
+ // It computes the multiplied sizes of all dimensions instead of taking
178
+ // the maximum of each dimension size * stride.
179
+ SmallVector<AffineExpr> productExpressions;
180
+ SmallVector<Value> productResults;
181
+ unsigned sourceRank = cast<ShapedType>(src.getType ()).getRank ();
182
+
183
+ SmallVector<AffineExpr> symbols (2 * sourceRank);
184
+ SmallVector<Value> offsetValues;
185
+ bindSymbolsList (rewriter.getContext (), MutableArrayRef{symbols});
186
+
187
+ size_t symbolIndex = 0 ;
188
+ for (size_t i = 0 ; i < sourceRank; ++i) {
189
+ AffineExpr strideExpr, sizeExpr;
190
+ OpFoldResult stride = strides[i];
191
+ OpFoldResult size = sizes[i];
192
+ if (auto constantStride = getConstantIntValue (stride)) {
193
+ strideExpr = rewriter.getAffineConstantExpr (*constantStride);
194
+ } else {
195
+ strideExpr = symbols[symbolIndex++];
196
+ offsetValues.push_back (
197
+ getValueOrCreateConstantIndexOp (rewriter, loc, stride));
198
+ }
199
+
200
+ if (auto constantSize = getConstantIntValue (size)) {
201
+ sizeExpr = rewriter.getAffineConstantExpr (*constantSize);
202
+ } else {
203
+ sizeExpr = symbols[symbolIndex++];
204
+ offsetValues.push_back (
205
+ getValueOrCreateConstantIndexOp (rewriter, loc, size));
206
+ }
207
+
208
+ productExpressions.push_back (strideExpr * sizeExpr);
131
209
}
132
210
133
- rewriter.replaceOp (readOp, res);
211
+ AffineMap maxMap = AffineMap::get (
212
+ /* dimCount=*/ 0 , /* symbolCount=*/ symbolIndex, productExpressions,
213
+ rewriter.getContext ());
214
+ Value totalSize =
215
+ rewriter.create <affine::AffineMaxOp>(loc, maxMap, offsetValues);
216
+
217
+ // delta = bufferSize - linearizedOffset
218
+ Value vectorSizeOffset =
219
+ rewriter.create <arith::ConstantIndexOp>(loc, vectorSize);
220
+ Value linearIndex =
221
+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
222
+ Value delta = rewriter.create <arith::SubIOp>(loc, totalSize, linearIndex);
223
+
224
+ // 1) check if delta < vectorSize
225
+ Value isOutofBounds = rewriter.create <arith::CmpIOp>(
226
+ loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset);
227
+
228
+ // 2) check if (detla_bytes % (32 / elementBitwidth) != 0)
229
+ Value deltaBytes = rewriter.create <arith::MulIOp>(
230
+ loc, delta,
231
+ rewriter.create <arith::ConstantIndexOp>(loc, elementBitWidth / 8 ));
232
+ Value elementsPerWord = rewriter.create <arith::ConstantIndexOp>(
233
+ loc, llvm::divideCeil (32 , elementBitWidth));
234
+ Value isNotWordAligned = rewriter.create <arith::CmpIOp>(
235
+ loc, arith::CmpIPredicate::ne,
236
+ rewriter.create <arith::RemUIOp>(loc, deltaBytes, elementsPerWord),
237
+ rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
238
+
239
+ // We take the fallback of transfer_read default lowering only it is both
240
+ // out-of-bounds and not word aligned. The fallback ensures correct results
241
+ // when loading at the boundary of the buffer since buffer load returns
242
+ // inconsistent zeros for the whole word when boundary is crossed.
243
+ Value ifCondition =
244
+ rewriter.create <arith::AndIOp>(loc, isOutofBounds, isNotWordAligned);
245
+
246
+ auto thenBuilder = [&](OpBuilder &builder, Location loc) {
247
+ Operation *read = builder.clone (*readOp.getOperation ());
248
+ read->setAttr (kTransferReadNeedsMask , builder.getUnitAttr ());
249
+ Value readResult = read->getResult (0 );
250
+ builder.create <scf::YieldOp>(loc, readResult);
251
+ };
252
+
253
+ auto elseBuilder = [&](OpBuilder &builder, Location loc) {
254
+ Value res = createVectorLoadForMaskedLoad (
255
+ builder, loc, readOp, requiresBroadcasting, unbroadcastedVectorType);
256
+ rewriter.create <scf::YieldOp>(loc, res);
257
+ };
258
+
259
+ auto ifOp =
260
+ rewriter.create <scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder);
261
+
262
+ rewriter.replaceOp (readOp, ifOp);
134
263
135
264
return success ();
136
265
}
@@ -149,6 +278,8 @@ struct AmdgpuTransferReadToLoadPass final
149
278
void runOnOperation () override {
150
279
RewritePatternSet patterns (&getContext ());
151
280
populateAmdgpuTransferReadToLoadPatterns (patterns);
152
- walkAndApplyPatterns (getOperation (), std::move (patterns));
281
+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
282
+ return signalPassFailure ();
283
+ }
153
284
}
154
285
};
0 commit comments